coder_v2_trial2 / app.py
fomext's picture
Upload app.py
8f79d0d verified
Raw
History Blame Contribute Delete
11.7 kB
import json
import re
import time
import uuid
from typing import Optional
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
MODEL_ID = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8"
MODEL_ALIAS = "qwen3-coder-30b-a3b-instruct-fp8"
print(f"Loading tokenizer for {MODEL_ID} …")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print(f"Loading model {MODEL_ID} …")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype="auto",
device_map="auto",
)
model.eval()
print("Model ready.")
# ---------------------------------------------------------------------------
# GPU generation functions — ZeroGPU anchors
# ---------------------------------------------------------------------------
@spaces.GPU
def gradio_chat(message: str, history: list) -> str:
hf_messages = [{"role": "user" if i % 2 == 0 else "assistant", "content": m}
for i, m in enumerate([msg for pair in history for msg in pair] + [message])]
prompt = tokenizer.apply_chat_template(
hf_messages, tokenize=False, add_generation_prompt=True
# NOTE: Qwen3-Coder is non-thinking only; enable_thinking is not supported.
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
)
new_ids = output_ids[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(new_ids, skip_special_tokens=True)
@spaces.GPU
def _generate_response(prompt: str, gen_kwargs: dict) -> str:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(**inputs, **gen_kwargs)
new_ids = output_ids[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(new_ids, skip_special_tokens=True)
# ---------------------------------------------------------------------------
# API functions
# ---------------------------------------------------------------------------
def list_models() -> str:
"""Returns a JSON string listing available models."""
result = {
"object": "list",
"data": [{"id": MODEL_ALIAS, "object": "model", "created": int(time.time()), "owned_by": "qwen"}],
}
return json.dumps(result)
# ---------------------------------------------------------------------------
# Tool call parsing (Hermes-style: <tool_call>{...}</tool_call>)
# ---------------------------------------------------------------------------
def _parse_tool_calls(text: str):
"""
Detect and extract tool calls from model output.
Handles two formats:
Format A — Hermes JSON (14b, 30b):
<tool_call>{"name": "fn", "arguments": {...}}</tool_call>
Format B — XML parameters (Qwen3-Coder):
<tool_call>
<function=fn_name>
<parameter=param1>value1</parameter>
</function>
</tool_call>
Returns (tool_calls, remaining_content) in OpenAI format,
or (None, text) if no tool calls found.
"""
pattern = r'<tool_call>(.*?)</tool_call>'
matches = re.findall(pattern, text, re.DOTALL)
if not matches:
return None, text
tool_calls = []
for match in matches:
stripped = match.strip()
# ── Format A: JSON inside tool_call ──────────────────────────
try:
call = json.loads(stripped)
tool_calls.append({
"id": f"call_{uuid.uuid4().hex[:24]}",
"type": "function",
"function": {
"name": call.get("name", ""),
"arguments": json.dumps(call.get("arguments", call.get("parameters", {}))),
},
})
continue
except json.JSONDecodeError:
pass
# ── Format B: XML <function=name><parameter=k>v</parameter> ──
fn_match = re.search(r'<function=([^>]+)>', stripped)
if fn_match:
fn_name = fn_match.group(1).strip()
args = {}
for param in re.finditer(r'<parameter=([^>]+)>(.*?)</parameter>', stripped, re.DOTALL):
key = param.group(1).strip()
val = param.group(2).strip()
# Try to coerce to int/float/bool, otherwise keep as string
try:
val = json.loads(val)
except (json.JSONDecodeError, ValueError):
pass
args[key] = val
tool_calls.append({
"id": f"call_{uuid.uuid4().hex[:24]}",
"type": "function",
"function": {
"name": fn_name,
"arguments": json.dumps(args),
},
})
if not tool_calls:
return None, text
remaining = re.sub(pattern, '', text, flags=re.DOTALL).strip()
return tool_calls, remaining or None
def chat_completions(
messages_json: str,
max_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
tools_json: str = "",
) -> str:
"""
Non-streaming chat completions. Returns an OpenAI-compatible JSON string.
messages_json: JSON array of {role, content} objects
tools_json: JSON array of OpenAI-format tool definitions (optional)
NOTE: Qwen3-Coder is non-thinking only; enable_thinking is not supported.
"""
try:
messages = json.loads(messages_json)
except json.JSONDecodeError as e:
return json.dumps({"error": f"Invalid messages_json: {e}"})
tools = None
if tools_json:
try:
tools = json.loads(tools_json)
except json.JSONDecodeError:
pass
try:
hf_messages = []
for m in messages:
role = m["role"]
if role == "tool":
hf_messages.append({
"role": "tool",
"content": m.get("content", ""),
"tool_call_id": m.get("tool_call_id", ""),
})
elif role == "assistant" and m.get("tool_calls"):
# Normalise tool_calls: apply_chat_template needs arguments as a dict,
# but OpenAI format (and our own output) stores them as a JSON string.
normalised_tool_calls = []
for tc in m["tool_calls"]:
fn = tc.get("function", {})
raw_args = fn.get("arguments", "{}")
if isinstance(raw_args, str):
try:
parsed_args = json.loads(raw_args)
except json.JSONDecodeError:
parsed_args = {}
else:
parsed_args = raw_args # already a dict
normalised_tool_calls.append({
"id": tc.get("id", ""),
"type": "function",
"function": {
"name": fn.get("name", ""),
"arguments": parsed_args, # dict, not string
},
})
hf_messages.append({
"role": "assistant",
"content": m.get("content") or "",
"tool_calls": normalised_tool_calls,
})
else:
hf_messages.append({"role": role, "content": m.get("content", "")})
template_kwargs = dict(tokenize=False, add_generation_prompt=True)
if tools:
template_kwargs["tools"] = tools
prompt = tokenizer.apply_chat_template(hf_messages, **template_kwargs)
except Exception as e:
return json.dumps({"error": f"Prompt build failed: {e}"})
gen_kwargs = dict(
max_new_tokens=max_tokens,
temperature=max(temperature, 0.01),
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
try:
raw = _generate_response(prompt, gen_kwargs)
except Exception as e:
return json.dumps({"error": f"Generation failed: {e}"})
cid = f"chatcmpl-{uuid.uuid4().hex}"
tool_calls, content = _parse_tool_calls(raw)
if tool_calls:
message = {"role": "assistant", "content": content, "tool_calls": tool_calls}
finish_reason = "tool_calls"
else:
message = {"role": "assistant", "content": raw}
finish_reason = "stop"
result = {
"id": cid,
"object": "chat.completion",
"created": int(time.time()),
"model": MODEL_ALIAS,
"choices": [{"index": 0, "message": message, "finish_reason": finish_reason}],
"usage": {"prompt_tokens": -1, "completion_tokens": -1, "total_tokens": -1},
}
return json.dumps(result)
def health() -> str:
"""Returns a JSON health-check string."""
return json.dumps({"status": "ok", "model": MODEL_ID})
# ---------------------------------------------------------------------------
# Gradio UI + API
# ---------------------------------------------------------------------------
with gr.Blocks(title=f"{MODEL_ALIAS} API") as demo:
gr.Markdown(f"""
# {MODEL_ALIAS} — Gradio API
Endpoints (via Gradio built-in API):
| api_name | Description |
|----------|-------------|
| `list_models` | List available models → JSON string |
| `chat_completions` | Chat completions → JSON string |
| `health` | Health check → JSON string |
Call them at `/gradio_api/call/<api_name>` (POST with `{{"data": [...]}}`)
or use the Gradio Python client.
You can also chat directly below.
""")
gr.ChatInterface(fn=gradio_chat)
with gr.Row(visible=False):
# -- health ------------------------------------------------------
_health_btn = gr.Button("health")
_health_out = gr.Textbox()
_health_btn.click(fn=health, inputs=[], outputs=[_health_out], api_name="health")
# -- list_models -------------------------------------------------
_models_btn = gr.Button("list_models")
_models_out = gr.Textbox()
_models_btn.click(fn=list_models, inputs=[], outputs=[_models_out], api_name="list_models")
with gr.Row(visible=False):
# -- chat_completions --------------------------------------------
_cc_messages = gr.Textbox(label="messages_json")
_cc_max_tokens = gr.Number(label="max_tokens", value=512)
_cc_temp = gr.Number(label="temperature", value=0.7)
_cc_top_p = gr.Number(label="top_p", value=0.9)
_cc_tools = gr.Textbox(label="tools_json", value="")
_cc_out = gr.Textbox(label="result")
_cc_btn = gr.Button("chat_completions")
_cc_btn.click(
fn=chat_completions,
inputs=[_cc_messages, _cc_max_tokens, _cc_temp, _cc_top_p, _cc_tools],
outputs=[_cc_out],
api_name="chat_completions",
)
# ---------------------------------------------------------------------------
# Entry-point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
demo.queue()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
)