agentflow / app.py
Balab2021's picture
Update app.py
b34e624 verified
Raw
History Blame Contribute Delete
3.09 kB
import os
from typing import List, Tuple
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_ID = "Balab2021/qwen-workflow-planner-qwen2p5-lora"
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("HF_TOKEN environment variable is missing. Please add it in Space Settings → Secrets.")
# Load model at startup
print(f"Loading model: {MODEL_ID} ...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True,
)
def build_messages(history: List[Tuple[str, str]], user_message: str):
messages = []
for user_text, assistant_text in history:
if user_text:
messages.append({"role": "user", "content": user_text})
if assistant_text:
messages.append({"role": "assistant", "content": assistant_text})
messages.append({"role": "user", "content": user_message})
return messages
def chat_fn(
message: str,
history: List[Tuple[str, str]],
temperature: float | None = 0.7, # <-- default here
max_new_tokens: int | None = 512, # <-- default here
) -> str:
# Handle None values (from example caching)
temperature = temperature if temperature is not None else 0.7
max_new_tokens = max_new_tokens if max_new_tokens is not None else 512
messages = build_messages(history, message)
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=0.9,
top_k=40,
do_sample=temperature > 0.01,
repetition_penalty=1.1,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
renormalize_logits=True,
)
generated_ids = output_ids[0][inputs["input_ids"].shape[-1] :]
response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
return response
demo = gr.ChatInterface(
fn=chat_fn,
additional_inputs=[
gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature"),
gr.Slider(32, 2048, value=512, step=32, label="Max New Tokens"),
],
additional_inputs_accordion=gr.Accordion("Generation Settings", open=False),
title="Qwen Workflow Planner Chat",
description=f"Model: {MODEL_ID}",
examples=[
["Plan a simple content creation workflow"],
["How to automate a daily report generation process?"],
],
cache_examples=False, # Recommended on HF Spaces with additional inputs
)
if __name__ == "__main__":
demo.launch()