Spaces:
Paused
Paused
| import json | |
| import re | |
| import time | |
| import uuid | |
| from typing import Optional | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| # --------------------------------------------------------------------------- | |
| # Configuration | |
| # --------------------------------------------------------------------------- | |
| MODEL_ID = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8" | |
| MODEL_ALIAS = "qwen3-coder-30b-a3b-instruct-fp8" | |
| print(f"Loading tokenizer for {MODEL_ID} …") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| print(f"Loading model {MODEL_ID} …") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| ) | |
| model.eval() | |
| print("Model ready.") | |
| # --------------------------------------------------------------------------- | |
| # GPU generation functions — ZeroGPU anchors | |
| # --------------------------------------------------------------------------- | |
| def gradio_chat(message: str, history: list) -> str: | |
| hf_messages = [{"role": "user" if i % 2 == 0 else "assistant", "content": m} | |
| for i, m in enumerate([msg for pair in history for msg in pair] + [message])] | |
| prompt = tokenizer.apply_chat_template( | |
| hf_messages, tokenize=False, add_generation_prompt=True | |
| # NOTE: Qwen3-Coder is non-thinking only; enable_thinking is not supported. | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| new_ids = output_ids[0][inputs["input_ids"].shape[1]:] | |
| return tokenizer.decode(new_ids, skip_special_tokens=True) | |
| def _generate_response(prompt: str, gen_kwargs: dict) -> str: | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| output_ids = model.generate(**inputs, **gen_kwargs) | |
| new_ids = output_ids[0][inputs["input_ids"].shape[1]:] | |
| return tokenizer.decode(new_ids, skip_special_tokens=True) | |
| # --------------------------------------------------------------------------- | |
| # API functions | |
| # --------------------------------------------------------------------------- | |
| def list_models() -> str: | |
| """Returns a JSON string listing available models.""" | |
| result = { | |
| "object": "list", | |
| "data": [{"id": MODEL_ALIAS, "object": "model", "created": int(time.time()), "owned_by": "qwen"}], | |
| } | |
| return json.dumps(result) | |
| # --------------------------------------------------------------------------- | |
| # Tool call parsing (Hermes-style: <tool_call>{...}</tool_call>) | |
| # --------------------------------------------------------------------------- | |
| def _parse_tool_calls(text: str): | |
| """ | |
| Detect and extract tool calls from model output. | |
| Handles two formats: | |
| Format A — Hermes JSON (14b, 30b): | |
| <tool_call>{"name": "fn", "arguments": {...}}</tool_call> | |
| Format B — XML parameters (Qwen3-Coder): | |
| <tool_call> | |
| <function=fn_name> | |
| <parameter=param1>value1</parameter> | |
| </function> | |
| </tool_call> | |
| Returns (tool_calls, remaining_content) in OpenAI format, | |
| or (None, text) if no tool calls found. | |
| """ | |
| pattern = r'<tool_call>(.*?)</tool_call>' | |
| matches = re.findall(pattern, text, re.DOTALL) | |
| if not matches: | |
| return None, text | |
| tool_calls = [] | |
| for match in matches: | |
| stripped = match.strip() | |
| # ── Format A: JSON inside tool_call ────────────────────────── | |
| try: | |
| call = json.loads(stripped) | |
| tool_calls.append({ | |
| "id": f"call_{uuid.uuid4().hex[:24]}", | |
| "type": "function", | |
| "function": { | |
| "name": call.get("name", ""), | |
| "arguments": json.dumps(call.get("arguments", call.get("parameters", {}))), | |
| }, | |
| }) | |
| continue | |
| except json.JSONDecodeError: | |
| pass | |
| # ── Format B: XML <function=name><parameter=k>v</parameter> ── | |
| fn_match = re.search(r'<function=([^>]+)>', stripped) | |
| if fn_match: | |
| fn_name = fn_match.group(1).strip() | |
| args = {} | |
| for param in re.finditer(r'<parameter=([^>]+)>(.*?)</parameter>', stripped, re.DOTALL): | |
| key = param.group(1).strip() | |
| val = param.group(2).strip() | |
| # Try to coerce to int/float/bool, otherwise keep as string | |
| try: | |
| val = json.loads(val) | |
| except (json.JSONDecodeError, ValueError): | |
| pass | |
| args[key] = val | |
| tool_calls.append({ | |
| "id": f"call_{uuid.uuid4().hex[:24]}", | |
| "type": "function", | |
| "function": { | |
| "name": fn_name, | |
| "arguments": json.dumps(args), | |
| }, | |
| }) | |
| if not tool_calls: | |
| return None, text | |
| remaining = re.sub(pattern, '', text, flags=re.DOTALL).strip() | |
| return tool_calls, remaining or None | |
| def chat_completions( | |
| messages_json: str, | |
| max_tokens: int = 512, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9, | |
| tools_json: str = "", | |
| ) -> str: | |
| """ | |
| Non-streaming chat completions. Returns an OpenAI-compatible JSON string. | |
| messages_json: JSON array of {role, content} objects | |
| tools_json: JSON array of OpenAI-format tool definitions (optional) | |
| NOTE: Qwen3-Coder is non-thinking only; enable_thinking is not supported. | |
| """ | |
| try: | |
| messages = json.loads(messages_json) | |
| except json.JSONDecodeError as e: | |
| return json.dumps({"error": f"Invalid messages_json: {e}"}) | |
| tools = None | |
| if tools_json: | |
| try: | |
| tools = json.loads(tools_json) | |
| except json.JSONDecodeError: | |
| pass | |
| try: | |
| hf_messages = [] | |
| for m in messages: | |
| role = m["role"] | |
| if role == "tool": | |
| hf_messages.append({ | |
| "role": "tool", | |
| "content": m.get("content", ""), | |
| "tool_call_id": m.get("tool_call_id", ""), | |
| }) | |
| elif role == "assistant" and m.get("tool_calls"): | |
| # Normalise tool_calls: apply_chat_template needs arguments as a dict, | |
| # but OpenAI format (and our own output) stores them as a JSON string. | |
| normalised_tool_calls = [] | |
| for tc in m["tool_calls"]: | |
| fn = tc.get("function", {}) | |
| raw_args = fn.get("arguments", "{}") | |
| if isinstance(raw_args, str): | |
| try: | |
| parsed_args = json.loads(raw_args) | |
| except json.JSONDecodeError: | |
| parsed_args = {} | |
| else: | |
| parsed_args = raw_args # already a dict | |
| normalised_tool_calls.append({ | |
| "id": tc.get("id", ""), | |
| "type": "function", | |
| "function": { | |
| "name": fn.get("name", ""), | |
| "arguments": parsed_args, # dict, not string | |
| }, | |
| }) | |
| hf_messages.append({ | |
| "role": "assistant", | |
| "content": m.get("content") or "", | |
| "tool_calls": normalised_tool_calls, | |
| }) | |
| else: | |
| hf_messages.append({"role": role, "content": m.get("content", "")}) | |
| template_kwargs = dict(tokenize=False, add_generation_prompt=True) | |
| if tools: | |
| template_kwargs["tools"] = tools | |
| prompt = tokenizer.apply_chat_template(hf_messages, **template_kwargs) | |
| except Exception as e: | |
| return json.dumps({"error": f"Prompt build failed: {e}"}) | |
| gen_kwargs = dict( | |
| max_new_tokens=max_tokens, | |
| temperature=max(temperature, 0.01), | |
| top_p=top_p, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| try: | |
| raw = _generate_response(prompt, gen_kwargs) | |
| except Exception as e: | |
| return json.dumps({"error": f"Generation failed: {e}"}) | |
| cid = f"chatcmpl-{uuid.uuid4().hex}" | |
| tool_calls, content = _parse_tool_calls(raw) | |
| if tool_calls: | |
| message = {"role": "assistant", "content": content, "tool_calls": tool_calls} | |
| finish_reason = "tool_calls" | |
| else: | |
| message = {"role": "assistant", "content": raw} | |
| finish_reason = "stop" | |
| result = { | |
| "id": cid, | |
| "object": "chat.completion", | |
| "created": int(time.time()), | |
| "model": MODEL_ALIAS, | |
| "choices": [{"index": 0, "message": message, "finish_reason": finish_reason}], | |
| "usage": {"prompt_tokens": -1, "completion_tokens": -1, "total_tokens": -1}, | |
| } | |
| return json.dumps(result) | |
| def health() -> str: | |
| """Returns a JSON health-check string.""" | |
| return json.dumps({"status": "ok", "model": MODEL_ID}) | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI + API | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title=f"{MODEL_ALIAS} API") as demo: | |
| gr.Markdown(f""" | |
| # {MODEL_ALIAS} — Gradio API | |
| Endpoints (via Gradio built-in API): | |
| | api_name | Description | | |
| |----------|-------------| | |
| | `list_models` | List available models → JSON string | | |
| | `chat_completions` | Chat completions → JSON string | | |
| | `health` | Health check → JSON string | | |
| Call them at `/gradio_api/call/<api_name>` (POST with `{{"data": [...]}}`) | |
| or use the Gradio Python client. | |
| You can also chat directly below. | |
| """) | |
| gr.ChatInterface(fn=gradio_chat) | |
| with gr.Row(visible=False): | |
| # -- health ------------------------------------------------------ | |
| _health_btn = gr.Button("health") | |
| _health_out = gr.Textbox() | |
| _health_btn.click(fn=health, inputs=[], outputs=[_health_out], api_name="health") | |
| # -- list_models ------------------------------------------------- | |
| _models_btn = gr.Button("list_models") | |
| _models_out = gr.Textbox() | |
| _models_btn.click(fn=list_models, inputs=[], outputs=[_models_out], api_name="list_models") | |
| with gr.Row(visible=False): | |
| # -- chat_completions -------------------------------------------- | |
| _cc_messages = gr.Textbox(label="messages_json") | |
| _cc_max_tokens = gr.Number(label="max_tokens", value=512) | |
| _cc_temp = gr.Number(label="temperature", value=0.7) | |
| _cc_top_p = gr.Number(label="top_p", value=0.9) | |
| _cc_tools = gr.Textbox(label="tools_json", value="") | |
| _cc_out = gr.Textbox(label="result") | |
| _cc_btn = gr.Button("chat_completions") | |
| _cc_btn.click( | |
| fn=chat_completions, | |
| inputs=[_cc_messages, _cc_max_tokens, _cc_temp, _cc_top_p, _cc_tools], | |
| outputs=[_cc_out], | |
| api_name="chat_completions", | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Entry-point | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| ) | |