import json
import re
import time
import uuid
from typing import Optional
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
MODEL_ID = "Qwen/Qwen3-14B"
MODEL_ALIAS = "qwen3-14b-4bit"
print(f"Loading tokenizer for {MODEL_ID} …")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print(f"Loading model {MODEL_ID} in 4-bit …")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=bnb_config,
device_map="auto",
)
model.eval()
print("Model ready.")
# ---------------------------------------------------------------------------
# GPU generation functions — ZeroGPU anchors
# ---------------------------------------------------------------------------
@spaces.GPU
def gradio_chat(message: str, history: list) -> str:
hf_messages = [{"role": "user" if i % 2 == 0 else "assistant", "content": m}
for i, m in enumerate([msg for pair in history for msg in pair] + [message])]
prompt = tokenizer.apply_chat_template(
hf_messages, tokenize=False, add_generation_prompt=True
# NOTE: Qwen3-Coder is non-thinking only; enable_thinking is not supported.
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
)
new_ids = output_ids[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(new_ids, skip_special_tokens=True)
@spaces.GPU
def _generate_response(prompt: str, gen_kwargs: dict) -> str:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(**inputs, **gen_kwargs)
new_ids = output_ids[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(new_ids, skip_special_tokens=True)
# ---------------------------------------------------------------------------
# API functions
# ---------------------------------------------------------------------------
def list_models() -> str:
"""Returns a JSON string listing available models."""
result = {
"object": "list",
"data": [{"id": MODEL_ALIAS, "object": "model", "created": int(time.time()), "owned_by": "qwen"}],
}
return json.dumps(result)
# ---------------------------------------------------------------------------
# Tool call parsing (Hermes-style: {...})
# ---------------------------------------------------------------------------
def _parse_tool_calls(text: str):
"""
Detect and extract Hermes-style tool calls from model output.
Returns (tool_calls, remaining_content) where tool_calls is a list in
OpenAI format, or (None, text) if no tool calls are found.
"""
pattern = r'(.*?)'
matches = re.findall(pattern, text, re.DOTALL)
if not matches:
return None, text
tool_calls = []
for match in matches:
try:
call = json.loads(match.strip())
tool_calls.append({
"id": f"call_{uuid.uuid4().hex[:24]}",
"type": "function",
"function": {
"name": call.get("name", ""),
"arguments": json.dumps(call.get("arguments", call.get("parameters", {}))),
},
})
except json.JSONDecodeError:
continue
if not tool_calls:
return None, text
# Strip tool call blocks and think tags from remaining content
remaining = re.sub(pattern, '', text, flags=re.DOTALL)
remaining = re.sub(r'.*?', '', remaining, flags=re.DOTALL).strip()
return tool_calls, remaining or None
def chat_completions(
messages_json: str,
max_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
tools_json: str = "",
) -> str:
"""
Non-streaming chat completions. Returns an OpenAI-compatible JSON string.
messages_json: JSON array of {role, content} objects
tools_json: JSON array of OpenAI-format tool definitions (optional)
NOTE: Qwen3-14B supports thinking mode. enable_thinking is set to False
here for reliable tool call formatting.
"""
try:
messages = json.loads(messages_json)
except json.JSONDecodeError as e:
return json.dumps({"error": f"Invalid messages_json: {e}"})
tools = None
if tools_json:
try:
tools = json.loads(tools_json)
except json.JSONDecodeError:
pass
try:
hf_messages = []
for m in messages:
role = m["role"]
# tool results come in as role=tool; map to role=tool with tool_call_id
if role == "tool":
hf_messages.append({
"role": "tool",
"content": m.get("content", ""),
"tool_call_id": m.get("tool_call_id", ""),
})
elif role == "assistant" and m.get("tool_calls"):
hf_messages.append({
"role": "assistant",
"content": m.get("content") or "",
"tool_calls": m["tool_calls"],
})
else:
hf_messages.append({"role": role, "content": m.get("content", "")})
template_kwargs = dict(
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
if tools:
template_kwargs["tools"] = tools
prompt = tokenizer.apply_chat_template(hf_messages, **template_kwargs)
except Exception as e:
return json.dumps({"error": f"Prompt build failed: {e}"})
gen_kwargs = dict(
max_new_tokens=max_tokens,
temperature=max(temperature, 0.01),
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
try:
raw = _generate_response(prompt, gen_kwargs)
except Exception as e:
return json.dumps({"error": f"Generation failed: {e}"})
cid = f"chatcmpl-{uuid.uuid4().hex}"
tool_calls, content = _parse_tool_calls(raw)
if tool_calls:
message = {"role": "assistant", "content": content, "tool_calls": tool_calls}
finish_reason = "tool_calls"
else:
# Strip any stray think tags from plain responses
content = re.sub(r'.*?', '', raw, flags=re.DOTALL).strip()
message = {"role": "assistant", "content": content}
finish_reason = "stop"
result = {
"id": cid,
"object": "chat.completion",
"created": int(time.time()),
"model": MODEL_ALIAS,
"choices": [{"index": 0, "message": message, "finish_reason": finish_reason}],
"usage": {"prompt_tokens": -1, "completion_tokens": -1, "total_tokens": -1},
}
return json.dumps(result)
def health() -> str:
"""Returns a JSON health-check string."""
return json.dumps({"status": "ok", "model": MODEL_ID})
# ---------------------------------------------------------------------------
# Gradio UI + API
# ---------------------------------------------------------------------------
with gr.Blocks(title=f"{MODEL_ALIAS} API") as demo:
gr.Markdown(f"""
# {MODEL_ALIAS} — Gradio API
Endpoints (via Gradio built-in API):
| api_name | Description |
|----------|-------------|
| `list_models` | List available models → JSON string |
| `chat_completions` | Chat completions → JSON string |
| `health` | Health check → JSON string |
Call them at `/gradio_api/call/` (POST with `{{"data": [...]}}`)
or use the Gradio Python client.
You can also chat directly below.
""")
gr.ChatInterface(fn=gradio_chat)
with gr.Row(visible=False):
# -- health ------------------------------------------------------
_health_btn = gr.Button("health")
_health_out = gr.Textbox()
_health_btn.click(fn=health, inputs=[], outputs=[_health_out], api_name="health")
# -- list_models -------------------------------------------------
_models_btn = gr.Button("list_models")
_models_out = gr.Textbox()
_models_btn.click(fn=list_models, inputs=[], outputs=[_models_out], api_name="list_models")
with gr.Row(visible=False):
# -- chat_completions --------------------------------------------
_cc_messages = gr.Textbox(label="messages_json")
_cc_max_tokens = gr.Number(label="max_tokens", value=512)
_cc_temp = gr.Number(label="temperature", value=0.7)
_cc_top_p = gr.Number(label="top_p", value=0.9)
_cc_tools = gr.Textbox(label="tools_json", value="")
_cc_out = gr.Textbox(label="result")
_cc_btn = gr.Button("chat_completions")
_cc_btn.click(
fn=chat_completions,
inputs=[_cc_messages, _cc_max_tokens, _cc_temp, _cc_top_p, _cc_tools],
outputs=[_cc_out],
api_name="chat_completions",
)
# ---------------------------------------------------------------------------
# Entry-point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
demo.queue()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
)