fomext's picture
Upload app.py
0bb09b3 verified
Raw
History Blame Contribute Delete
10.3 kB
import json
import re
import time
import uuid
from typing import Optional
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
MODEL_ID = "Qwen/Qwen3-30B-A3B"
MODEL_ALIAS = "qwen3-30b-a3b"
print(f"Loading tokenizer for {MODEL_ID} …")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print(f"Loading model {MODEL_ID} …")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.eval()
print("Model ready.")
# ---------------------------------------------------------------------------
# GPU generation functions β€” ZeroGPU anchors
# ---------------------------------------------------------------------------
@spaces.GPU
def gradio_chat(message: str, history: list) -> str:
hf_messages = [{"role": "user" if i % 2 == 0 else "assistant", "content": m}
for i, m in enumerate([msg for pair in history for msg in pair] + [message])]
prompt = tokenizer.apply_chat_template(
hf_messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
)
new_ids = output_ids[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(new_ids, skip_special_tokens=True)
@spaces.GPU
def _generate_response(prompt: str, gen_kwargs: dict) -> str:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(**inputs, **gen_kwargs)
new_ids = output_ids[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(new_ids, skip_special_tokens=True)
# ---------------------------------------------------------------------------
# API functions
# ---------------------------------------------------------------------------
def list_models() -> str:
"""Returns a JSON string listing available models."""
result = {
"object": "list",
"data": [{"id": MODEL_ALIAS, "object": "model", "created": int(time.time()), "owned_by": "qwen"}],
}
return json.dumps(result)
# ---------------------------------------------------------------------------
# Tool call parsing (Hermes-style: <tool_call>{...}</tool_call>)
# ---------------------------------------------------------------------------
def _parse_tool_calls(text: str):
"""
Detect and extract Hermes-style tool calls from model output.
Returns (tool_calls, remaining_content) where tool_calls is a list in
OpenAI format, or (None, text) if no tool calls are found.
"""
pattern = r'<tool_call>(.*?)</tool_call>'
matches = re.findall(pattern, text, re.DOTALL)
if not matches:
return None, text
tool_calls = []
for match in matches:
try:
call = json.loads(match.strip())
tool_calls.append({
"id": f"call_{uuid.uuid4().hex[:24]}",
"type": "function",
"function": {
"name": call.get("name", ""),
"arguments": json.dumps(call.get("arguments", call.get("parameters", {}))),
},
})
except json.JSONDecodeError:
continue
if not tool_calls:
return None, text
remaining = re.sub(pattern, '', text, flags=re.DOTALL)
remaining = re.sub(r'<think>.*?</think>', '', remaining, flags=re.DOTALL).strip()
return tool_calls, remaining or None
def chat_completions(
messages_json: str,
max_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
enable_thinking: bool = False,
tools_json: str = "",
) -> str:
"""
Non-streaming chat completions. Returns an OpenAI-compatible JSON string.
messages_json: JSON array of {role, content} objects
enable_thinking: enable Qwen3 chain-of-thought reasoning (default False for tool use)
tools_json: JSON array of OpenAI-format tool definitions (optional)
"""
try:
messages = json.loads(messages_json)
except json.JSONDecodeError as e:
return json.dumps({"error": f"Invalid messages_json: {e}"})
tools = None
if tools_json:
try:
tools = json.loads(tools_json)
except json.JSONDecodeError:
pass
try:
hf_messages = []
for m in messages:
role = m["role"]
if role == "tool":
hf_messages.append({
"role": "tool",
"content": m.get("content", ""),
"tool_call_id": m.get("tool_call_id", ""),
})
elif role == "assistant" and m.get("tool_calls"):
hf_messages.append({
"role": "assistant",
"content": m.get("content") or "",
"tool_calls": m["tool_calls"],
})
else:
hf_messages.append({"role": role, "content": m.get("content", "")})
template_kwargs = dict(
tokenize=False,
add_generation_prompt=True,
enable_thinking=enable_thinking,
)
if tools:
template_kwargs["tools"] = tools
prompt = tokenizer.apply_chat_template(hf_messages, **template_kwargs)
except Exception as e:
return json.dumps({"error": f"Prompt build failed: {e}"})
gen_kwargs = dict(
max_new_tokens=max_tokens,
temperature=max(temperature, 0.01),
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
try:
raw = _generate_response(prompt, gen_kwargs)
except Exception as e:
return json.dumps({"error": f"Generation failed: {e}"})
cid = f"chatcmpl-{uuid.uuid4().hex}"
tool_calls, content = _parse_tool_calls(raw)
if tool_calls:
message = {"role": "assistant", "content": content, "tool_calls": tool_calls}
finish_reason = "tool_calls"
else:
# Strip think tags from plain responses when thinking was enabled
content = re.sub(r'<think>.*?</think>', '', raw, flags=re.DOTALL).strip()
message = {"role": "assistant", "content": content}
finish_reason = "stop"
result = {
"id": cid,
"object": "chat.completion",
"created": int(time.time()),
"model": MODEL_ALIAS,
"choices": [{"index": 0, "message": message, "finish_reason": finish_reason}],
"usage": {"prompt_tokens": -1, "completion_tokens": -1, "total_tokens": -1},
}
return json.dumps(result)
def health() -> str:
"""Returns a JSON health-check string."""
return json.dumps({"status": "ok", "model": MODEL_ID})
# ---------------------------------------------------------------------------
# Gradio UI + API
# ---------------------------------------------------------------------------
# In Gradio 4.x (pre-5), gr.api() does not exist. The correct way to expose
# named API endpoints is to wire invisible component events with api_name=.
# Each .click(fn=..., api_name="name") registers /gradio_api/call/<name>.
# ---------------------------------------------------------------------------
with gr.Blocks(title=f"{MODEL_ALIAS} API") as demo:
gr.Markdown(f"""
# {MODEL_ALIAS} β€” Gradio API
Endpoints (via Gradio built-in API):
| api_name | Description |
|----------|-------------|
| `list_models` | List available models β†’ JSON string |
| `chat_completions` | Chat completions β†’ JSON string |
| `health` | Health check β†’ JSON string |
Call them at `/gradio_api/call/<api_name>` (POST with `{{"data": [...]}}`)
or use the Gradio Python client.
You can also chat directly below.
""")
gr.ChatInterface(fn=gradio_chat)
# ------------------------------------------------------------------
# Hidden API wiring β€” invisible rows that register named endpoints.
# Gradio 4.41 exposes /gradio_api/call/<api_name> for every event
# that has api_name set, regardless of whether the components are
# visible in the UI.
# ------------------------------------------------------------------
with gr.Row(visible=False):
# -- health ------------------------------------------------------
_health_btn = gr.Button("health")
_health_out = gr.Textbox()
_health_btn.click(fn=health, inputs=[], outputs=[_health_out], api_name="health")
# -- list_models -------------------------------------------------
_models_btn = gr.Button("list_models")
_models_out = gr.Textbox()
_models_btn.click(fn=list_models, inputs=[], outputs=[_models_out], api_name="list_models")
with gr.Row(visible=False):
# -- chat_completions --------------------------------------------
_cc_messages = gr.Textbox(label="messages_json")
_cc_max_tokens = gr.Number(label="max_tokens", value=512)
_cc_temp = gr.Number(label="temperature", value=0.7)
_cc_top_p = gr.Number(label="top_p", value=0.9)
_cc_thinking = gr.Checkbox(label="enable_thinking", value=False)
_cc_tools = gr.Textbox(label="tools_json", value="")
_cc_out = gr.Textbox(label="result")
_cc_btn = gr.Button("chat_completions")
_cc_btn.click(
fn=chat_completions,
inputs=[_cc_messages, _cc_max_tokens, _cc_temp, _cc_top_p, _cc_thinking, _cc_tools],
outputs=[_cc_out],
api_name="chat_completions",
)
# ---------------------------------------------------------------------------
# Entry-point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
demo.queue()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
)