qwen3.5api / app.py
sidmaz666's picture
Update app.py
8b3ded7 verified
import json
import threading
import gc
import re
import base64
import asyncio
import queue
from io import BytesIO
from typing import List, Optional, Union, Dict, Any, Literal, AsyncGenerator
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field, field_validator
from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer
import torch
import requests
from PIL import Image
# ----------------------------------------------------------------------
# App setup
# ----------------------------------------------------------------------
app = FastAPI(
title="Qwen3.5 Multimodal API",
description="OpenAI‑compatible chat completions API with vision, tool calling, and JSON mode.",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ----------------------------------------------------------------------
# Model management
# ----------------------------------------------------------------------
ALLOWED_MODELS = ["Qwen/Qwen3.5-2B", "Qwen/Qwen3.5-0.8B"]
DEFAULT_MODEL = "Qwen/Qwen3.5-2B"
current_model_id: Optional[str] = None
model: Optional[AutoModelForCausalLM] = None
processor: Optional[AutoProcessor] = None
def load_model_global(model_id: str):
global model, processor, current_model_id
if current_model_id == model_id:
return
print(f"Loading model {model_id}...")
if model is not None:
del model
del processor
gc.collect()
torch.cuda.empty_cache()
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="cpu",
trust_remote_code=True
)
model.eval()
current_model_id = model_id
print(f"Model {model_id} ready.")
# ----------------------------------------------------------------------
# Helper: image loading (URL or base64)
# ----------------------------------------------------------------------
def load_image_from_input(image_input: Union[str, Dict]) -> Image.Image:
url = image_input if isinstance(image_input, str) else image_input.get("url", "")
if url.startswith("data:image"):
match = re.match(r"data:image/(?P<fmt>.+?);base64,(?P<data>.+)", url)
if not match:
raise ValueError("Invalid base64 image format")
img_data = base64.b64decode(match.group("data"))
return Image.open(BytesIO(img_data)).convert("RGB")
else:
resp = requests.get(url, timeout=10)
resp.raise_for_status()
return Image.open(BytesIO(resp.content)).convert("RGB")
# ----------------------------------------------------------------------
# OpenAI‑compatible schemas
# ----------------------------------------------------------------------
class FunctionDefinition(BaseModel):
name: str
description: Optional[str] = None
parameters: Optional[Dict[str, Any]] = None
class Tool(BaseModel):
type: Literal["function"] = "function"
function: FunctionDefinition
class ToolChoice(BaseModel):
type: Literal["function"] = "function"
function: Dict[str, str]
class ResponseFormat(BaseModel):
type: Literal["text", "json_object"] = "text"
class ImageURL(BaseModel):
url: str
detail: Optional[Literal["auto", "low", "high"]] = "auto"
class ContentPart(BaseModel):
type: str
text: Optional[str] = None
image_url: Optional[ImageURL] = None
class ChatMessage(BaseModel):
role: Literal["system", "user", "assistant", "tool"]
content: Optional[Union[str, List[ContentPart]]] = None
name: Optional[str] = None
tool_calls: Optional[List[Dict[str, Any]]] = None
tool_call_id: Optional[str] = None
class ChatCompletionRequest(BaseModel):
model: str = DEFAULT_MODEL
messages: List[ChatMessage]
temperature: Optional[float] = 0.7
max_tokens: Optional[int] = 256
top_p: Optional[float] = 1.0
frequency_penalty: Optional[float] = 0.0
presence_penalty: Optional[float] = 0.0
stream: bool = False
tools: Optional[List[Tool]] = None
tool_choice: Optional[Union[Literal["none", "auto"], ToolChoice]] = "auto"
response_format: Optional[ResponseFormat] = ResponseFormat(type="text")
@field_validator("model")
def validate_model(cls, v):
if v not in ALLOWED_MODELS:
raise ValueError(f"Model must be one of {ALLOWED_MODELS}")
return v
class TokenCountRequest(BaseModel):
model: str = DEFAULT_MODEL
messages: List[ChatMessage]
tools: Optional[List[Tool]] = None
class TokenCountResponse(BaseModel):
total_tokens: int
messages_tokens: int
tools_tokens: int = 0
# ----------------------------------------------------------------------
# Helper: token counting
# ----------------------------------------------------------------------
def count_tokens(text: str) -> int:
if processor is None:
raise HTTPException(503, "Model not loaded")
return len(processor.tokenizer.encode(text))
def count_messages_tokens(messages: List[ChatMessage], tools: Optional[List[Tool]] = None) -> int:
total = 0
for msg in messages:
content = msg.content
if isinstance(content, str):
total += count_tokens(content)
elif isinstance(content, list):
for part in content:
if part.text:
total += count_tokens(part.text)
if msg.name:
total += count_tokens(msg.name)
if msg.tool_calls:
total += count_tokens(json.dumps(msg.tool_calls))
if msg.tool_call_id:
total += count_tokens(msg.tool_call_id)
if tools:
total += count_tokens(json.dumps([t.model_dump() for t in tools]))
return total
# ----------------------------------------------------------------------
# Helper: prepare tools for Qwen
# ----------------------------------------------------------------------
def prepare_tools(tools: Optional[List[Tool]]) -> Optional[List[Dict]]:
if not tools:
return None
return [tool.model_dump() for tool in tools]
# ----------------------------------------------------------------------
# Helper: enforce JSON mode
# ----------------------------------------------------------------------
def enforce_json_mode(messages: List[ChatMessage]) -> List[ChatMessage]:
new_messages = messages.copy()
sys_instruction = "You must output only a valid JSON object. Do not include any other text, commentary, or markdown."
for i, msg in enumerate(new_messages):
if msg.role == "system":
if msg.content:
if isinstance(msg.content, str):
new_messages[i].content = sys_instruction + "\n\n" + msg.content
else:
new_messages[i].content = [{"type": "text", "text": sys_instruction}] + msg.content
else:
new_messages[i].content = sys_instruction
break
else:
new_messages.insert(0, ChatMessage(role="system", content=sys_instruction))
return new_messages
# ----------------------------------------------------------------------
# Main generation logic
# ----------------------------------------------------------------------
async def generate_chat_response(request: ChatCompletionRequest, http_request: Request):
target_model = request.model
load_model_global(target_model)
# JSON mode override
messages = request.messages
if request.response_format and request.response_format.type == "json_object":
messages = enforce_json_mode(messages)
# Build conversation and collect images
conversation = []
images = []
for msg in messages:
role = msg.role
if role == "tool":
conversation.append({"role": "tool", "content": msg.content, "tool_call_id": msg.tool_call_id})
continue
if isinstance(msg.content, str):
conversation.append({"role": role, "content": msg.content})
elif isinstance(msg.content, list):
parts = []
for part in msg.content:
if part.type == "text" and part.text:
parts.append({"type": "text", "text": part.text})
elif part.type == "image_url" and part.image_url:
img = load_image_from_input(part.image_url.url)
images.append(img)
parts.append({"type": "image"})
conversation.append({"role": role, "content": parts})
# Prepare tools
tools = prepare_tools(request.tools)
# Apply chat template
try:
prompt = processor.apply_chat_template(
conversation,
tools=tools,
tokenize=False,
add_generation_prompt=True
)
except Exception as e:
print(f"Warning: apply_chat_template with tools failed: {e}. Falling back without tools.")
prompt = processor.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True
)
# Tokenize
inputs = processor(
text=[prompt],
images=images if images else None,
return_tensors="pt",
padding=True
).to(model.device)
inputs.pop("mm_token_type_ids", None)
# Generation kwargs
gen_kwargs = {
"max_new_tokens": request.max_tokens,
"temperature": request.temperature,
"top_p": request.top_p,
"do_sample": request.temperature > 0,
"pad_token_id": processor.tokenizer.eos_token_id
}
gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None}
# ---- NON-STREAMING ----
if not request.stream:
with torch.no_grad():
generated_ids = model.generate(**inputs, **gen_kwargs)
output_ids = generated_ids[0][len(inputs["input_ids"][0]):]
response_text = processor.decode(output_ids, skip_special_tokens=True)
tool_calls = None
try:
parsed = json.loads(response_text)
if isinstance(parsed, dict) and "name" in parsed and "arguments" in parsed:
tool_calls = [{
"id": "call_" + str(torch.randint(0, 1000000, (1,)).item()),
"type": "function",
"function": {
"name": parsed["name"],
"arguments": json.dumps(parsed["arguments"])
}
}]
response_text = None
except:
pass
return {
"id": "chatcmpl-" + str(torch.randint(0, 1000000, (1,)).item()),
"object": "chat.completion",
"created": int(torch.randint(0, 1e9, (1,)).item()),
"model": current_model_id,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response_text,
"tool_calls": tool_calls
},
"finish_reason": "tool_calls" if tool_calls else "stop"
}],
"usage": {
"prompt_tokens": len(inputs["input_ids"][0]),
"completion_tokens": len(output_ids),
"total_tokens": len(inputs["input_ids"][0]) + len(output_ids)
}
}
# ---- STREAMING WITH THREAD-SAFE QUEUE ----
# Use a thread-safe queue.Queue (not asyncio)
token_queue = queue.Queue()
stop_event = threading.Event()
def generate_in_thread():
streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
generate_kwargs = {**inputs, **gen_kwargs}
# Start generation in its own thread (the model.generate call will run in this thread)
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
# Read tokens from streamer and put into the queue
for new_text in streamer:
if stop_event.is_set():
break
token_queue.put(new_text)
token_queue.put(None) # signal end
thread.join()
# Start background generation thread
bg_thread = threading.Thread(target=generate_in_thread)
bg_thread.start()
async def event_generator() -> AsyncGenerator[str, None]:
try:
while True:
# Check client disconnect
if await http_request.is_disconnected():
print("Client disconnected, stopping stream.")
stop_event.set()
break
# Try to get token from queue with timeout
try:
token = await asyncio.get_running_loop().run_in_executor(
None, lambda: token_queue.get(timeout=0.5)
)
except queue.Empty:
continue
if token is None:
break
# Send SSE chunk
chunk = {
"id": "chatcmpl-stream",
"object": "chat.completion.chunk",
"created": 0,
"model": current_model_id,
"choices": [{
"index": 0,
"delta": {"content": token},
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk)}\n\n"
# Final chunk
final_chunk = {"choices": [{"delta": {}, "finish_reason": "stop"}]}
yield f"data: {json.dumps(final_chunk)}\n\n"
yield "data: [DONE]\n\n"
except GeneratorExit:
stop_event.set()
finally:
stop_event.set()
bg_thread.join(timeout=1)
return StreamingResponse(event_generator(), media_type="text/event-stream")
# ----------------------------------------------------------------------
# Endpoints
# ----------------------------------------------------------------------
@app.get("/v1/models", tags=["Models"])
async def list_models():
return {
"object": "list",
"data": [
{"id": mid, "object": "model", "created": 0, "owned_by": "qwen"}
for mid in ALLOWED_MODELS
]
}
@app.post("/v1/chat/completions", tags=["Chat"])
async def chat_completions(request: ChatCompletionRequest, http_request: Request):
try:
return await generate_chat_response(request, http_request)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/v1/token_count", tags=["Utils"])
async def token_count(request: TokenCountRequest):
load_model_global(request.model)
total = count_messages_tokens(request.messages, request.tools)
return TokenCountResponse(
total_tokens=total,
messages_tokens=count_messages_tokens(request.messages),
tools_tokens=count_messages_tokens([], request.tools)
)
@app.get("/", tags=["Health"])
async def health():
return {
"status": "ok",
"current_model": current_model_id,
"available_models": ALLOWED_MODELS
}
# ----------------------------------------------------------------------
# Startup
# ----------------------------------------------------------------------
@app.on_event("startup")
def startup():
load_model_global(DEFAULT_MODEL)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)