| import os |
| from typing import List, Tuple |
| import gradio as gr |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| MODEL_ID = "Balab2021/qwen-workflow-planner-qwen2p5-lora" |
|
|
| HF_TOKEN = os.getenv("HF_TOKEN") |
| if not HF_TOKEN: |
| raise ValueError("HF_TOKEN environment variable is missing. Please add it in Space Settings → Secrets.") |
|
|
| |
| print(f"Loading model: {MODEL_ID} ...") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) |
|
|
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| token=HF_TOKEN, |
| torch_dtype="auto", |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
|
|
| def build_messages(history: List[Tuple[str, str]], user_message: str): |
| messages = [] |
| for user_text, assistant_text in history: |
| if user_text: |
| messages.append({"role": "user", "content": user_text}) |
| if assistant_text: |
| messages.append({"role": "assistant", "content": assistant_text}) |
| messages.append({"role": "user", "content": user_message}) |
| return messages |
|
|
| def chat_fn( |
| message: str, |
| history: List[Tuple[str, str]], |
| temperature: float | None = 0.7, |
| max_new_tokens: int | None = 512, |
| ) -> str: |
| |
| temperature = temperature if temperature is not None else 0.7 |
| max_new_tokens = max_new_tokens if max_new_tokens is not None else 512 |
|
|
| messages = build_messages(history, message) |
| |
| prompt = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| |
| inputs = tokenizer(prompt, return_tensors="pt") |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| output_ids = model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_p=0.9, |
| top_k=40, |
| do_sample=temperature > 0.01, |
| repetition_penalty=1.1, |
| pad_token_id=tokenizer.pad_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| renormalize_logits=True, |
| ) |
| |
| generated_ids = output_ids[0][inputs["input_ids"].shape[-1] :] |
| response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() |
| |
| return response |
|
|
|
|
| demo = gr.ChatInterface( |
| fn=chat_fn, |
| additional_inputs=[ |
| gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature"), |
| gr.Slider(32, 2048, value=512, step=32, label="Max New Tokens"), |
| ], |
| additional_inputs_accordion=gr.Accordion("Generation Settings", open=False), |
| title="Qwen Workflow Planner Chat", |
| description=f"Model: {MODEL_ID}", |
| examples=[ |
| ["Plan a simple content creation workflow"], |
| ["How to automate a daily report generation process?"], |
| ], |
| cache_examples=False, |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |