Spaces:
Running
Running
| import json | |
| import threading | |
| import gc | |
| import re | |
| import base64 | |
| import asyncio | |
| import queue | |
| from io import BytesIO | |
| from typing import List, Optional, Union, Dict, Any, Literal, AsyncGenerator | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel, Field, field_validator | |
| from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer | |
| import torch | |
| import requests | |
| from PIL import Image | |
| # ---------------------------------------------------------------------- | |
| # App setup | |
| # ---------------------------------------------------------------------- | |
| app = FastAPI( | |
| title="Qwen3.5 Multimodal API", | |
| description="OpenAI‑compatible chat completions API with vision, tool calling, and JSON mode.", | |
| version="1.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ---------------------------------------------------------------------- | |
| # Model management | |
| # ---------------------------------------------------------------------- | |
| ALLOWED_MODELS = ["Qwen/Qwen3.5-2B", "Qwen/Qwen3.5-0.8B"] | |
| DEFAULT_MODEL = "Qwen/Qwen3.5-2B" | |
| current_model_id: Optional[str] = None | |
| model: Optional[AutoModelForCausalLM] = None | |
| processor: Optional[AutoProcessor] = None | |
| def load_model_global(model_id: str): | |
| global model, processor, current_model_id | |
| if current_model_id == model_id: | |
| return | |
| print(f"Loading model {model_id}...") | |
| if model is not None: | |
| del model | |
| del processor | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| dtype=torch.float16, | |
| low_cpu_mem_usage=True, | |
| device_map="cpu", | |
| trust_remote_code=True | |
| ) | |
| model.eval() | |
| current_model_id = model_id | |
| print(f"Model {model_id} ready.") | |
| # ---------------------------------------------------------------------- | |
| # Helper: image loading (URL or base64) | |
| # ---------------------------------------------------------------------- | |
| def load_image_from_input(image_input: Union[str, Dict]) -> Image.Image: | |
| url = image_input if isinstance(image_input, str) else image_input.get("url", "") | |
| if url.startswith("data:image"): | |
| match = re.match(r"data:image/(?P<fmt>.+?);base64,(?P<data>.+)", url) | |
| if not match: | |
| raise ValueError("Invalid base64 image format") | |
| img_data = base64.b64decode(match.group("data")) | |
| return Image.open(BytesIO(img_data)).convert("RGB") | |
| else: | |
| resp = requests.get(url, timeout=10) | |
| resp.raise_for_status() | |
| return Image.open(BytesIO(resp.content)).convert("RGB") | |
| # ---------------------------------------------------------------------- | |
| # OpenAI‑compatible schemas | |
| # ---------------------------------------------------------------------- | |
| class FunctionDefinition(BaseModel): | |
| name: str | |
| description: Optional[str] = None | |
| parameters: Optional[Dict[str, Any]] = None | |
| class Tool(BaseModel): | |
| type: Literal["function"] = "function" | |
| function: FunctionDefinition | |
| class ToolChoice(BaseModel): | |
| type: Literal["function"] = "function" | |
| function: Dict[str, str] | |
| class ResponseFormat(BaseModel): | |
| type: Literal["text", "json_object"] = "text" | |
| class ImageURL(BaseModel): | |
| url: str | |
| detail: Optional[Literal["auto", "low", "high"]] = "auto" | |
| class ContentPart(BaseModel): | |
| type: str | |
| text: Optional[str] = None | |
| image_url: Optional[ImageURL] = None | |
| class ChatMessage(BaseModel): | |
| role: Literal["system", "user", "assistant", "tool"] | |
| content: Optional[Union[str, List[ContentPart]]] = None | |
| name: Optional[str] = None | |
| tool_calls: Optional[List[Dict[str, Any]]] = None | |
| tool_call_id: Optional[str] = None | |
| class ChatCompletionRequest(BaseModel): | |
| model: str = DEFAULT_MODEL | |
| messages: List[ChatMessage] | |
| temperature: Optional[float] = 0.7 | |
| max_tokens: Optional[int] = 256 | |
| top_p: Optional[float] = 1.0 | |
| frequency_penalty: Optional[float] = 0.0 | |
| presence_penalty: Optional[float] = 0.0 | |
| stream: bool = False | |
| tools: Optional[List[Tool]] = None | |
| tool_choice: Optional[Union[Literal["none", "auto"], ToolChoice]] = "auto" | |
| response_format: Optional[ResponseFormat] = ResponseFormat(type="text") | |
| def validate_model(cls, v): | |
| if v not in ALLOWED_MODELS: | |
| raise ValueError(f"Model must be one of {ALLOWED_MODELS}") | |
| return v | |
| class TokenCountRequest(BaseModel): | |
| model: str = DEFAULT_MODEL | |
| messages: List[ChatMessage] | |
| tools: Optional[List[Tool]] = None | |
| class TokenCountResponse(BaseModel): | |
| total_tokens: int | |
| messages_tokens: int | |
| tools_tokens: int = 0 | |
| # ---------------------------------------------------------------------- | |
| # Helper: token counting | |
| # ---------------------------------------------------------------------- | |
| def count_tokens(text: str) -> int: | |
| if processor is None: | |
| raise HTTPException(503, "Model not loaded") | |
| return len(processor.tokenizer.encode(text)) | |
| def count_messages_tokens(messages: List[ChatMessage], tools: Optional[List[Tool]] = None) -> int: | |
| total = 0 | |
| for msg in messages: | |
| content = msg.content | |
| if isinstance(content, str): | |
| total += count_tokens(content) | |
| elif isinstance(content, list): | |
| for part in content: | |
| if part.text: | |
| total += count_tokens(part.text) | |
| if msg.name: | |
| total += count_tokens(msg.name) | |
| if msg.tool_calls: | |
| total += count_tokens(json.dumps(msg.tool_calls)) | |
| if msg.tool_call_id: | |
| total += count_tokens(msg.tool_call_id) | |
| if tools: | |
| total += count_tokens(json.dumps([t.model_dump() for t in tools])) | |
| return total | |
| # ---------------------------------------------------------------------- | |
| # Helper: prepare tools for Qwen | |
| # ---------------------------------------------------------------------- | |
| def prepare_tools(tools: Optional[List[Tool]]) -> Optional[List[Dict]]: | |
| if not tools: | |
| return None | |
| return [tool.model_dump() for tool in tools] | |
| # ---------------------------------------------------------------------- | |
| # Helper: enforce JSON mode | |
| # ---------------------------------------------------------------------- | |
| def enforce_json_mode(messages: List[ChatMessage]) -> List[ChatMessage]: | |
| new_messages = messages.copy() | |
| sys_instruction = "You must output only a valid JSON object. Do not include any other text, commentary, or markdown." | |
| for i, msg in enumerate(new_messages): | |
| if msg.role == "system": | |
| if msg.content: | |
| if isinstance(msg.content, str): | |
| new_messages[i].content = sys_instruction + "\n\n" + msg.content | |
| else: | |
| new_messages[i].content = [{"type": "text", "text": sys_instruction}] + msg.content | |
| else: | |
| new_messages[i].content = sys_instruction | |
| break | |
| else: | |
| new_messages.insert(0, ChatMessage(role="system", content=sys_instruction)) | |
| return new_messages | |
| # ---------------------------------------------------------------------- | |
| # Main generation logic | |
| # ---------------------------------------------------------------------- | |
| async def generate_chat_response(request: ChatCompletionRequest, http_request: Request): | |
| target_model = request.model | |
| load_model_global(target_model) | |
| # JSON mode override | |
| messages = request.messages | |
| if request.response_format and request.response_format.type == "json_object": | |
| messages = enforce_json_mode(messages) | |
| # Build conversation and collect images | |
| conversation = [] | |
| images = [] | |
| for msg in messages: | |
| role = msg.role | |
| if role == "tool": | |
| conversation.append({"role": "tool", "content": msg.content, "tool_call_id": msg.tool_call_id}) | |
| continue | |
| if isinstance(msg.content, str): | |
| conversation.append({"role": role, "content": msg.content}) | |
| elif isinstance(msg.content, list): | |
| parts = [] | |
| for part in msg.content: | |
| if part.type == "text" and part.text: | |
| parts.append({"type": "text", "text": part.text}) | |
| elif part.type == "image_url" and part.image_url: | |
| img = load_image_from_input(part.image_url.url) | |
| images.append(img) | |
| parts.append({"type": "image"}) | |
| conversation.append({"role": role, "content": parts}) | |
| # Prepare tools | |
| tools = prepare_tools(request.tools) | |
| # Apply chat template | |
| try: | |
| prompt = processor.apply_chat_template( | |
| conversation, | |
| tools=tools, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| except Exception as e: | |
| print(f"Warning: apply_chat_template with tools failed: {e}. Falling back without tools.") | |
| prompt = processor.apply_chat_template( | |
| conversation, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Tokenize | |
| inputs = processor( | |
| text=[prompt], | |
| images=images if images else None, | |
| return_tensors="pt", | |
| padding=True | |
| ).to(model.device) | |
| inputs.pop("mm_token_type_ids", None) | |
| # Generation kwargs | |
| gen_kwargs = { | |
| "max_new_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| "top_p": request.top_p, | |
| "do_sample": request.temperature > 0, | |
| "pad_token_id": processor.tokenizer.eos_token_id | |
| } | |
| gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None} | |
| # ---- NON-STREAMING ---- | |
| if not request.stream: | |
| with torch.no_grad(): | |
| generated_ids = model.generate(**inputs, **gen_kwargs) | |
| output_ids = generated_ids[0][len(inputs["input_ids"][0]):] | |
| response_text = processor.decode(output_ids, skip_special_tokens=True) | |
| tool_calls = None | |
| try: | |
| parsed = json.loads(response_text) | |
| if isinstance(parsed, dict) and "name" in parsed and "arguments" in parsed: | |
| tool_calls = [{ | |
| "id": "call_" + str(torch.randint(0, 1000000, (1,)).item()), | |
| "type": "function", | |
| "function": { | |
| "name": parsed["name"], | |
| "arguments": json.dumps(parsed["arguments"]) | |
| } | |
| }] | |
| response_text = None | |
| except: | |
| pass | |
| return { | |
| "id": "chatcmpl-" + str(torch.randint(0, 1000000, (1,)).item()), | |
| "object": "chat.completion", | |
| "created": int(torch.randint(0, 1e9, (1,)).item()), | |
| "model": current_model_id, | |
| "choices": [{ | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": response_text, | |
| "tool_calls": tool_calls | |
| }, | |
| "finish_reason": "tool_calls" if tool_calls else "stop" | |
| }], | |
| "usage": { | |
| "prompt_tokens": len(inputs["input_ids"][0]), | |
| "completion_tokens": len(output_ids), | |
| "total_tokens": len(inputs["input_ids"][0]) + len(output_ids) | |
| } | |
| } | |
| # ---- STREAMING WITH THREAD-SAFE QUEUE ---- | |
| # Use a thread-safe queue.Queue (not asyncio) | |
| token_queue = queue.Queue() | |
| stop_event = threading.Event() | |
| def generate_in_thread(): | |
| streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True) | |
| gen_kwargs["streamer"] = streamer | |
| generate_kwargs = {**inputs, **gen_kwargs} | |
| # Start generation in its own thread (the model.generate call will run in this thread) | |
| thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) | |
| thread.start() | |
| # Read tokens from streamer and put into the queue | |
| for new_text in streamer: | |
| if stop_event.is_set(): | |
| break | |
| token_queue.put(new_text) | |
| token_queue.put(None) # signal end | |
| thread.join() | |
| # Start background generation thread | |
| bg_thread = threading.Thread(target=generate_in_thread) | |
| bg_thread.start() | |
| async def event_generator() -> AsyncGenerator[str, None]: | |
| try: | |
| while True: | |
| # Check client disconnect | |
| if await http_request.is_disconnected(): | |
| print("Client disconnected, stopping stream.") | |
| stop_event.set() | |
| break | |
| # Try to get token from queue with timeout | |
| try: | |
| token = await asyncio.get_running_loop().run_in_executor( | |
| None, lambda: token_queue.get(timeout=0.5) | |
| ) | |
| except queue.Empty: | |
| continue | |
| if token is None: | |
| break | |
| # Send SSE chunk | |
| chunk = { | |
| "id": "chatcmpl-stream", | |
| "object": "chat.completion.chunk", | |
| "created": 0, | |
| "model": current_model_id, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {"content": token}, | |
| "finish_reason": None | |
| }] | |
| } | |
| yield f"data: {json.dumps(chunk)}\n\n" | |
| # Final chunk | |
| final_chunk = {"choices": [{"delta": {}, "finish_reason": "stop"}]} | |
| yield f"data: {json.dumps(final_chunk)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| except GeneratorExit: | |
| stop_event.set() | |
| finally: | |
| stop_event.set() | |
| bg_thread.join(timeout=1) | |
| return StreamingResponse(event_generator(), media_type="text/event-stream") | |
| # ---------------------------------------------------------------------- | |
| # Endpoints | |
| # ---------------------------------------------------------------------- | |
| async def list_models(): | |
| return { | |
| "object": "list", | |
| "data": [ | |
| {"id": mid, "object": "model", "created": 0, "owned_by": "qwen"} | |
| for mid in ALLOWED_MODELS | |
| ] | |
| } | |
| async def chat_completions(request: ChatCompletionRequest, http_request: Request): | |
| try: | |
| return await generate_chat_response(request, http_request) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def token_count(request: TokenCountRequest): | |
| load_model_global(request.model) | |
| total = count_messages_tokens(request.messages, request.tools) | |
| return TokenCountResponse( | |
| total_tokens=total, | |
| messages_tokens=count_messages_tokens(request.messages), | |
| tools_tokens=count_messages_tokens([], request.tools) | |
| ) | |
| async def health(): | |
| return { | |
| "status": "ok", | |
| "current_model": current_model_id, | |
| "available_models": ALLOWED_MODELS | |
| } | |
| # ---------------------------------------------------------------------- | |
| # Startup | |
| # ---------------------------------------------------------------------- | |
| def startup(): | |
| load_model_global(DEFAULT_MODEL) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |