qwen-api / app.py
Ngixdev's picture
Switch to transformers with Qwen2.5-7B-Instruct
15dcc64 verified
import os
import torch
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
tokenizer = None
model = None
def load_model():
global tokenizer, model
if model is None:
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
print("Model loaded!")
return tokenizer, model
@spaces.GPU(duration=120)
def generate_response(
message: str,
history: list,
system_prompt: str = "",
temperature: float = 0.7,
top_p: float = 0.8,
top_k: int = 20,
max_tokens: int = 1024,
) -> str:
tok, mdl = load_model()
messages = []
if system_prompt.strip():
messages.append({"role": "system", "content": system_prompt})
for user_msg, assistant_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tok([text], return_tensors="pt").to(mdl.device)
outputs = mdl.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=True,
pad_token_id=tok.eos_token_id,
)
generated = outputs[0][inputs['input_ids'].shape[-1]:]
return tok.decode(generated, skip_special_tokens=True)
@spaces.GPU(duration=120)
def api_generate(
prompt: str,
system_prompt: str = "",
temperature: float = 0.7,
top_p: float = 0.8,
max_tokens: int = 1024,
) -> dict:
"""
API endpoint for text generation.
Args:
prompt: The user prompt/question
system_prompt: Optional system instruction
temperature: Sampling temperature (0.0-2.0)
top_p: Nucleus sampling parameter (0.0-1.0)
max_tokens: Maximum tokens to generate
Returns:
Dictionary with 'response' key containing generated text
"""
try:
response = generate_response(
message=prompt,
history=[],
system_prompt=system_prompt,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
)
return {"response": response, "status": "success"}
except Exception as e:
return {"response": None, "status": "error", "error": str(e)}
with gr.Blocks(title="Qwen API", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Qwen2.5-7B-Instruct API
Powered by [Qwen/Qwen2.5-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) on ZeroGPU
"""
)
with gr.Tab("Chat"):
chatbot = gr.Chatbot(height=450, label="Conversation")
with gr.Row():
msg = gr.Textbox(label="Message", placeholder="Type here...", scale=4, lines=2)
submit_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Accordion("Settings", open=False):
system_prompt = gr.Textbox(label="System Prompt", placeholder="Optional", lines=2)
with gr.Row():
temperature = gr.Slider(0.0, 2.0, 0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.0, 1.0, 0.8, step=0.05, label="Top P")
with gr.Row():
top_k = gr.Slider(1, 100, 20, step=1, label="Top K")
max_tokens = gr.Slider(64, 2048, 1024, step=64, label="Max Tokens")
clear_btn = gr.Button("Clear")
def user_submit(message, history):
return "", history + [[message, None]]
def bot_response(history, system_prompt, temperature, top_p, top_k, max_tokens):
if not history:
return history
message = history[-1][0]
history_without_last = history[:-1]
response = generate_response(message, history_without_last, system_prompt, temperature, top_p, top_k, max_tokens)
history[-1][1] = response
return history
msg.submit(user_submit, [msg, chatbot], [msg, chatbot]).then(
bot_response, [chatbot, system_prompt, temperature, top_p, top_k, max_tokens], chatbot
)
submit_btn.click(user_submit, [msg, chatbot], [msg, chatbot]).then(
bot_response, [chatbot, system_prompt, temperature, top_p, top_k, max_tokens], chatbot
)
clear_btn.click(lambda: [], None, chatbot)
with gr.Tab("API"):
gr.Markdown(
"""
## API Usage
```python
from gradio_client import Client
client = Client("Ngixdev/qwen-api")
result = client.predict(
prompt="Hello!",
system_prompt="You are helpful.",
temperature=0.7,
top_p=0.8,
max_tokens=1024,
api_name="/api_generate"
)
print(result)
```
"""
)
with gr.Row():
with gr.Column():
api_prompt = gr.Textbox(label="Prompt", lines=3)
api_system = gr.Textbox(label="System Prompt", lines=2)
with gr.Row():
api_temp = gr.Slider(0.0, 2.0, 0.7, step=0.1, label="Temperature")
api_top_p = gr.Slider(0.0, 1.0, 0.8, step=0.05, label="Top P")
api_max_tokens = gr.Slider(64, 2048, 1024, step=64, label="Max Tokens")
api_submit = gr.Button("Generate", variant="primary")
with gr.Column():
api_output = gr.JSON(label="Response")
api_submit.click(
api_generate,
[api_prompt, api_system, api_temp, api_top_p, api_max_tokens],
api_output,
api_name="api_generate",
)
demo.launch()