| import json |
| import re |
| import time |
| import uuid |
| from typing import Optional |
|
|
| import gradio as gr |
| import spaces |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
|
| |
| |
| |
|
|
| MODEL_ID = "Qwen/Qwen3-30B-A3B" |
| MODEL_ALIAS = "qwen3-30b-a3b" |
|
|
| print(f"Loading tokenizer for {MODEL_ID} β¦") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
| print(f"Loading model {MODEL_ID} β¦") |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| ) |
| model.eval() |
| print("Model ready.") |
|
|
|
|
| |
| |
| |
|
|
|
|
| @spaces.GPU |
| def gradio_chat(message: str, history: list) -> str: |
| hf_messages = [{"role": "user" if i % 2 == 0 else "assistant", "content": m} |
| for i, m in enumerate([msg for pair in history for msg in pair] + [message])] |
| prompt = tokenizer.apply_chat_template( |
| hf_messages, tokenize=False, add_generation_prompt=True, enable_thinking=False |
| ) |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| with torch.no_grad(): |
| output_ids = model.generate( |
| **inputs, |
| max_new_tokens=512, |
| do_sample=True, |
| temperature=0.7, |
| top_p=0.9, |
| pad_token_id=tokenizer.eos_token_id, |
| ) |
| new_ids = output_ids[0][inputs["input_ids"].shape[1]:] |
| return tokenizer.decode(new_ids, skip_special_tokens=True) |
|
|
|
|
| @spaces.GPU |
| def _generate_response(prompt: str, gen_kwargs: dict) -> str: |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| with torch.no_grad(): |
| output_ids = model.generate(**inputs, **gen_kwargs) |
| new_ids = output_ids[0][inputs["input_ids"].shape[1]:] |
| return tokenizer.decode(new_ids, skip_special_tokens=True) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def list_models() -> str: |
| """Returns a JSON string listing available models.""" |
| result = { |
| "object": "list", |
| "data": [{"id": MODEL_ALIAS, "object": "model", "created": int(time.time()), "owned_by": "qwen"}], |
| } |
| return json.dumps(result) |
|
|
|
|
| |
| |
| |
|
|
| def _parse_tool_calls(text: str): |
| """ |
| Detect and extract Hermes-style tool calls from model output. |
| Returns (tool_calls, remaining_content) where tool_calls is a list in |
| OpenAI format, or (None, text) if no tool calls are found. |
| """ |
| pattern = r'<tool_call>(.*?)</tool_call>' |
| matches = re.findall(pattern, text, re.DOTALL) |
| if not matches: |
| return None, text |
|
|
| tool_calls = [] |
| for match in matches: |
| try: |
| call = json.loads(match.strip()) |
| tool_calls.append({ |
| "id": f"call_{uuid.uuid4().hex[:24]}", |
| "type": "function", |
| "function": { |
| "name": call.get("name", ""), |
| "arguments": json.dumps(call.get("arguments", call.get("parameters", {}))), |
| }, |
| }) |
| except json.JSONDecodeError: |
| continue |
|
|
| if not tool_calls: |
| return None, text |
|
|
| remaining = re.sub(pattern, '', text, flags=re.DOTALL) |
| remaining = re.sub(r'<think>.*?</think>', '', remaining, flags=re.DOTALL).strip() |
| return tool_calls, remaining or None |
|
|
|
|
| def chat_completions( |
| messages_json: str, |
| max_tokens: int = 512, |
| temperature: float = 0.7, |
| top_p: float = 0.9, |
| enable_thinking: bool = False, |
| tools_json: str = "", |
| ) -> str: |
| """ |
| Non-streaming chat completions. Returns an OpenAI-compatible JSON string. |
| |
| messages_json: JSON array of {role, content} objects |
| enable_thinking: enable Qwen3 chain-of-thought reasoning (default False for tool use) |
| tools_json: JSON array of OpenAI-format tool definitions (optional) |
| """ |
| try: |
| messages = json.loads(messages_json) |
| except json.JSONDecodeError as e: |
| return json.dumps({"error": f"Invalid messages_json: {e}"}) |
|
|
| tools = None |
| if tools_json: |
| try: |
| tools = json.loads(tools_json) |
| except json.JSONDecodeError: |
| pass |
|
|
| try: |
| hf_messages = [] |
| for m in messages: |
| role = m["role"] |
| if role == "tool": |
| hf_messages.append({ |
| "role": "tool", |
| "content": m.get("content", ""), |
| "tool_call_id": m.get("tool_call_id", ""), |
| }) |
| elif role == "assistant" and m.get("tool_calls"): |
| hf_messages.append({ |
| "role": "assistant", |
| "content": m.get("content") or "", |
| "tool_calls": m["tool_calls"], |
| }) |
| else: |
| hf_messages.append({"role": role, "content": m.get("content", "")}) |
|
|
| template_kwargs = dict( |
| tokenize=False, |
| add_generation_prompt=True, |
| enable_thinking=enable_thinking, |
| ) |
| if tools: |
| template_kwargs["tools"] = tools |
|
|
| prompt = tokenizer.apply_chat_template(hf_messages, **template_kwargs) |
| except Exception as e: |
| return json.dumps({"error": f"Prompt build failed: {e}"}) |
|
|
| gen_kwargs = dict( |
| max_new_tokens=max_tokens, |
| temperature=max(temperature, 0.01), |
| top_p=top_p, |
| do_sample=True, |
| pad_token_id=tokenizer.eos_token_id, |
| ) |
|
|
| try: |
| raw = _generate_response(prompt, gen_kwargs) |
| except Exception as e: |
| return json.dumps({"error": f"Generation failed: {e}"}) |
|
|
| cid = f"chatcmpl-{uuid.uuid4().hex}" |
| tool_calls, content = _parse_tool_calls(raw) |
|
|
| if tool_calls: |
| message = {"role": "assistant", "content": content, "tool_calls": tool_calls} |
| finish_reason = "tool_calls" |
| else: |
| |
| content = re.sub(r'<think>.*?</think>', '', raw, flags=re.DOTALL).strip() |
| message = {"role": "assistant", "content": content} |
| finish_reason = "stop" |
|
|
| result = { |
| "id": cid, |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": MODEL_ALIAS, |
| "choices": [{"index": 0, "message": message, "finish_reason": finish_reason}], |
| "usage": {"prompt_tokens": -1, "completion_tokens": -1, "total_tokens": -1}, |
| } |
| return json.dumps(result) |
|
|
|
|
| def health() -> str: |
| """Returns a JSON health-check string.""" |
| return json.dumps({"status": "ok", "model": MODEL_ID}) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| with gr.Blocks(title=f"{MODEL_ALIAS} API") as demo: |
| gr.Markdown(f""" |
| # {MODEL_ALIAS} β Gradio API |
| |
| Endpoints (via Gradio built-in API): |
| |
| | api_name | Description | |
| |----------|-------------| |
| | `list_models` | List available models β JSON string | |
| | `chat_completions` | Chat completions β JSON string | |
| | `health` | Health check β JSON string | |
| |
| Call them at `/gradio_api/call/<api_name>` (POST with `{{"data": [...]}}`) |
| or use the Gradio Python client. |
| |
| You can also chat directly below. |
| """) |
|
|
| gr.ChatInterface(fn=gradio_chat) |
|
|
| |
| |
| |
| |
| |
| |
| with gr.Row(visible=False): |
| |
| _health_btn = gr.Button("health") |
| _health_out = gr.Textbox() |
| _health_btn.click(fn=health, inputs=[], outputs=[_health_out], api_name="health") |
|
|
| |
| _models_btn = gr.Button("list_models") |
| _models_out = gr.Textbox() |
| _models_btn.click(fn=list_models, inputs=[], outputs=[_models_out], api_name="list_models") |
|
|
| with gr.Row(visible=False): |
| |
| _cc_messages = gr.Textbox(label="messages_json") |
| _cc_max_tokens = gr.Number(label="max_tokens", value=512) |
| _cc_temp = gr.Number(label="temperature", value=0.7) |
| _cc_top_p = gr.Number(label="top_p", value=0.9) |
| _cc_thinking = gr.Checkbox(label="enable_thinking", value=False) |
| _cc_tools = gr.Textbox(label="tools_json", value="") |
| _cc_out = gr.Textbox(label="result") |
| _cc_btn = gr.Button("chat_completions") |
| _cc_btn.click( |
| fn=chat_completions, |
| inputs=[_cc_messages, _cc_max_tokens, _cc_temp, _cc_top_p, _cc_thinking, _cc_tools], |
| outputs=[_cc_out], |
| api_name="chat_completions", |
| ) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| demo.queue() |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| ) |
|
|