| | import gradio as gr |
| | import spaces |
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| |
|
| | MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct" |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| |
|
| | |
| | model = None |
| |
|
| |
|
| | def load_model(): |
| | global model |
| | if model is None: |
| | model = AutoModelForCausalLM.from_pretrained( |
| | MODEL_ID, |
| | torch_dtype=torch.float16, |
| | device_map="auto", |
| | ) |
| | return model |
| |
|
| |
|
| | @spaces.GPU(duration=120) |
| | def generate_response(message, history, system_message, max_tokens, temperature, top_p): |
| | loaded_model = load_model() |
| |
|
| | messages = [{"role": "system", "content": system_message}] |
| |
|
| | for user_msg, assistant_msg in history: |
| | if user_msg: |
| | messages.append({"role": "user", "content": user_msg}) |
| | if assistant_msg: |
| | messages.append({"role": "assistant", "content": assistant_msg}) |
| |
|
| | messages.append({"role": "user", "content": message}) |
| |
|
| | text = tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True |
| | ) |
| | inputs = tokenizer([text], return_tensors="pt").to(loaded_model.device) |
| |
|
| | with torch.no_grad(): |
| | outputs = loaded_model.generate( |
| | **inputs, |
| | max_new_tokens=int(max_tokens), |
| | temperature=float(temperature), |
| | top_p=float(top_p), |
| | do_sample=True, |
| | pad_token_id=tokenizer.eos_token_id, |
| | ) |
| |
|
| | response = tokenizer.decode( |
| | outputs[0][inputs['input_ids'].shape[1]:], |
| | skip_special_tokens=True |
| | ) |
| | return response |
| |
|
| |
|
| | demo = gr.ChatInterface( |
| | generate_response, |
| | title="Qwen2.5 Coder 7B", |
| | description="A coding assistant powered by Qwen2.5-Coder-7B-Instruct on ZeroGPU", |
| | additional_inputs=[ |
| | gr.Textbox( |
| | value="You are Qwen, a helpful coding assistant. You excel at writing clean, efficient code and explaining programming concepts clearly.", |
| | label="System message", |
| | lines=2, |
| | ), |
| | gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max tokens"), |
| | gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature"), |
| | gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"), |
| | ], |
| | examples=[ |
| | ["Hello! What programming languages are you best at?"], |
| | ["Write a Python function to check if a string is a palindrome"], |
| | ["Explain the difference between async/await and promises in JavaScript"], |
| | ["Help me optimize this SQL query: SELECT * FROM users WHERE name LIKE '%john%'"], |
| | ], |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|