import json import threading import gc import re import base64 import asyncio import queue from io import BytesIO from typing import List, Optional, Union, Dict, Any, Literal, AsyncGenerator from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field, field_validator from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer import torch import requests from PIL import Image # ---------------------------------------------------------------------- # App setup # ---------------------------------------------------------------------- app = FastAPI( title="Qwen3.5 Multimodal API", description="OpenAI‑compatible chat completions API with vision, tool calling, and JSON mode.", version="1.0.0", docs_url="/docs", redoc_url="/redoc" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ---------------------------------------------------------------------- # Model management # ---------------------------------------------------------------------- ALLOWED_MODELS = ["Qwen/Qwen3.5-2B", "Qwen/Qwen3.5-0.8B"] DEFAULT_MODEL = "Qwen/Qwen3.5-2B" current_model_id: Optional[str] = None model: Optional[AutoModelForCausalLM] = None processor: Optional[AutoProcessor] = None def load_model_global(model_id: str): global model, processor, current_model_id if current_model_id == model_id: return print(f"Loading model {model_id}...") if model is not None: del model del processor gc.collect() torch.cuda.empty_cache() processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_id, dtype=torch.float16, low_cpu_mem_usage=True, device_map="cpu", trust_remote_code=True ) model.eval() current_model_id = model_id print(f"Model {model_id} ready.") # ---------------------------------------------------------------------- # Helper: image loading (URL or base64) # ---------------------------------------------------------------------- def load_image_from_input(image_input: Union[str, Dict]) -> Image.Image: url = image_input if isinstance(image_input, str) else image_input.get("url", "") if url.startswith("data:image"): match = re.match(r"data:image/(?P.+?);base64,(?P.+)", url) if not match: raise ValueError("Invalid base64 image format") img_data = base64.b64decode(match.group("data")) return Image.open(BytesIO(img_data)).convert("RGB") else: resp = requests.get(url, timeout=10) resp.raise_for_status() return Image.open(BytesIO(resp.content)).convert("RGB") # ---------------------------------------------------------------------- # OpenAI‑compatible schemas # ---------------------------------------------------------------------- class FunctionDefinition(BaseModel): name: str description: Optional[str] = None parameters: Optional[Dict[str, Any]] = None class Tool(BaseModel): type: Literal["function"] = "function" function: FunctionDefinition class ToolChoice(BaseModel): type: Literal["function"] = "function" function: Dict[str, str] class ResponseFormat(BaseModel): type: Literal["text", "json_object"] = "text" class ImageURL(BaseModel): url: str detail: Optional[Literal["auto", "low", "high"]] = "auto" class ContentPart(BaseModel): type: str text: Optional[str] = None image_url: Optional[ImageURL] = None class ChatMessage(BaseModel): role: Literal["system", "user", "assistant", "tool"] content: Optional[Union[str, List[ContentPart]]] = None name: Optional[str] = None tool_calls: Optional[List[Dict[str, Any]]] = None tool_call_id: Optional[str] = None class ChatCompletionRequest(BaseModel): model: str = DEFAULT_MODEL messages: List[ChatMessage] temperature: Optional[float] = 0.7 max_tokens: Optional[int] = 256 top_p: Optional[float] = 1.0 frequency_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0 stream: bool = False tools: Optional[List[Tool]] = None tool_choice: Optional[Union[Literal["none", "auto"], ToolChoice]] = "auto" response_format: Optional[ResponseFormat] = ResponseFormat(type="text") @field_validator("model") def validate_model(cls, v): if v not in ALLOWED_MODELS: raise ValueError(f"Model must be one of {ALLOWED_MODELS}") return v class TokenCountRequest(BaseModel): model: str = DEFAULT_MODEL messages: List[ChatMessage] tools: Optional[List[Tool]] = None class TokenCountResponse(BaseModel): total_tokens: int messages_tokens: int tools_tokens: int = 0 # ---------------------------------------------------------------------- # Helper: token counting # ---------------------------------------------------------------------- def count_tokens(text: str) -> int: if processor is None: raise HTTPException(503, "Model not loaded") return len(processor.tokenizer.encode(text)) def count_messages_tokens(messages: List[ChatMessage], tools: Optional[List[Tool]] = None) -> int: total = 0 for msg in messages: content = msg.content if isinstance(content, str): total += count_tokens(content) elif isinstance(content, list): for part in content: if part.text: total += count_tokens(part.text) if msg.name: total += count_tokens(msg.name) if msg.tool_calls: total += count_tokens(json.dumps(msg.tool_calls)) if msg.tool_call_id: total += count_tokens(msg.tool_call_id) if tools: total += count_tokens(json.dumps([t.model_dump() for t in tools])) return total # ---------------------------------------------------------------------- # Helper: prepare tools for Qwen # ---------------------------------------------------------------------- def prepare_tools(tools: Optional[List[Tool]]) -> Optional[List[Dict]]: if not tools: return None return [tool.model_dump() for tool in tools] # ---------------------------------------------------------------------- # Helper: enforce JSON mode # ---------------------------------------------------------------------- def enforce_json_mode(messages: List[ChatMessage]) -> List[ChatMessage]: new_messages = messages.copy() sys_instruction = "You must output only a valid JSON object. Do not include any other text, commentary, or markdown." for i, msg in enumerate(new_messages): if msg.role == "system": if msg.content: if isinstance(msg.content, str): new_messages[i].content = sys_instruction + "\n\n" + msg.content else: new_messages[i].content = [{"type": "text", "text": sys_instruction}] + msg.content else: new_messages[i].content = sys_instruction break else: new_messages.insert(0, ChatMessage(role="system", content=sys_instruction)) return new_messages # ---------------------------------------------------------------------- # Main generation logic # ---------------------------------------------------------------------- async def generate_chat_response(request: ChatCompletionRequest, http_request: Request): target_model = request.model load_model_global(target_model) # JSON mode override messages = request.messages if request.response_format and request.response_format.type == "json_object": messages = enforce_json_mode(messages) # Build conversation and collect images conversation = [] images = [] for msg in messages: role = msg.role if role == "tool": conversation.append({"role": "tool", "content": msg.content, "tool_call_id": msg.tool_call_id}) continue if isinstance(msg.content, str): conversation.append({"role": role, "content": msg.content}) elif isinstance(msg.content, list): parts = [] for part in msg.content: if part.type == "text" and part.text: parts.append({"type": "text", "text": part.text}) elif part.type == "image_url" and part.image_url: img = load_image_from_input(part.image_url.url) images.append(img) parts.append({"type": "image"}) conversation.append({"role": role, "content": parts}) # Prepare tools tools = prepare_tools(request.tools) # Apply chat template try: prompt = processor.apply_chat_template( conversation, tools=tools, tokenize=False, add_generation_prompt=True ) except Exception as e: print(f"Warning: apply_chat_template with tools failed: {e}. Falling back without tools.") prompt = processor.apply_chat_template( conversation, tokenize=False, add_generation_prompt=True ) # Tokenize inputs = processor( text=[prompt], images=images if images else None, return_tensors="pt", padding=True ).to(model.device) inputs.pop("mm_token_type_ids", None) # Generation kwargs gen_kwargs = { "max_new_tokens": request.max_tokens, "temperature": request.temperature, "top_p": request.top_p, "do_sample": request.temperature > 0, "pad_token_id": processor.tokenizer.eos_token_id } gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None} # ---- NON-STREAMING ---- if not request.stream: with torch.no_grad(): generated_ids = model.generate(**inputs, **gen_kwargs) output_ids = generated_ids[0][len(inputs["input_ids"][0]):] response_text = processor.decode(output_ids, skip_special_tokens=True) tool_calls = None try: parsed = json.loads(response_text) if isinstance(parsed, dict) and "name" in parsed and "arguments" in parsed: tool_calls = [{ "id": "call_" + str(torch.randint(0, 1000000, (1,)).item()), "type": "function", "function": { "name": parsed["name"], "arguments": json.dumps(parsed["arguments"]) } }] response_text = None except: pass return { "id": "chatcmpl-" + str(torch.randint(0, 1000000, (1,)).item()), "object": "chat.completion", "created": int(torch.randint(0, 1e9, (1,)).item()), "model": current_model_id, "choices": [{ "index": 0, "message": { "role": "assistant", "content": response_text, "tool_calls": tool_calls }, "finish_reason": "tool_calls" if tool_calls else "stop" }], "usage": { "prompt_tokens": len(inputs["input_ids"][0]), "completion_tokens": len(output_ids), "total_tokens": len(inputs["input_ids"][0]) + len(output_ids) } } # ---- STREAMING WITH THREAD-SAFE QUEUE ---- # Use a thread-safe queue.Queue (not asyncio) token_queue = queue.Queue() stop_event = threading.Event() def generate_in_thread(): streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True) gen_kwargs["streamer"] = streamer generate_kwargs = {**inputs, **gen_kwargs} # Start generation in its own thread (the model.generate call will run in this thread) thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) thread.start() # Read tokens from streamer and put into the queue for new_text in streamer: if stop_event.is_set(): break token_queue.put(new_text) token_queue.put(None) # signal end thread.join() # Start background generation thread bg_thread = threading.Thread(target=generate_in_thread) bg_thread.start() async def event_generator() -> AsyncGenerator[str, None]: try: while True: # Check client disconnect if await http_request.is_disconnected(): print("Client disconnected, stopping stream.") stop_event.set() break # Try to get token from queue with timeout try: token = await asyncio.get_running_loop().run_in_executor( None, lambda: token_queue.get(timeout=0.5) ) except queue.Empty: continue if token is None: break # Send SSE chunk chunk = { "id": "chatcmpl-stream", "object": "chat.completion.chunk", "created": 0, "model": current_model_id, "choices": [{ "index": 0, "delta": {"content": token}, "finish_reason": None }] } yield f"data: {json.dumps(chunk)}\n\n" # Final chunk final_chunk = {"choices": [{"delta": {}, "finish_reason": "stop"}]} yield f"data: {json.dumps(final_chunk)}\n\n" yield "data: [DONE]\n\n" except GeneratorExit: stop_event.set() finally: stop_event.set() bg_thread.join(timeout=1) return StreamingResponse(event_generator(), media_type="text/event-stream") # ---------------------------------------------------------------------- # Endpoints # ---------------------------------------------------------------------- @app.get("/v1/models", tags=["Models"]) async def list_models(): return { "object": "list", "data": [ {"id": mid, "object": "model", "created": 0, "owned_by": "qwen"} for mid in ALLOWED_MODELS ] } @app.post("/v1/chat/completions", tags=["Chat"]) async def chat_completions(request: ChatCompletionRequest, http_request: Request): try: return await generate_chat_response(request, http_request) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/v1/token_count", tags=["Utils"]) async def token_count(request: TokenCountRequest): load_model_global(request.model) total = count_messages_tokens(request.messages, request.tools) return TokenCountResponse( total_tokens=total, messages_tokens=count_messages_tokens(request.messages), tools_tokens=count_messages_tokens([], request.tools) ) @app.get("/", tags=["Health"]) async def health(): return { "status": "ok", "current_model": current_model_id, "available_models": ALLOWED_MODELS } # ---------------------------------------------------------------------- # Startup # ---------------------------------------------------------------------- @app.on_event("startup") def startup(): load_model_global(DEFAULT_MODEL) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)