Agentworkflow / app.py
Balab2021's picture
Update app.py
bf3d8f9 verified
Raw
History Blame Contribute Delete
2.31 kB
import os
from typing import List, Tuple
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_ID = "Balab2021/qwen-workflow-planner-qwen2p5-lora"
# Hugging Face Spaces automatically provides this if you set it in Secrets
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("HF_TOKEN environment variable is missing. Please add it in Space Settings → Secrets.")
def build_messages(history: List[Tuple[str, str]], user_message: str):
messages = []
for user_text, assistant_text in history:
if user_text:
messages.append({"role": "user", "content": user_text})
if assistant_text:
messages.append({"role": "assistant", "content": assistant_text})
messages.append({"role": "user", "content": user_message})
return messages
# Load model at startup
print(f"Loading model: {MODEL_ID} ...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
torch_dtype="auto",
device_map="auto",
)
def chat_fn(
message: str,
history: List[Tuple[str, str]],
temperature: float,
max_new_tokens: int,
) -> str:
messages = build_messages(history, message)
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=tokenizer.eos_token_id,
)
generated_ids = output_ids[0][inputs["input_ids"].shape[-1] :]
response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
return response
demo = gr.ChatInterface(
fn=chat_fn,
additional_inputs=[
gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="Temperature"),
gr.Slider(32, 2048, value=512, step=32, label="Max New Tokens"),
],
title="Qwen Workflow Planner Chat",
description=f"Model: {MODEL_ID}",
)
if __name__ == "__main__":
demo.launch()