import os from typing import List, Tuple import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer MODEL_ID = "Balab2021/qwen-workflow-planner-qwen2p5-lora" HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: raise ValueError("HF_TOKEN environment variable is missing. Please add it in Space Settings → Secrets.") # Load model at startup print(f"Loading model: {MODEL_ID} ...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( MODEL_ID, token=HF_TOKEN, torch_dtype="auto", device_map="auto", trust_remote_code=True, ) def build_messages(history: List[Tuple[str, str]], user_message: str): messages = [] for user_text, assistant_text in history: if user_text: messages.append({"role": "user", "content": user_text}) if assistant_text: messages.append({"role": "assistant", "content": assistant_text}) messages.append({"role": "user", "content": user_message}) return messages def chat_fn( message: str, history: List[Tuple[str, str]], temperature: float | None = 0.7, # <-- default here max_new_tokens: int | None = 512, # <-- default here ) -> str: # Handle None values (from example caching) temperature = temperature if temperature is not None else 0.7 max_new_tokens = max_new_tokens if max_new_tokens is not None else 512 messages = build_messages(history, message) prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) inputs = tokenizer(prompt, return_tensors="pt") inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=0.9, top_k=40, do_sample=temperature > 0.01, repetition_penalty=1.1, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, renormalize_logits=True, ) generated_ids = output_ids[0][inputs["input_ids"].shape[-1] :] response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() return response demo = gr.ChatInterface( fn=chat_fn, additional_inputs=[ gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature"), gr.Slider(32, 2048, value=512, step=32, label="Max New Tokens"), ], additional_inputs_accordion=gr.Accordion("Generation Settings", open=False), title="Qwen Workflow Planner Chat", description=f"Model: {MODEL_ID}", examples=[ ["Plan a simple content creation workflow"], ["How to automate a daily report generation process?"], ], cache_examples=False, # Recommended on HF Spaces with additional inputs ) if __name__ == "__main__": demo.launch()