import json
import re
import time
import uuid
from typing import Optional
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
MODEL_ID = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8"
MODEL_ALIAS = "qwen3-coder-30b-a3b-instruct-fp8"
print(f"Loading tokenizer for {MODEL_ID} …")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print(f"Loading model {MODEL_ID} …")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype="auto",
device_map="auto",
)
model.eval()
print("Model ready.")
# ---------------------------------------------------------------------------
# GPU generation functions — ZeroGPU anchors
# ---------------------------------------------------------------------------
@spaces.GPU
def gradio_chat(message: str, history: list) -> str:
hf_messages = [{"role": "user" if i % 2 == 0 else "assistant", "content": m}
for i, m in enumerate([msg for pair in history for msg in pair] + [message])]
prompt = tokenizer.apply_chat_template(
hf_messages, tokenize=False, add_generation_prompt=True
# NOTE: Qwen3-Coder is non-thinking only; enable_thinking is not supported.
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
)
new_ids = output_ids[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(new_ids, skip_special_tokens=True)
@spaces.GPU
def _generate_response(prompt: str, gen_kwargs: dict) -> str:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(**inputs, **gen_kwargs)
new_ids = output_ids[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(new_ids, skip_special_tokens=True)
# ---------------------------------------------------------------------------
# API functions
# ---------------------------------------------------------------------------
def list_models() -> str:
"""Returns a JSON string listing available models."""
result = {
"object": "list",
"data": [{"id": MODEL_ALIAS, "object": "model", "created": int(time.time()), "owned_by": "qwen"}],
}
return json.dumps(result)
# ---------------------------------------------------------------------------
# Tool call parsing (Hermes-style: {...})
# ---------------------------------------------------------------------------
def _parse_tool_calls(text: str):
"""
Detect and extract tool calls from model output.
Handles two formats:
Format A — Hermes JSON (14b, 30b):
{"name": "fn", "arguments": {...}}
Format B — XML parameters (Qwen3-Coder):
value1
Returns (tool_calls, remaining_content) in OpenAI format,
or (None, text) if no tool calls found.
"""
pattern = r'(.*?)'
matches = re.findall(pattern, text, re.DOTALL)
if not matches:
return None, text
tool_calls = []
for match in matches:
stripped = match.strip()
# ── Format A: JSON inside tool_call ──────────────────────────
try:
call = json.loads(stripped)
tool_calls.append({
"id": f"call_{uuid.uuid4().hex[:24]}",
"type": "function",
"function": {
"name": call.get("name", ""),
"arguments": json.dumps(call.get("arguments", call.get("parameters", {}))),
},
})
continue
except json.JSONDecodeError:
pass
# ── Format B: XML v ──
fn_match = re.search(r']+)>', stripped)
if fn_match:
fn_name = fn_match.group(1).strip()
args = {}
for param in re.finditer(r']+)>(.*?)', stripped, re.DOTALL):
key = param.group(1).strip()
val = param.group(2).strip()
# Try to coerce to int/float/bool, otherwise keep as string
try:
val = json.loads(val)
except (json.JSONDecodeError, ValueError):
pass
args[key] = val
tool_calls.append({
"id": f"call_{uuid.uuid4().hex[:24]}",
"type": "function",
"function": {
"name": fn_name,
"arguments": json.dumps(args),
},
})
if not tool_calls:
return None, text
remaining = re.sub(pattern, '', text, flags=re.DOTALL).strip()
return tool_calls, remaining or None
def chat_completions(
messages_json: str,
max_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
tools_json: str = "",
) -> str:
"""
Non-streaming chat completions. Returns an OpenAI-compatible JSON string.
messages_json: JSON array of {role, content} objects
tools_json: JSON array of OpenAI-format tool definitions (optional)
NOTE: Qwen3-Coder is non-thinking only; enable_thinking is not supported.
"""
try:
messages = json.loads(messages_json)
except json.JSONDecodeError as e:
return json.dumps({"error": f"Invalid messages_json: {e}"})
tools = None
if tools_json:
try:
tools = json.loads(tools_json)
except json.JSONDecodeError:
pass
try:
hf_messages = []
for m in messages:
role = m["role"]
if role == "tool":
hf_messages.append({
"role": "tool",
"content": m.get("content", ""),
"tool_call_id": m.get("tool_call_id", ""),
})
elif role == "assistant" and m.get("tool_calls"):
# Normalise tool_calls: apply_chat_template needs arguments as a dict,
# but OpenAI format (and our own output) stores them as a JSON string.
normalised_tool_calls = []
for tc in m["tool_calls"]:
fn = tc.get("function", {})
raw_args = fn.get("arguments", "{}")
if isinstance(raw_args, str):
try:
parsed_args = json.loads(raw_args)
except json.JSONDecodeError:
parsed_args = {}
else:
parsed_args = raw_args # already a dict
normalised_tool_calls.append({
"id": tc.get("id", ""),
"type": "function",
"function": {
"name": fn.get("name", ""),
"arguments": parsed_args, # dict, not string
},
})
hf_messages.append({
"role": "assistant",
"content": m.get("content") or "",
"tool_calls": normalised_tool_calls,
})
else:
hf_messages.append({"role": role, "content": m.get("content", "")})
template_kwargs = dict(tokenize=False, add_generation_prompt=True)
if tools:
template_kwargs["tools"] = tools
prompt = tokenizer.apply_chat_template(hf_messages, **template_kwargs)
except Exception as e:
return json.dumps({"error": f"Prompt build failed: {e}"})
gen_kwargs = dict(
max_new_tokens=max_tokens,
temperature=max(temperature, 0.01),
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
try:
raw = _generate_response(prompt, gen_kwargs)
except Exception as e:
return json.dumps({"error": f"Generation failed: {e}"})
cid = f"chatcmpl-{uuid.uuid4().hex}"
tool_calls, content = _parse_tool_calls(raw)
if tool_calls:
message = {"role": "assistant", "content": content, "tool_calls": tool_calls}
finish_reason = "tool_calls"
else:
message = {"role": "assistant", "content": raw}
finish_reason = "stop"
result = {
"id": cid,
"object": "chat.completion",
"created": int(time.time()),
"model": MODEL_ALIAS,
"choices": [{"index": 0, "message": message, "finish_reason": finish_reason}],
"usage": {"prompt_tokens": -1, "completion_tokens": -1, "total_tokens": -1},
}
return json.dumps(result)
def health() -> str:
"""Returns a JSON health-check string."""
return json.dumps({"status": "ok", "model": MODEL_ID})
# ---------------------------------------------------------------------------
# Gradio UI + API
# ---------------------------------------------------------------------------
with gr.Blocks(title=f"{MODEL_ALIAS} API") as demo:
gr.Markdown(f"""
# {MODEL_ALIAS} — Gradio API
Endpoints (via Gradio built-in API):
| api_name | Description |
|----------|-------------|
| `list_models` | List available models → JSON string |
| `chat_completions` | Chat completions → JSON string |
| `health` | Health check → JSON string |
Call them at `/gradio_api/call/` (POST with `{{"data": [...]}}`)
or use the Gradio Python client.
You can also chat directly below.
""")
gr.ChatInterface(fn=gradio_chat)
with gr.Row(visible=False):
# -- health ------------------------------------------------------
_health_btn = gr.Button("health")
_health_out = gr.Textbox()
_health_btn.click(fn=health, inputs=[], outputs=[_health_out], api_name="health")
# -- list_models -------------------------------------------------
_models_btn = gr.Button("list_models")
_models_out = gr.Textbox()
_models_btn.click(fn=list_models, inputs=[], outputs=[_models_out], api_name="list_models")
with gr.Row(visible=False):
# -- chat_completions --------------------------------------------
_cc_messages = gr.Textbox(label="messages_json")
_cc_max_tokens = gr.Number(label="max_tokens", value=512)
_cc_temp = gr.Number(label="temperature", value=0.7)
_cc_top_p = gr.Number(label="top_p", value=0.9)
_cc_tools = gr.Textbox(label="tools_json", value="")
_cc_out = gr.Textbox(label="result")
_cc_btn = gr.Button("chat_completions")
_cc_btn.click(
fn=chat_completions,
inputs=[_cc_messages, _cc_max_tokens, _cc_temp, _cc_top_p, _cc_tools],
outputs=[_cc_out],
api_name="chat_completions",
)
# ---------------------------------------------------------------------------
# Entry-point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
demo.queue()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
)