Spaces:
Paused
Paused
| import json | |
| import re | |
| import time | |
| import uuid | |
| from typing import Optional | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig | |
| # --------------------------------------------------------------------------- | |
| # Configuration | |
| # --------------------------------------------------------------------------- | |
| MODEL_ID = "Qwen/Qwen3-14B" | |
| MODEL_ALIAS = "qwen3-14b-4bit" | |
| print(f"Loading tokenizer for {MODEL_ID} …") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| print(f"Loading model {MODEL_ID} in 4-bit …") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| ) | |
| model.eval() | |
| print("Model ready.") | |
| # --------------------------------------------------------------------------- | |
| # GPU generation functions — ZeroGPU anchors | |
| # --------------------------------------------------------------------------- | |
| def gradio_chat(message: str, history: list) -> str: | |
| hf_messages = [{"role": "user" if i % 2 == 0 else "assistant", "content": m} | |
| for i, m in enumerate([msg for pair in history for msg in pair] + [message])] | |
| prompt = tokenizer.apply_chat_template( | |
| hf_messages, tokenize=False, add_generation_prompt=True | |
| # NOTE: Qwen3-Coder is non-thinking only; enable_thinking is not supported. | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| new_ids = output_ids[0][inputs["input_ids"].shape[1]:] | |
| return tokenizer.decode(new_ids, skip_special_tokens=True) | |
| def _generate_response(prompt: str, gen_kwargs: dict) -> str: | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| output_ids = model.generate(**inputs, **gen_kwargs) | |
| new_ids = output_ids[0][inputs["input_ids"].shape[1]:] | |
| return tokenizer.decode(new_ids, skip_special_tokens=True) | |
| # --------------------------------------------------------------------------- | |
| # API functions | |
| # --------------------------------------------------------------------------- | |
| def list_models() -> str: | |
| """Returns a JSON string listing available models.""" | |
| result = { | |
| "object": "list", | |
| "data": [{"id": MODEL_ALIAS, "object": "model", "created": int(time.time()), "owned_by": "qwen"}], | |
| } | |
| return json.dumps(result) | |
| # --------------------------------------------------------------------------- | |
| # Tool call parsing (Hermes-style: <tool_call>{...}</tool_call>) | |
| # --------------------------------------------------------------------------- | |
| def _parse_tool_calls(text: str): | |
| """ | |
| Detect and extract Hermes-style tool calls from model output. | |
| Returns (tool_calls, remaining_content) where tool_calls is a list in | |
| OpenAI format, or (None, text) if no tool calls are found. | |
| """ | |
| pattern = r'<tool_call>(.*?)</tool_call>' | |
| matches = re.findall(pattern, text, re.DOTALL) | |
| if not matches: | |
| return None, text | |
| tool_calls = [] | |
| for match in matches: | |
| try: | |
| call = json.loads(match.strip()) | |
| tool_calls.append({ | |
| "id": f"call_{uuid.uuid4().hex[:24]}", | |
| "type": "function", | |
| "function": { | |
| "name": call.get("name", ""), | |
| "arguments": json.dumps(call.get("arguments", call.get("parameters", {}))), | |
| }, | |
| }) | |
| except json.JSONDecodeError: | |
| continue | |
| if not tool_calls: | |
| return None, text | |
| # Strip tool call blocks and think tags from remaining content | |
| remaining = re.sub(pattern, '', text, flags=re.DOTALL) | |
| remaining = re.sub(r'<think>.*?</think>', '', remaining, flags=re.DOTALL).strip() | |
| return tool_calls, remaining or None | |
| def chat_completions( | |
| messages_json: str, | |
| max_tokens: int = 512, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9, | |
| tools_json: str = "", | |
| ) -> str: | |
| """ | |
| Non-streaming chat completions. Returns an OpenAI-compatible JSON string. | |
| messages_json: JSON array of {role, content} objects | |
| tools_json: JSON array of OpenAI-format tool definitions (optional) | |
| NOTE: Qwen3-14B supports thinking mode. enable_thinking is set to False | |
| here for reliable tool call formatting. | |
| """ | |
| try: | |
| messages = json.loads(messages_json) | |
| except json.JSONDecodeError as e: | |
| return json.dumps({"error": f"Invalid messages_json: {e}"}) | |
| tools = None | |
| if tools_json: | |
| try: | |
| tools = json.loads(tools_json) | |
| except json.JSONDecodeError: | |
| pass | |
| try: | |
| hf_messages = [] | |
| for m in messages: | |
| role = m["role"] | |
| # tool results come in as role=tool; map to role=tool with tool_call_id | |
| if role == "tool": | |
| hf_messages.append({ | |
| "role": "tool", | |
| "content": m.get("content", ""), | |
| "tool_call_id": m.get("tool_call_id", ""), | |
| }) | |
| elif role == "assistant" and m.get("tool_calls"): | |
| hf_messages.append({ | |
| "role": "assistant", | |
| "content": m.get("content") or "", | |
| "tool_calls": m["tool_calls"], | |
| }) | |
| else: | |
| hf_messages.append({"role": role, "content": m.get("content", "")}) | |
| template_kwargs = dict( | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=False, | |
| ) | |
| if tools: | |
| template_kwargs["tools"] = tools | |
| prompt = tokenizer.apply_chat_template(hf_messages, **template_kwargs) | |
| except Exception as e: | |
| return json.dumps({"error": f"Prompt build failed: {e}"}) | |
| gen_kwargs = dict( | |
| max_new_tokens=max_tokens, | |
| temperature=max(temperature, 0.01), | |
| top_p=top_p, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| try: | |
| raw = _generate_response(prompt, gen_kwargs) | |
| except Exception as e: | |
| return json.dumps({"error": f"Generation failed: {e}"}) | |
| cid = f"chatcmpl-{uuid.uuid4().hex}" | |
| tool_calls, content = _parse_tool_calls(raw) | |
| if tool_calls: | |
| message = {"role": "assistant", "content": content, "tool_calls": tool_calls} | |
| finish_reason = "tool_calls" | |
| else: | |
| # Strip any stray think tags from plain responses | |
| content = re.sub(r'<think>.*?</think>', '', raw, flags=re.DOTALL).strip() | |
| message = {"role": "assistant", "content": content} | |
| finish_reason = "stop" | |
| result = { | |
| "id": cid, | |
| "object": "chat.completion", | |
| "created": int(time.time()), | |
| "model": MODEL_ALIAS, | |
| "choices": [{"index": 0, "message": message, "finish_reason": finish_reason}], | |
| "usage": {"prompt_tokens": -1, "completion_tokens": -1, "total_tokens": -1}, | |
| } | |
| return json.dumps(result) | |
| def health() -> str: | |
| """Returns a JSON health-check string.""" | |
| return json.dumps({"status": "ok", "model": MODEL_ID}) | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI + API | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title=f"{MODEL_ALIAS} API") as demo: | |
| gr.Markdown(f""" | |
| # {MODEL_ALIAS} — Gradio API | |
| Endpoints (via Gradio built-in API): | |
| | api_name | Description | | |
| |----------|-------------| | |
| | `list_models` | List available models → JSON string | | |
| | `chat_completions` | Chat completions → JSON string | | |
| | `health` | Health check → JSON string | | |
| Call them at `/gradio_api/call/<api_name>` (POST with `{{"data": [...]}}`) | |
| or use the Gradio Python client. | |
| You can also chat directly below. | |
| """) | |
| gr.ChatInterface(fn=gradio_chat) | |
| with gr.Row(visible=False): | |
| # -- health ------------------------------------------------------ | |
| _health_btn = gr.Button("health") | |
| _health_out = gr.Textbox() | |
| _health_btn.click(fn=health, inputs=[], outputs=[_health_out], api_name="health") | |
| # -- list_models ------------------------------------------------- | |
| _models_btn = gr.Button("list_models") | |
| _models_out = gr.Textbox() | |
| _models_btn.click(fn=list_models, inputs=[], outputs=[_models_out], api_name="list_models") | |
| with gr.Row(visible=False): | |
| # -- chat_completions -------------------------------------------- | |
| _cc_messages = gr.Textbox(label="messages_json") | |
| _cc_max_tokens = gr.Number(label="max_tokens", value=512) | |
| _cc_temp = gr.Number(label="temperature", value=0.7) | |
| _cc_top_p = gr.Number(label="top_p", value=0.9) | |
| _cc_tools = gr.Textbox(label="tools_json", value="") | |
| _cc_out = gr.Textbox(label="result") | |
| _cc_btn = gr.Button("chat_completions") | |
| _cc_btn.click( | |
| fn=chat_completions, | |
| inputs=[_cc_messages, _cc_max_tokens, _cc_temp, _cc_top_p, _cc_tools], | |
| outputs=[_cc_out], | |
| api_name="chat_completions", | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Entry-point | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| ) | |