fomext's picture
Upload app.py
62d5c00 verified
Raw
History Blame Contribute Delete
9.85 kB
import json
import re
import time
import uuid
from typing import Optional
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
MODEL_ID = "Qwen/Qwen3-14B"
MODEL_ALIAS = "qwen3-14b-4bit"
print(f"Loading tokenizer for {MODEL_ID} …")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print(f"Loading model {MODEL_ID} in 4-bit …")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=bnb_config,
device_map="auto",
)
model.eval()
print("Model ready.")
# ---------------------------------------------------------------------------
# GPU generation functions — ZeroGPU anchors
# ---------------------------------------------------------------------------
@spaces.GPU
def gradio_chat(message: str, history: list) -> str:
hf_messages = [{"role": "user" if i % 2 == 0 else "assistant", "content": m}
for i, m in enumerate([msg for pair in history for msg in pair] + [message])]
prompt = tokenizer.apply_chat_template(
hf_messages, tokenize=False, add_generation_prompt=True
# NOTE: Qwen3-Coder is non-thinking only; enable_thinking is not supported.
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
)
new_ids = output_ids[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(new_ids, skip_special_tokens=True)
@spaces.GPU
def _generate_response(prompt: str, gen_kwargs: dict) -> str:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(**inputs, **gen_kwargs)
new_ids = output_ids[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(new_ids, skip_special_tokens=True)
# ---------------------------------------------------------------------------
# API functions
# ---------------------------------------------------------------------------
def list_models() -> str:
"""Returns a JSON string listing available models."""
result = {
"object": "list",
"data": [{"id": MODEL_ALIAS, "object": "model", "created": int(time.time()), "owned_by": "qwen"}],
}
return json.dumps(result)
# ---------------------------------------------------------------------------
# Tool call parsing (Hermes-style: <tool_call>{...}</tool_call>)
# ---------------------------------------------------------------------------
def _parse_tool_calls(text: str):
"""
Detect and extract Hermes-style tool calls from model output.
Returns (tool_calls, remaining_content) where tool_calls is a list in
OpenAI format, or (None, text) if no tool calls are found.
"""
pattern = r'<tool_call>(.*?)</tool_call>'
matches = re.findall(pattern, text, re.DOTALL)
if not matches:
return None, text
tool_calls = []
for match in matches:
try:
call = json.loads(match.strip())
tool_calls.append({
"id": f"call_{uuid.uuid4().hex[:24]}",
"type": "function",
"function": {
"name": call.get("name", ""),
"arguments": json.dumps(call.get("arguments", call.get("parameters", {}))),
},
})
except json.JSONDecodeError:
continue
if not tool_calls:
return None, text
# Strip tool call blocks and think tags from remaining content
remaining = re.sub(pattern, '', text, flags=re.DOTALL)
remaining = re.sub(r'<think>.*?</think>', '', remaining, flags=re.DOTALL).strip()
return tool_calls, remaining or None
def chat_completions(
messages_json: str,
max_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
tools_json: str = "",
) -> str:
"""
Non-streaming chat completions. Returns an OpenAI-compatible JSON string.
messages_json: JSON array of {role, content} objects
tools_json: JSON array of OpenAI-format tool definitions (optional)
NOTE: Qwen3-14B supports thinking mode. enable_thinking is set to False
here for reliable tool call formatting.
"""
try:
messages = json.loads(messages_json)
except json.JSONDecodeError as e:
return json.dumps({"error": f"Invalid messages_json: {e}"})
tools = None
if tools_json:
try:
tools = json.loads(tools_json)
except json.JSONDecodeError:
pass
try:
hf_messages = []
for m in messages:
role = m["role"]
# tool results come in as role=tool; map to role=tool with tool_call_id
if role == "tool":
hf_messages.append({
"role": "tool",
"content": m.get("content", ""),
"tool_call_id": m.get("tool_call_id", ""),
})
elif role == "assistant" and m.get("tool_calls"):
hf_messages.append({
"role": "assistant",
"content": m.get("content") or "",
"tool_calls": m["tool_calls"],
})
else:
hf_messages.append({"role": role, "content": m.get("content", "")})
template_kwargs = dict(
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
if tools:
template_kwargs["tools"] = tools
prompt = tokenizer.apply_chat_template(hf_messages, **template_kwargs)
except Exception as e:
return json.dumps({"error": f"Prompt build failed: {e}"})
gen_kwargs = dict(
max_new_tokens=max_tokens,
temperature=max(temperature, 0.01),
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
try:
raw = _generate_response(prompt, gen_kwargs)
except Exception as e:
return json.dumps({"error": f"Generation failed: {e}"})
cid = f"chatcmpl-{uuid.uuid4().hex}"
tool_calls, content = _parse_tool_calls(raw)
if tool_calls:
message = {"role": "assistant", "content": content, "tool_calls": tool_calls}
finish_reason = "tool_calls"
else:
# Strip any stray think tags from plain responses
content = re.sub(r'<think>.*?</think>', '', raw, flags=re.DOTALL).strip()
message = {"role": "assistant", "content": content}
finish_reason = "stop"
result = {
"id": cid,
"object": "chat.completion",
"created": int(time.time()),
"model": MODEL_ALIAS,
"choices": [{"index": 0, "message": message, "finish_reason": finish_reason}],
"usage": {"prompt_tokens": -1, "completion_tokens": -1, "total_tokens": -1},
}
return json.dumps(result)
def health() -> str:
"""Returns a JSON health-check string."""
return json.dumps({"status": "ok", "model": MODEL_ID})
# ---------------------------------------------------------------------------
# Gradio UI + API
# ---------------------------------------------------------------------------
with gr.Blocks(title=f"{MODEL_ALIAS} API") as demo:
gr.Markdown(f"""
# {MODEL_ALIAS} — Gradio API
Endpoints (via Gradio built-in API):
| api_name | Description |
|----------|-------------|
| `list_models` | List available models → JSON string |
| `chat_completions` | Chat completions → JSON string |
| `health` | Health check → JSON string |
Call them at `/gradio_api/call/<api_name>` (POST with `{{"data": [...]}}`)
or use the Gradio Python client.
You can also chat directly below.
""")
gr.ChatInterface(fn=gradio_chat)
with gr.Row(visible=False):
# -- health ------------------------------------------------------
_health_btn = gr.Button("health")
_health_out = gr.Textbox()
_health_btn.click(fn=health, inputs=[], outputs=[_health_out], api_name="health")
# -- list_models -------------------------------------------------
_models_btn = gr.Button("list_models")
_models_out = gr.Textbox()
_models_btn.click(fn=list_models, inputs=[], outputs=[_models_out], api_name="list_models")
with gr.Row(visible=False):
# -- chat_completions --------------------------------------------
_cc_messages = gr.Textbox(label="messages_json")
_cc_max_tokens = gr.Number(label="max_tokens", value=512)
_cc_temp = gr.Number(label="temperature", value=0.7)
_cc_top_p = gr.Number(label="top_p", value=0.9)
_cc_tools = gr.Textbox(label="tools_json", value="")
_cc_out = gr.Textbox(label="result")
_cc_btn = gr.Button("chat_completions")
_cc_btn.click(
fn=chat_completions,
inputs=[_cc_messages, _cc_max_tokens, _cc_temp, _cc_top_p, _cc_tools],
outputs=[_cc_out],
api_name="chat_completions",
)
# ---------------------------------------------------------------------------
# Entry-point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
demo.queue()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
)