import json import re import time import uuid from typing import Optional import gradio as gr import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- MODEL_ID = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8" MODEL_ALIAS = "qwen3-coder-30b-a3b-instruct-fp8" print(f"Loading tokenizer for {MODEL_ID} …") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) print(f"Loading model {MODEL_ID} …") model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype="auto", device_map="auto", ) model.eval() print("Model ready.") # --------------------------------------------------------------------------- # GPU generation functions — ZeroGPU anchors # --------------------------------------------------------------------------- @spaces.GPU def gradio_chat(message: str, history: list) -> str: hf_messages = [{"role": "user" if i % 2 == 0 else "assistant", "content": m} for i, m in enumerate([msg for pair in history for msg in pair] + [message])] prompt = tokenizer.apply_chat_template( hf_messages, tokenize=False, add_generation_prompt=True # NOTE: Qwen3-Coder is non-thinking only; enable_thinking is not supported. ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9, pad_token_id=tokenizer.eos_token_id, ) new_ids = output_ids[0][inputs["input_ids"].shape[1]:] return tokenizer.decode(new_ids, skip_special_tokens=True) @spaces.GPU def _generate_response(prompt: str, gen_kwargs: dict) -> str: inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): output_ids = model.generate(**inputs, **gen_kwargs) new_ids = output_ids[0][inputs["input_ids"].shape[1]:] return tokenizer.decode(new_ids, skip_special_tokens=True) # --------------------------------------------------------------------------- # API functions # --------------------------------------------------------------------------- def list_models() -> str: """Returns a JSON string listing available models.""" result = { "object": "list", "data": [{"id": MODEL_ALIAS, "object": "model", "created": int(time.time()), "owned_by": "qwen"}], } return json.dumps(result) # --------------------------------------------------------------------------- # Tool call parsing (Hermes-style: {...}) # --------------------------------------------------------------------------- def _parse_tool_calls(text: str): """ Detect and extract tool calls from model output. Handles two formats: Format A — Hermes JSON (14b, 30b): {"name": "fn", "arguments": {...}} Format B — XML parameters (Qwen3-Coder): value1 Returns (tool_calls, remaining_content) in OpenAI format, or (None, text) if no tool calls found. """ pattern = r'(.*?)' matches = re.findall(pattern, text, re.DOTALL) if not matches: return None, text tool_calls = [] for match in matches: stripped = match.strip() # ── Format A: JSON inside tool_call ────────────────────────── try: call = json.loads(stripped) tool_calls.append({ "id": f"call_{uuid.uuid4().hex[:24]}", "type": "function", "function": { "name": call.get("name", ""), "arguments": json.dumps(call.get("arguments", call.get("parameters", {}))), }, }) continue except json.JSONDecodeError: pass # ── Format B: XML v ── fn_match = re.search(r']+)>', stripped) if fn_match: fn_name = fn_match.group(1).strip() args = {} for param in re.finditer(r']+)>(.*?)', stripped, re.DOTALL): key = param.group(1).strip() val = param.group(2).strip() # Try to coerce to int/float/bool, otherwise keep as string try: val = json.loads(val) except (json.JSONDecodeError, ValueError): pass args[key] = val tool_calls.append({ "id": f"call_{uuid.uuid4().hex[:24]}", "type": "function", "function": { "name": fn_name, "arguments": json.dumps(args), }, }) if not tool_calls: return None, text remaining = re.sub(pattern, '', text, flags=re.DOTALL).strip() return tool_calls, remaining or None def chat_completions( messages_json: str, max_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9, tools_json: str = "", ) -> str: """ Non-streaming chat completions. Returns an OpenAI-compatible JSON string. messages_json: JSON array of {role, content} objects tools_json: JSON array of OpenAI-format tool definitions (optional) NOTE: Qwen3-Coder is non-thinking only; enable_thinking is not supported. """ try: messages = json.loads(messages_json) except json.JSONDecodeError as e: return json.dumps({"error": f"Invalid messages_json: {e}"}) tools = None if tools_json: try: tools = json.loads(tools_json) except json.JSONDecodeError: pass try: hf_messages = [] for m in messages: role = m["role"] if role == "tool": hf_messages.append({ "role": "tool", "content": m.get("content", ""), "tool_call_id": m.get("tool_call_id", ""), }) elif role == "assistant" and m.get("tool_calls"): # Normalise tool_calls: apply_chat_template needs arguments as a dict, # but OpenAI format (and our own output) stores them as a JSON string. normalised_tool_calls = [] for tc in m["tool_calls"]: fn = tc.get("function", {}) raw_args = fn.get("arguments", "{}") if isinstance(raw_args, str): try: parsed_args = json.loads(raw_args) except json.JSONDecodeError: parsed_args = {} else: parsed_args = raw_args # already a dict normalised_tool_calls.append({ "id": tc.get("id", ""), "type": "function", "function": { "name": fn.get("name", ""), "arguments": parsed_args, # dict, not string }, }) hf_messages.append({ "role": "assistant", "content": m.get("content") or "", "tool_calls": normalised_tool_calls, }) else: hf_messages.append({"role": role, "content": m.get("content", "")}) template_kwargs = dict(tokenize=False, add_generation_prompt=True) if tools: template_kwargs["tools"] = tools prompt = tokenizer.apply_chat_template(hf_messages, **template_kwargs) except Exception as e: return json.dumps({"error": f"Prompt build failed: {e}"}) gen_kwargs = dict( max_new_tokens=max_tokens, temperature=max(temperature, 0.01), top_p=top_p, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) try: raw = _generate_response(prompt, gen_kwargs) except Exception as e: return json.dumps({"error": f"Generation failed: {e}"}) cid = f"chatcmpl-{uuid.uuid4().hex}" tool_calls, content = _parse_tool_calls(raw) if tool_calls: message = {"role": "assistant", "content": content, "tool_calls": tool_calls} finish_reason = "tool_calls" else: message = {"role": "assistant", "content": raw} finish_reason = "stop" result = { "id": cid, "object": "chat.completion", "created": int(time.time()), "model": MODEL_ALIAS, "choices": [{"index": 0, "message": message, "finish_reason": finish_reason}], "usage": {"prompt_tokens": -1, "completion_tokens": -1, "total_tokens": -1}, } return json.dumps(result) def health() -> str: """Returns a JSON health-check string.""" return json.dumps({"status": "ok", "model": MODEL_ID}) # --------------------------------------------------------------------------- # Gradio UI + API # --------------------------------------------------------------------------- with gr.Blocks(title=f"{MODEL_ALIAS} API") as demo: gr.Markdown(f""" # {MODEL_ALIAS} — Gradio API Endpoints (via Gradio built-in API): | api_name | Description | |----------|-------------| | `list_models` | List available models → JSON string | | `chat_completions` | Chat completions → JSON string | | `health` | Health check → JSON string | Call them at `/gradio_api/call/` (POST with `{{"data": [...]}}`) or use the Gradio Python client. You can also chat directly below. """) gr.ChatInterface(fn=gradio_chat) with gr.Row(visible=False): # -- health ------------------------------------------------------ _health_btn = gr.Button("health") _health_out = gr.Textbox() _health_btn.click(fn=health, inputs=[], outputs=[_health_out], api_name="health") # -- list_models ------------------------------------------------- _models_btn = gr.Button("list_models") _models_out = gr.Textbox() _models_btn.click(fn=list_models, inputs=[], outputs=[_models_out], api_name="list_models") with gr.Row(visible=False): # -- chat_completions -------------------------------------------- _cc_messages = gr.Textbox(label="messages_json") _cc_max_tokens = gr.Number(label="max_tokens", value=512) _cc_temp = gr.Number(label="temperature", value=0.7) _cc_top_p = gr.Number(label="top_p", value=0.9) _cc_tools = gr.Textbox(label="tools_json", value="") _cc_out = gr.Textbox(label="result") _cc_btn = gr.Button("chat_completions") _cc_btn.click( fn=chat_completions, inputs=[_cc_messages, _cc_max_tokens, _cc_temp, _cc_top_p, _cc_tools], outputs=[_cc_out], api_name="chat_completions", ) # --------------------------------------------------------------------------- # Entry-point # --------------------------------------------------------------------------- if __name__ == "__main__": demo.queue() demo.launch( server_name="0.0.0.0", server_port=7860, )