#!/usr/bin/env python3 """ Multi-Model AI API — HuggingFace Spaces Edition With load balancing, 10 req/s rate limiting, vision support, and multimodal fixes. """ import re, os, json, uuid, time, random, string, logging, threading, base64 from abc import ABC, abstractmethod from collections import deque from dataclasses import dataclass, field from typing import Any, Dict, Generator, List, Optional, Tuple, Union from io import BytesIO import requests from flask import Flask, request as freq, jsonify, Response, stream_with_context try: from gradio_client import Client as GradioClient, handle_file HAS_GRADIO_CLIENT = True except ImportError: HAS_GRADIO_CLIENT = False # ═══════════════════════════════════════════════════════════════ # CONFIG & CONSTANTS # ═══════════════════════════════════════════════════════════════ VERSION = "3.0.0-hf-lb" APP_NAME = "Multi-Model-AI-API" DEFAULT_SYSTEM_PROMPT = "You are a helpful, friendly AI assistant." DEFAULT_MODEL = "gpt-oss-120b" logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") log = logging.getLogger(APP_NAME) USER_AGENTS = [ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Chrome/144.0.0.0 Safari/537.36", "Mozilla/5.0 (Macintosh; Intel Mac OS X 14_5) AppleWebKit/605.1.15 Safari/605.1.15", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 Chrome/143.0.0.0 Safari/537.36", "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:128.0) Gecko/20100101 Firefox/128.0", ] # ═══════════════════════════════════════════════════════════════ # MULTIMODAL HELPERS # ═══════════════════════════════════════════════════════════════ def extract_text_and_images(content: Any) -> Tuple[str, List[str]]: """ Parse OpenAI-style multimodal content. Returns (text, [base64_or_url, ...]) Handles: str, list of {type, text/image_url} """ if content is None: return "", [] if isinstance(content, str): return content.strip(), [] texts: List[str] = [] images: List[str] = [] if isinstance(content, list): for block in content: if not isinstance(block, dict): texts.append(str(block)) continue btype = block.get("type", "") if btype == "text": t = block.get("text", "") if t: texts.append(t) elif btype == "image_url": img = block.get("image_url", {}) url = img.get("url", "") if isinstance(img, dict) else str(img) if url: images.append(url) elif btype == "image": # Alternative format src = block.get("source", {}) if isinstance(src, dict): data = src.get("data", "") if data: media = src.get("media_type", "image/jpeg") images.append(f"data:{media};base64,{data}") return " ".join(texts).strip(), images def decode_image_to_bytes(image_url: str) -> Optional[Tuple[bytes, str]]: """Convert image URL or data URI to (bytes, media_type).""" try: if image_url.startswith("data:"): # data:image/jpeg;base64,/9j/... header, data = image_url.split(",", 1) media_type = header.split(";")[0].split(":")[1] return base64.b64decode(data), media_type else: # Remote URL r = requests.get(image_url, timeout=15) r.raise_for_status() ct = r.headers.get("content-type", "image/jpeg").split(";")[0] return r.content, ct except Exception as e: log.warning(f"Failed to decode image: {e}") return None def save_image_temp(image_url: str) -> Optional[str]: """Save image to a temp file and return path (for gradio_client).""" import tempfile result = decode_image_to_bytes(image_url) if not result: return None data, media_type = result ext = media_type.split("/")[-1].replace("jpeg", "jpg") with tempfile.NamedTemporaryFile(suffix=f".{ext}", delete=False) as f: f.write(data) return f.name # ═══════════════════════════════════════════════════════════════ # MODEL REGISTRY # ═══════════════════════════════════════════════════════════════ @dataclass class ModelDef: model_id: str display_name: str provider_type: str space_id: str owned_by: str description: str = "" supports_system_prompt: bool = True supports_temperature: bool = True supports_streaming: bool = True supports_history: bool = True supports_vision: bool = False supports_thinking: bool = False thinking_default: bool = True max_tokens_default: int = 4096 default_temperature: float = 0.7 fn_index: Optional[int] = None api_name: Optional[str] = None extra_params: Dict[str, Any] = field(default_factory=dict) clean_analysis: bool = False lb_pool_size: int = 2 lb_enabled: bool = True is_beta: bool = False MODEL_REGISTRY: Dict[str, ModelDef] = {} def register_model(m: ModelDef): MODEL_REGISTRY[m.model_id] = m def _init_registry(): register_model(ModelDef( model_id="gpt-oss-120b", display_name="AMD GPT-OSS-120B", provider_type="gradio_sse", space_id="https://amd-gpt-oss-120b-chatbot.hf.space", owned_by="amd", description="AMD open-source 120B model", fn_index=8, clean_analysis=True, default_temperature=0.0, supports_vision=False, supports_thinking=False, lb_pool_size=3, lb_enabled=True, )) register_model(ModelDef( model_id="command-a-vision", display_name="Cohere Command-A Vision", provider_type="gradio_client", space_id="CohereLabs/command-a-vision", owned_by="cohere", description="Cohere multimodal command model", api_name="/chat", supports_vision=True, supports_system_prompt=False, supports_temperature=False, supports_streaming=False, supports_history=False, supports_thinking=False, max_tokens_default=700, extra_params={"max_new_tokens": 700}, lb_pool_size=2, lb_enabled=True, )) register_model(ModelDef( model_id="command-a-translate", display_name="Cohere Command-A Translate", provider_type="gradio_client", space_id="CohereLabs/command-a-translate", owned_by="cohere", description="Cohere translation model", api_name="/chat", supports_vision=False, supports_system_prompt=False, supports_temperature=False, supports_streaming=False, supports_history=False, supports_thinking=False, max_tokens_default=700, extra_params={"max_new_tokens": 700}, lb_pool_size=1, lb_enabled=False, )) register_model(ModelDef( model_id="command-a-reasoning", display_name="Cohere Command-A Reasoning", provider_type="gradio_client", space_id="CohereLabs/command-a-reasoning", owned_by="cohere", description="Cohere reasoning model with thinking budget", api_name="/chat", supports_vision=False, supports_system_prompt=False, supports_temperature=False, supports_streaming=False, supports_history=False, supports_thinking=True, thinking_default=True, max_tokens_default=4096, extra_params={"thinking_budget": 500}, lb_pool_size=2, lb_enabled=True, )) register_model(ModelDef( model_id="minimax-vl-01", display_name="MiniMax VL-01", provider_type="gradio_client", space_id="MiniMaxAI/MiniMax-VL-01", owned_by="minimax", description="MiniMax vision-language model", api_name="/chat", supports_vision=True, supports_system_prompt=False, supports_temperature=True, supports_streaming=False, supports_history=False, supports_thinking=False, max_tokens_default=12800, default_temperature=0.1, extra_params={"max_tokens": 12800, "top_p": 0.9}, lb_pool_size=2, lb_enabled=True, )) register_model(ModelDef( model_id="glm-4.5", display_name="GLM-4.5 (ZhipuAI)", provider_type="gradio_client", space_id="zai-org/GLM-4.5-Space", owned_by="zhipuai", description="ZhipuAI GLM-4.5 with thinking mode", api_name="/chat_wrapper", supports_vision=False, supports_system_prompt=True, supports_temperature=True, supports_streaming=False, supports_history=False, supports_thinking=True, thinking_default=True, default_temperature=1.0, extra_params={"thinking_enabled": True}, lb_pool_size=2, lb_enabled=True, )) register_model(ModelDef( model_id="chatgpt", display_name="ChatGPT (Community)", provider_type="gradio_client", space_id="yuntian-deng/ChatGPT", owned_by="community", description="ChatGPT via community Space", api_name="/predict", supports_vision=False, supports_system_prompt=False, supports_temperature=True, supports_streaming=False, supports_history=True, supports_thinking=False, default_temperature=1.0, extra_params={"top_p": 1.0}, lb_pool_size=2, lb_enabled=True, )) register_model(ModelDef( model_id="qwen3-vl", display_name="Qwen3-VL (Alibaba)", provider_type="gradio_client", space_id="Qwen/Qwen3-VL-Demo", owned_by="alibaba", description="Alibaba Qwen3 Vision-Language model", api_name="/add_message", supports_vision=True, supports_system_prompt=False, supports_temperature=False, supports_streaming=False, supports_history=False, supports_thinking=False, max_tokens_default=4096, lb_pool_size=2, lb_enabled=True, )) register_model(ModelDef( model_id="qwen2.5-coder", display_name="Qwen2.5-Coder Artifacts (BETA)", provider_type="gradio_client", space_id="Qwen/Qwen2.5-Coder-Artifacts", owned_by="alibaba", description="Alibaba Qwen2.5 Coder — code generation model (BETA)", api_name="/generation_code", supports_vision=False, supports_system_prompt=True, supports_temperature=False, supports_streaming=False, supports_history=False, supports_thinking=False, max_tokens_default=4096, extra_params={ "system_prompt_override": ( "You are a helpful assistant. You are a skilled programming assistant. " "You help users write, debug, and understand code across all languages. " "Respond with clear explanations and clean code. " "Do NOT generate HTML artifacts or web page previews. " "Do NOT wrap everything in a single HTML file. " "Just provide the code the user asks for with explanations." ), }, lb_pool_size=2, lb_enabled=True, is_beta=True, )) _init_registry() # ═══════════════════════════════════════════════════════════════ # CONFIG # ═══════════════════════════════════════════════════════════════ @dataclass class Config: default_model: str = DEFAULT_MODEL default_system_prompt: str = DEFAULT_SYSTEM_PROMPT timeout_stream: int = 300 max_retries: int = 3 retry_backoff_base: float = 1.5 retry_jitter: float = 0.5 rate_limit_rps: int = 10 rate_limit_burst: int = 15 pool_size: int = 2 max_history_messages: int = 50 max_message_length: int = 32000 default_temperature: float = 0.7 include_thinking: bool = True log_sse_raw: bool = False @classmethod def from_env(cls) -> "Config": cfg = cls() env_map = { "MMAI_TIMEOUT": ("timeout_stream", int), "MMAI_MAX_RETRIES": ("max_retries", int), "MMAI_RATE_LIMIT_RPS": ("rate_limit_rps", int), "MMAI_RATE_LIMIT_BURST": ("rate_limit_burst", int), "MMAI_POOL_SIZE": ("pool_size", int), "MMAI_SYSTEM_PROMPT": ("default_system_prompt", str), "MMAI_TEMPERATURE": ("default_temperature", float), "MMAI_DEFAULT_MODEL": ("default_model", str), "MMAI_INCLUDE_THINKING": ("include_thinking", lambda x: x.lower() in ("1", "true")), } for env_key, (attr, conv) in env_map.items(): val = os.environ.get(env_key) if val is not None: try: setattr(cfg, attr, conv(val)) except (ValueError, TypeError): pass return cfg # ═══════════════════════════════════════════════════════════════ # EXCEPTIONS # ═══════════════════════════════════════════════════════════════ class APIError(Exception): def __init__(self, message: str, code: str = "UNKNOWN", status: int = 500): super().__init__(message) self.code = code self.status = status def to_dict(self): return {"error": str(self), "code": self.code} class ModelNotFoundError(APIError): def __init__(self, model_id: str): super().__init__( f"Model '{model_id}' not found. Available: {list(MODEL_REGISTRY.keys())}", "MODEL_NOT_FOUND", 404, ) # ═══════════════════════════════════════════════════════════════ # RESPONSE CLEANER # ═══════════════════════════════════════════════════════════════ class ResponseCleaner: @classmethod def clean_analysis(cls, text: str) -> str: if not text: return text original = text.strip() for pattern in [ r'\*\*💬\s*Response:\*\*\s*\n*(.*?)$', r'\*\*Response:\*\*\s*\n*(.*?)$', r'---+\s*\n*\*\*💬\s*Response:\*\*\s*\n*(.*?)$', ]: match = re.search(pattern, original, re.DOTALL) if match: cleaned = match.group(1).strip() if cleaned: return cleaned for pattern in [r'assistantfinal\s*(.*?)$', r'assistant\s*final\s*(.*?)$']: match = re.search(pattern, original, re.DOTALL | re.IGNORECASE) if match: cleaned = match.group(1).strip() if cleaned: return cleaned if re.match(r'^analysis', original, re.IGNORECASE): return "" return original @classmethod def _decode_html_entities(cls, text: str) -> str: entities = { ''': "'", ''': "'", ''': "'", '"': '"', '"': '"', '"': '"', '&': '&', '<': '<', '>': '>', ' ': ' ', '’': '\u2019', '‘': '\u2018', '”': '\u201d', '“': '\u201c', '—': '—', '–': '–', '…': '…', } for entity, char in entities.items(): text = text.replace(entity, char) text = re.sub(r'&#x([0-9a-fA-F]+);', lambda m: chr(int(m.group(1), 16)), text) text = re.sub(r'&#(\d+);', lambda m: chr(int(m.group(1))), text) return text @classmethod def _strip_html(cls, text: str) -> str: text = re.sub(r'', '\n', text, flags=re.IGNORECASE) text = re.sub(r'<[^>]+>', '', text) return cls._decode_html_entities(text).strip() @classmethod def clean_glm(cls, text: str, include_thinking: bool = True) -> str: if not text: return text if ']*>.*?]*>(.*?)\s*', text, re.DOTALL | re.IGNORECASE, ) if thinking_match: thinking_text = cls._strip_html(thinking_match.group(1)).strip() text_without_details = re.sub( r']*>.*?', '', text, flags=re.DOTALL | re.IGNORECASE, ).strip() div_match = re.search( r"]*>\s*(.*?)\s*", text_without_details, re.DOTALL | re.IGNORECASE, ) response_text = ( cls._strip_html(div_match.group(1)).strip() if div_match else cls._strip_html(text_without_details).strip() ) if thinking_text and include_thinking: return f"\n{thinking_text}\n\n{response_text}" return response_text @classmethod def extract_qwen_text(cls, result: Any) -> str: if result is None: return "" if isinstance(result, str): return result.strip() if isinstance(result, tuple): for el in result: if isinstance(el, dict): value = el.get("value") if isinstance(value, list): for msg in reversed(value): if isinstance(msg, dict) and msg.get("role") == "assistant": content = msg.get("content", "") if isinstance(content, str): return content.strip() if isinstance(content, list): texts = [] for block in content: if isinstance(block, str): texts.append(block) elif isinstance(block, dict) and block.get("type") != "file": bc = block.get("content", "") if isinstance(bc, str) and bc.strip(): texts.append(bc) return "\n".join(t for t in texts if t.strip()).strip() return str(content) return str(result) if result else "" @classmethod def extract_chatgpt_text(cls, result: Any) -> str: if isinstance(result, str): return result.strip() if isinstance(result, tuple) and len(result) >= 1: chatbot = result[0] if isinstance(chatbot, (list, tuple)) and chatbot: last = chatbot[-1] if isinstance(last, (list, tuple)) and len(last) >= 2: msg = last[1] if isinstance(msg, str): return msg.strip() if isinstance(msg, dict): return str(msg.get("value", msg.get("content", ""))).strip() return str(msg).strip() if msg else "" return str(chatbot).strip() if chatbot else "" return str(result) @classmethod def extract_qwen_coder_text(cls, result: Any) -> str: if result is None: return "" if isinstance(result, str): return result.strip() if isinstance(result, tuple): if len(result) >= 1 and isinstance(result[0], str): text = result[0].strip() if text: return text if len(result) >= 2 and isinstance(result[1], str): return result[1].strip() if isinstance(result, (list, dict)): return str(result) return str(result) if result else "" @classmethod def clean(cls, text: str, model_id: str = "", include_thinking: bool = True) -> str: if not text: return text text = text.strip() if model_id == "gpt-oss-120b": text = cls.clean_analysis(text) elif model_id == "glm-4.5": text = cls.clean_glm(text, include_thinking=include_thinking) if '&' in text and ';' in text: text = cls._decode_html_entities(text) return text.strip() # ═══════════════════════════════════════════════════════════════ # THINKING PARSER # ═══════════════════════════════════════════════════════════════ class ThinkingParser: @staticmethod def split(text: str) -> Tuple[Optional[str], str]: match = re.match( r'\s*\s*\n?(.*?)\n?\s*\s*\n?(.*)', text, re.DOTALL | re.IGNORECASE, ) if match: thinking = match.group(1).strip() response = match.group(2).strip() return (thinking if thinking else None, response) return (None, text.strip()) @staticmethod def format(thinking: Optional[str], response: str) -> str: if thinking: return f"\n{thinking}\n\n{response}" return response # ═══════════════════════════════════════════════════════════════ # DATA MODELS # ═══════════════════════════════════════════════════════════════ @dataclass class Message: role: str content: str thinking: Optional[str] = None timestamp: float = field(default_factory=time.time) message_id: str = field(default_factory=lambda: str(uuid.uuid4())) @dataclass class Conversation: conversation_id: str = field(default_factory=lambda: str(uuid.uuid4())) messages: List[Message] = field(default_factory=list) created_at: float = field(default_factory=time.time) updated_at: float = field(default_factory=time.time) title: Optional[str] = None system_prompt: str = DEFAULT_SYSTEM_PROMPT model_id: str = DEFAULT_MODEL def add_message(self, role: str, content: str, max_messages: int = 50, thinking: Optional[str] = None) -> Message: msg = Message(role=role, content=content, thinking=thinking) self.messages.append(msg) self.updated_at = time.time() if self.title is None and role == "user": self.title = content[:80] if len(self.messages) > max_messages: system_msgs = [m for m in self.messages if m.role == "system"] other_msgs = [m for m in self.messages if m.role != "system"] self.messages = system_msgs + other_msgs[-(max_messages - len(system_msgs)):] return msg def build_gradio_history(self) -> List[List[str]]: history = [] non_system = [m for m in self.messages if m.role != "system"] i = 0 while i < len(non_system) - 1: if (non_system[i].role == "user" and i + 1 < len(non_system) and non_system[i + 1].role == "assistant"): history.append([non_system[i].content, non_system[i + 1].content]) i += 2 else: i += 1 return history def build_chatbot_tuples(self) -> List[List[str]]: return self.build_gradio_history() def to_dict(self) -> Dict: return { "conversation_id": self.conversation_id, "title": self.title, "model": self.model_id, "message_count": len(self.messages), "created_at": self.created_at, "updated_at": self.updated_at, } # ═══════════════════════════════════════════════════════════════ # METRICS # ═══════════════════════════════════════════════════════════════ @dataclass class Metrics: _lock: threading.Lock = field(default_factory=threading.Lock, repr=False) total_requests: int = 0 successful_requests: int = 0 failed_requests: int = 0 total_retries: int = 0 total_chars_received: int = 0 active_streams: int = 0 requests_per_model: Dict[str, int] = field(default_factory=dict) _latencies: deque = field(default_factory=lambda: deque(maxlen=1000), repr=False) started_at: float = field(default_factory=time.time) lb_total_dispatches: int = 0 lb_failovers: int = 0 def record_request(self, success: bool, duration_ms: float, chars: int = 0, model: str = ""): with self._lock: self.total_requests += 1 if success: self.successful_requests += 1 self.total_chars_received += chars else: self.failed_requests += 1 self._latencies.append(duration_ms) if model: self.requests_per_model[model] = ( self.requests_per_model.get(model, 0) + 1 ) def record_retry(self): with self._lock: self.total_retries += 1 def record_lb_dispatch(self, failover: bool = False): with self._lock: self.lb_total_dispatches += 1 if failover: self.lb_failovers += 1 def to_dict(self) -> Dict: with self._lock: avg = (sum(self._latencies) / len(self._latencies) if self._latencies else 0) rate = (self.successful_requests / self.total_requests if self.total_requests else 1) return { "total_requests": self.total_requests, "successful": self.successful_requests, "failed": self.failed_requests, "success_rate": round(rate, 4), "retries": self.total_retries, "chars_received": self.total_chars_received, "avg_latency_ms": round(avg, 1), "active_streams": self.active_streams, "uptime_s": round(time.time() - self.started_at, 1), "per_model": dict(self.requests_per_model), "load_balancer": { "total_dispatches": self.lb_total_dispatches, "failovers": self.lb_failovers, }, } metrics = Metrics() # ═══════════════════════════════════════════════════════════════ # RATE LIMITER — token bucket (10 req/s) # ═══════════════════════════════════════════════════════════════ class RateLimiter: def __init__(self, rps: int = 10, burst: int = 15): self.rate = float(rps) self.max_tokens = float(burst) self.tokens = float(burst) self.last_refill = time.monotonic() self._lock = threading.Lock() def acquire(self, timeout: float = 10.0) -> bool: deadline = time.monotonic() + timeout while True: with self._lock: now = time.monotonic() elapsed = now - self.last_refill self.tokens = min( self.max_tokens, self.tokens + elapsed * self.rate, ) self.last_refill = now if self.tokens >= 1.0: self.tokens -= 1.0 return True if time.monotonic() >= deadline: return False time.sleep(0.05) def get_info(self) -> Dict: with self._lock: return { "rate_rps": self.rate, "burst": self.max_tokens, "available_tokens": round(self.tokens, 2), } # ═══════════════════════════════════════════════════════════════ # CIRCUIT BREAKER # ═══════════════════════════════════════════════════════════════ class CircuitBreaker: def __init__(self, threshold: int = 5, recovery: int = 60): self.threshold = threshold self.recovery = recovery self.state = "closed" self.failures = 0 self.successes = 0 self.last_failure = 0.0 self._lock = threading.Lock() def can_execute(self) -> bool: with self._lock: if self.state == "closed": return True if self.state == "open": if time.time() - self.last_failure >= self.recovery: self.state = "half_open" return True return False return self.successes < 2 def record_success(self): with self._lock: if self.state == "half_open": self.successes += 1 if self.successes >= 2: self.state = "closed" self.failures = 0 self.successes = 0 else: self.failures = max(0, self.failures - 1) def record_failure(self): with self._lock: self.failures += 1 self.last_failure = time.time() if self.state == "half_open" or self.failures >= self.threshold: self.state = "open" # ═══════════════════════════════════════════════════════════════ # SSE PARSER (for GPT-OSS) # ═══════════════════════════════════════════════════════════════ class GradioSSEParser: @staticmethod def parse_sse(response: requests.Response, log_raw: bool = False) -> Generator[Dict, None, None]: buffer = "" for chunk in response.iter_content(chunk_size=None, decode_unicode=True): if chunk is None: continue buffer += chunk while "\n" in buffer: line, buffer = buffer.split("\n", 1) line = line.strip() if not line or not line.startswith("data:"): continue data_str = line[5:].strip() if not data_str: continue try: yield json.loads(data_str) except json.JSONDecodeError: continue @staticmethod def extract_text(output: Dict) -> str: data = output.get("data", []) if not data: return "" first = data[0] if isinstance(first, str): return first if isinstance(first, list): try: if first and isinstance(first[0], list): return str(first[0][-1]) except (IndexError, TypeError): pass return "" # ═══════════════════════════════════════════════════════════════ # MODEL PROVIDERS # ═══════════════════════════════════════════════════════════════ class ModelProvider(ABC): def __init__(self, model_def: ModelDef, config: Config, instance_id: int = 0): self.model_def = model_def self.config = config self.instance_id = instance_id self.ready = False self._lock = threading.Lock() self._consecutive_failures = 0 self._last_success_time = 0.0 self._last_failure_time = 0.0 self._total_requests = 0 self._total_failures = 0 self._latencies: deque = deque(maxlen=50) @abstractmethod def initialize(self) -> bool: ... @abstractmethod def generate(self, message: str, history=None, system_prompt=None, temperature=None, max_tokens=None, images=None, **kwargs) -> str: ... def generate_stream(self, message: str, **kwargs) -> Generator[str, None, None]: yield self.generate(message, **kwargs) def record_success(self, latency_ms: float): self._consecutive_failures = 0 self._last_success_time = time.time() self._total_requests += 1 self._latencies.append(latency_ms) def record_failure(self): self._consecutive_failures += 1 self._last_failure_time = time.time() self._total_requests += 1 self._total_failures += 1 @property def avg_latency(self) -> float: return sum(self._latencies) / len(self._latencies) if self._latencies else 0.0 @property def health_score(self) -> float: if not self.ready: return 0.0 score = 1.0 score -= min(self._consecutive_failures * 0.2, 0.8) if self._latencies: avg = self.avg_latency if avg > 10000: score -= 0.3 elif avg > 5000: score -= 0.15 if self._total_requests > 5: fail_rate = self._total_failures / self._total_requests score -= fail_rate * 0.4 return max(0.0, min(1.0, score)) def get_instance_info(self) -> Dict: return { "instance_id": self.instance_id, "ready": self.ready, "health_score": round(self.health_score, 3), "consecutive_failures": self._consecutive_failures, "total_requests": self._total_requests, "total_failures": self._total_failures, "avg_latency_ms": round(self.avg_latency, 1), } class GptOssProvider(ModelProvider): def __init__(self, model_def, config, instance_id=0): super().__init__(model_def, config, instance_id) self._session = requests.Session() self._rotate() def _rotate(self): self._session.headers.update({ "User-Agent": random.choice(USER_AGENTS), "Accept-Language": "fr-FR,fr;q=0.9", "Origin": "https://gptunlimited.org", "Referer": "https://gptunlimited.org/", }) def _hash(self): return ''.join(random.choices(string.ascii_lowercase + string.digits, k=12)) def initialize(self) -> bool: with self._lock: if self.ready: return True self._rotate() try: r = self._session.get( f"{self.model_def.space_id}/gradio_api/info", timeout=15, ) self.ready = r.status_code == 200 return self.ready except Exception: return False def generate(self, message, history=None, system_prompt=None, temperature=None, max_tokens=None, images=None, **kw): if not self.ready: self.initialize() sys_p = system_prompt or self.config.default_system_prompt temp = (temperature if temperature is not None else self.model_def.default_temperature) h = self._hash() payload = { "data": [message, history or [], sys_p, temp], "event_data": None, "fn_index": self.model_def.fn_index, "trigger_id": None, "session_hash": h, } r = self._session.post( f"{self.model_def.space_id}/gradio_api/queue/join?", json=payload, headers={"Content-Type": "application/json"}, timeout=30, ) if r.status_code != 200: raise APIError(f"Queue join failed: {r.status_code}") data = r.json() if not data.get("event_id"): raise APIError("No event_id") resp = self._session.get( f"{self.model_def.space_id}/gradio_api/queue/data", params={"session_hash": h}, headers={"Accept": "text/event-stream"}, timeout=self.config.timeout_stream, stream=True, ) full = "" for d in GradioSSEParser.parse_sse(resp): msg = d.get("msg", "") if msg in ("process_generating", "process_completed"): output = d.get("output", {}) if not output.get("success", True): raise APIError(f"Gradio error: {output.get('error')}") t = GradioSSEParser.extract_text(output) if t: full = t if msg == "process_completed": break elif msg == "close_stream": break if not full.strip(): raise APIError("Empty response", "EMPTY") return (ResponseCleaner.clean_analysis(full) if self.model_def.clean_analysis else full) def generate_stream(self, message, history=None, system_prompt=None, temperature=None, max_tokens=None, images=None, **kw): if not self.ready: self.initialize() sys_p = system_prompt or self.config.default_system_prompt temp = (temperature if temperature is not None else self.model_def.default_temperature) h = self._hash() payload = { "data": [message, history or [], sys_p, temp], "event_data": None, "fn_index": self.model_def.fn_index, "trigger_id": None, "session_hash": h, } self._session.post( f"{self.model_def.space_id}/gradio_api/queue/join?", json=payload, headers={"Content-Type": "application/json"}, timeout=30, ) resp = self._session.get( f"{self.model_def.space_id}/gradio_api/queue/data", params={"session_hash": h}, headers={"Accept": "text/event-stream"}, timeout=self.config.timeout_stream, stream=True, ) metrics.active_streams += 1 last = "" try: for d in GradioSSEParser.parse_sse(resp): msg = d.get("msg", "") if msg in ("process_generating", "process_completed"): output = d.get("output", {}) if not output.get("success", True): raise APIError("Gradio error") raw = GradioSSEParser.extract_text(output) if raw: if self.model_def.clean_analysis: cleaned = ResponseCleaner.clean_analysis(raw) if cleaned and len(cleaned) > len(last): yield cleaned[len(last):] last = cleaned else: if len(raw) > len(last): yield raw[len(last):] last = raw if msg == "process_completed": return elif msg == "close_stream": return finally: metrics.active_streams = max(0, metrics.active_streams - 1) class GradioClientProvider(ModelProvider): """Generic provider for all gradio_client based models.""" def __init__(self, model_def, config, instance_id=0): super().__init__(model_def, config, instance_id) self._client = None self._chat_counter = 0 def initialize(self) -> bool: if not HAS_GRADIO_CLIENT: raise APIError("gradio_client not installed", "MISSING_DEP") with self._lock: if self.ready: return True try: log.info( f"[Instance {self.instance_id}] Connecting to " f"{self.model_def.space_id}..." ) self._client = GradioClient(self.model_def.space_id) self.ready = True return True except Exception as e: log.error( f"[Instance {self.instance_id}] Init failed for " f"{self.model_def.model_id}: {e}" ) return False def generate(self, message, history=None, system_prompt=None, temperature=None, max_tokens=None, images=None, **kw): if not self.ready: self.initialize() if not self._client: raise APIError(f"{self.model_def.model_id} not initialized") mid = self.model_def.model_id images = images or [] try: if mid == "command-a-vision": max_new = (max_tokens or self.model_def.extra_params.get("max_new_tokens", 700)) # Build multimodal message msg_payload: Any if images: img_path = save_image_temp(images[0]) if img_path: msg_payload = {"text": message, "files": [handle_file(img_path)]} else: msg_payload = {"text": message, "files": []} else: msg_payload = {"text": message, "files": []} result = self._client.predict( message=msg_payload, max_new_tokens=max_new, api_name=self.model_def.api_name, ) elif mid == "command-a-translate": max_new = (max_tokens or self.model_def.extra_params.get("max_new_tokens", 700)) result = self._client.predict( message=message, max_new_tokens=max_new, api_name=self.model_def.api_name, ) elif mid == "command-a-reasoning": thinking_budget = kw.get( "thinking_budget", self.model_def.extra_params.get("thinking_budget", 500), ) result = self._client.predict( message=message, thinking_budget=thinking_budget, api_name=self.model_def.api_name, ) return self._extract_reasoning(result) elif mid == "minimax-vl-01": temp = (temperature if temperature is not None else self.model_def.default_temperature) max_tok = (max_tokens or self.model_def.extra_params.get("max_tokens", 12800)) top_p = kw.get("top_p", self.model_def.extra_params.get("top_p", 0.9)) # Vision support if images: img_path = save_image_temp(images[0]) files = [handle_file(img_path)] if img_path else [] else: files = [] result = self._client.predict( message={"text": message, "files": files}, max_tokens=max_tok, temperature=temp, top_p=top_p, api_name=self.model_def.api_name, ) elif mid == "glm-4.5": sys_p = system_prompt or self.config.default_system_prompt temp = (temperature if temperature is not None else self.model_def.default_temperature) thinking = kw.get("thinking_enabled", self.model_def.thinking_default) include = kw.get("include_thinking", self.config.include_thinking) result = self._client.predict( msg=message, sys_prompt=sys_p, thinking_enabled=thinking, temperature=temp, api_name=self.model_def.api_name, ) return self._extract_glm(result, include) elif mid == "chatgpt": temp = (temperature if temperature is not None else self.model_def.default_temperature) top_p = kw.get("top_p", self.model_def.extra_params.get("top_p", 1.0)) chat_hist = [] if history: for pair in history: if isinstance(pair, (list, tuple)) and len(pair) == 2: chat_hist.append([str(pair[0]), str(pair[1])]) result = self._client.predict( inputs=message, top_p=top_p, temperature=temp, chat_counter=self._chat_counter, chatbot=chat_hist, api_name=self.model_def.api_name, ) self._chat_counter += 1 return ResponseCleaner.extract_chatgpt_text(result) elif mid == "qwen3-vl": # Vision support if images: img_path = save_image_temp(images[0]) files = [handle_file(img_path)] if img_path else [] result = self._client.predict( input_value={"files": files, "text": message}, api_name="/add_message", ) else: result = self._client.predict( input_value={"files": None, "text": message}, api_name="/add_message", ) return ResponseCleaner.extract_qwen_text(result) elif mid == "qwen2.5-coder": sys_override = self.model_def.extra_params.get( "system_prompt_override", "" ) if sys_override: try: self._client.predict( input=sys_override, api_name="/lambda_1", ) except Exception as e: log.warning(f"[qwen2.5-coder] Failed to set system prompt: {e}") result = self._client.predict( query=message, api_name="/generation_code", ) return ResponseCleaner.extract_qwen_coder_text(result) else: raise APIError(f"Unknown model handler: {mid}") if isinstance(result, str): return result.strip() if isinstance(result, dict): return json.dumps(result, ensure_ascii=False) if isinstance(result, (list, tuple)): return str(result[0]).strip() if result else "" return str(result) except APIError: raise except Exception as e: raise APIError(f"{mid} error: {e}", "PROVIDER_ERROR") def _extract_reasoning(self, result: Any) -> str: if result is None: return "" if isinstance(result, str): return result.strip() if isinstance(result, dict): for key in ("response", "output", "answer", "text", "content", "result"): if key in result: val = result[key] if isinstance(val, str): return val.strip() return str(val) thinking = result.get("thinking", "") response = result.get("response", result.get("output", "")) if thinking and response: return f"\n{thinking}\n\n{response}" if response: return str(response).strip() return json.dumps(result, ensure_ascii=False, indent=2) if isinstance(result, (list, tuple)): if len(result) == 1: return str(result[0]).strip() texts = [] for item in result: if isinstance(item, str) and item.strip(): texts.append(item.strip()) if texts: return "\n".join(texts) return json.dumps(result, ensure_ascii=False) if isinstance(result, (int, float, bool)): return str(result) return str(result) def _extract_glm(self, result, include_thinking: bool = True) -> str: if isinstance(result, tuple) and len(result) >= 1: chatbot = result[0] if isinstance(chatbot, list) and chatbot: for msg in reversed(chatbot): if isinstance(msg, dict) and msg.get("role") == "assistant": content = msg.get("content", "") raw = content if isinstance(content, str) else str(content) return ResponseCleaner.clean_glm(raw, include_thinking) last = chatbot[-1] if isinstance(last, dict): raw = last.get("content", "") raw = raw if isinstance(raw, str) else str(raw) return ResponseCleaner.clean_glm(raw, include_thinking) return ResponseCleaner.clean_glm(str(chatbot), include_thinking) if isinstance(result, str): return ResponseCleaner.clean_glm(result, include_thinking) return ResponseCleaner.clean_glm(str(result), include_thinking) def create_provider(model_id: str, config: Config, instance_id: int = 0) -> ModelProvider: if model_id not in MODEL_REGISTRY: raise ModelNotFoundError(model_id) mdef = MODEL_REGISTRY[model_id] if model_id == "gpt-oss-120b": return GptOssProvider(mdef, config, instance_id) return GradioClientProvider(mdef, config, instance_id) # ═══════════════════════════════════════════════════════════════ # LOAD BALANCER # ═══════════════════════════════════════════════════════════════ class LoadBalancedProviderPool: def __init__(self, model_id: str, config: Config): self.model_id = model_id self.config = config self.mdef = MODEL_REGISTRY[model_id] pool_size = self.mdef.lb_pool_size if self.mdef.lb_enabled else 1 self._instances: List[ModelProvider] = [] self._rr_index = 0 self._lock = threading.Lock() for i in range(pool_size): self._instances.append(create_provider(model_id, config, instance_id=i)) log.info( f"[LB] Created pool for '{model_id}' with {len(self._instances)} " f"instance(s), lb_enabled={self.mdef.lb_enabled}" ) @property def pool_size(self) -> int: return len(self._instances) def initialize_all(self) -> int: ok = 0 for inst in self._instances: try: if inst.initialize(): ok += 1 except Exception as e: log.warning( f"[LB] Failed to init {self.model_id} " f"instance {inst.instance_id}: {e}" ) return ok def initialize_one(self) -> bool: for inst in self._instances: try: if inst.initialize(): return True except Exception: continue return False def _select_instance(self) -> ModelProvider: if len(self._instances) == 1: return self._instances[0] with self._lock: scored = [] for inst in self._instances: score = inst.health_score scored.append((inst, max(score, 0.05))) total_weight = sum(s for _, s in scored) if total_weight <= 0: inst = self._instances[self._rr_index % len(self._instances)] self._rr_index += 1 return inst r = random.uniform(0, total_weight) cumulative = 0.0 for inst, weight in scored: cumulative += weight if r <= cumulative: return inst return scored[-1][0] def _get_ordered_instances(self) -> List[ModelProvider]: return sorted(self._instances, key=lambda p: p.health_score, reverse=True) def execute(self, fn_name: str, **kwargs) -> Any: primary = self._select_instance() metrics.record_lb_dispatch() if not primary.ready: try: primary.initialize() except Exception: pass start = time.monotonic() try: result = self._call_provider(primary, fn_name, **kwargs) latency = (time.monotonic() - start) * 1000 primary.record_success(latency) return result except Exception as primary_err: primary.record_failure() log.warning( f"[LB] Primary instance {primary.instance_id} for " f"'{self.model_id}' failed: {primary_err}" ) for inst in self._get_ordered_instances(): if inst is primary: continue if not inst.ready: try: inst.initialize() except Exception: continue metrics.record_lb_dispatch(failover=True) start = time.monotonic() try: result = self._call_provider(inst, fn_name, **kwargs) latency = (time.monotonic() - start) * 1000 inst.record_success(latency) log.info( f"[LB] Failover to instance {inst.instance_id} " f"for '{self.model_id}' succeeded" ) return result except Exception as e: inst.record_failure() log.warning( f"[LB] Failover instance {inst.instance_id} " f"for '{self.model_id}' failed: {e}" ) raise APIError( f"All {len(self._instances)} instances for '{self.model_id}' failed", "ALL_INSTANCES_FAILED", ) def execute_stream(self, **kwargs) -> Generator[str, None, None]: primary = self._select_instance() metrics.record_lb_dispatch() if not primary.ready: try: primary.initialize() except Exception: pass try: yield from self._call_provider_stream(primary, **kwargs) return except Exception as primary_err: primary.record_failure() log.warning( f"[LB] Stream primary instance {primary.instance_id} " f"for '{self.model_id}' failed: {primary_err}" ) for inst in self._get_ordered_instances(): if inst is primary: continue if not inst.ready: try: inst.initialize() except Exception: continue metrics.record_lb_dispatch(failover=True) try: yield from self._call_provider_stream(inst, **kwargs) return except Exception as e: inst.record_failure() log.warning( f"[LB] Stream failover instance {inst.instance_id} " f"for '{self.model_id}' failed: {e}" ) raise APIError( f"All streaming instances for '{self.model_id}' failed", "ALL_INSTANCES_FAILED", ) def _call_provider(self, provider: ModelProvider, fn_name: str, **kwargs) -> Any: if not provider.ready: provider.initialize() fn = getattr(provider, fn_name) return fn(**kwargs) def _call_provider_stream(self, provider: ModelProvider, **kwargs) -> Generator[str, None, None]: if not provider.ready: provider.initialize() start = time.monotonic() try: yield from provider.generate_stream(**kwargs) latency = (time.monotonic() - start) * 1000 provider.record_success(latency) except Exception: provider.record_failure() raise def get_pool_info(self) -> Dict: return { "model_id": self.model_id, "lb_enabled": self.mdef.lb_enabled, "pool_size": len(self._instances), "is_beta": self.mdef.is_beta, "instances": [inst.get_instance_info() for inst in self._instances], } # ═══════════════════════════════════════════════════════════════ # MULTI-MODEL CLIENT # ═══════════════════════════════════════════════════════════════ class MultiModelClient: def __init__(self, config: Config): self.config = config self._lb_pools: Dict[str, LoadBalancedProviderPool] = {} self._lock = threading.Lock() self._conversations: Dict[str, Conversation] = {} self._active_conv_id: Optional[str] = None self._current_model = config.default_model self.rate_limiter = RateLimiter(config.rate_limit_rps, config.rate_limit_burst) self.circuit_breaker = CircuitBreaker() @property def current_model(self): return self._current_model @current_model.setter def current_model(self, m): if m not in MODEL_REGISTRY: raise ModelNotFoundError(m) self._current_model = m def _get_lb_pool(self, model_id: str) -> LoadBalancedProviderPool: if model_id not in self._lb_pools: with self._lock: if model_id not in self._lb_pools: self._lb_pools[model_id] = LoadBalancedProviderPool( model_id, self.config ) return self._lb_pools[model_id] def _ensure_ready(self, model_id: str) -> LoadBalancedProviderPool: lb_pool = self._get_lb_pool(model_id) has_ready = any(inst.ready for inst in lb_pool._instances) if not has_ready: if not lb_pool.initialize_one(): raise APIError(f"Cannot init any instance for {model_id}", "INIT_FAILED") return lb_pool @property def active_conversation(self) -> Conversation: if self._active_conv_id not in self._conversations: conv = Conversation( system_prompt=self.config.default_system_prompt, model_id=self._current_model, ) self._conversations[conv.conversation_id] = conv self._active_conv_id = conv.conversation_id return self._conversations[self._active_conv_id] def new_conversation(self, system_prompt=None, model_id=None) -> Conversation: conv = Conversation( system_prompt=system_prompt or self.config.default_system_prompt, model_id=model_id or self._current_model, ) self._conversations[conv.conversation_id] = conv self._active_conv_id = conv.conversation_id return conv def init_model(self, model_id: str) -> bool: try: lb_pool = self._get_lb_pool(model_id) return lb_pool.initialize_one() except Exception: return False def init_model_all(self, model_id: str) -> int: try: lb_pool = self._get_lb_pool(model_id) return lb_pool.initialize_all() except Exception: return 0 def send_message( self, message: Any, # str OR list (multimodal) *, stream: bool = False, model: Optional[str] = None, conversation_id: Optional[str] = None, system_prompt: Optional[str] = None, temperature: Optional[float] = None, max_tokens: Optional[int] = None, include_thinking: Optional[bool] = None, images: Optional[List[str]] = None, **kwargs, ) -> Union[str, Generator]: model_id = model or self._current_model if model_id not in MODEL_REGISTRY: raise ModelNotFoundError(model_id) mdef = MODEL_REGISTRY[model_id] # ── Normalise multimodal content ────────────────────── if isinstance(message, list): text, extracted_images = extract_text_and_images(message) if not images: images = extracted_images message = text if isinstance(message, str): message = message.strip() else: message = str(message).strip() if not message and not images: raise APIError("Empty message", "INVALID_INPUT", 400) if len(message) > self.config.max_message_length: raise APIError("Message too long", "INVALID_INPUT", 400) if not self.circuit_breaker.can_execute(): raise APIError("Circuit breaker open", "CIRCUIT_OPEN", 503) if not self.rate_limiter.acquire(timeout=10.0): raise APIError("Rate limited (10 req/s max)", "RATE_LIMITED", 429) conv = (self._conversations.get(conversation_id, self.active_conversation) if conversation_id else self.active_conversation) conv.model_id = model_id if system_prompt: conv.system_prompt = system_prompt history = conv.build_gradio_history() if mdef.supports_history else None conv.add_message("user", message, self.config.max_history_messages) eff_temp = (temperature if temperature is not None else mdef.default_temperature) eff_sys = conv.system_prompt if mdef.supports_system_prompt else None eff_thinking = (include_thinking if include_thinking is not None else self.config.include_thinking) extra = dict(kwargs) if mdef.supports_thinking: extra["include_thinking"] = eff_thinking start = time.monotonic() for attempt in range(self.config.max_retries + 1): try: if attempt > 0: time.sleep( self.config.retry_backoff_base ** attempt + random.uniform(0, self.config.retry_jitter) ) metrics.record_retry() lb_pool = self._ensure_ready(model_id) if stream and mdef.supports_streaming: gen = lb_pool.execute_stream( message=message, history=history, system_prompt=eff_sys, temperature=eff_temp, max_tokens=max_tokens, images=images, **extra, ) return self._wrap_stream(gen, conv, start, model_id) result = lb_pool.execute( "generate", message=message, history=history, system_prompt=eff_sys, temperature=eff_temp, max_tokens=max_tokens, images=images, **extra, ) dur = (time.monotonic() - start) * 1000 thinking, response = ThinkingParser.split(result) conv.add_message("assistant", response, self.config.max_history_messages, thinking=thinking) metrics.record_request(True, dur, len(result), model_id) self.circuit_breaker.record_success() return result except APIError: self.circuit_breaker.record_failure() if attempt == self.config.max_retries: dur = (time.monotonic() - start) * 1000 metrics.record_request(False, dur, model=model_id) raise except Exception as e: self.circuit_breaker.record_failure() if attempt == self.config.max_retries: dur = (time.monotonic() - start) * 1000 metrics.record_request(False, dur, model=model_id) raise APIError(str(e)) def _wrap_stream(self, gen, conv, start, model_id): full = "" try: for chunk in gen: full += chunk yield chunk thinking, response = ThinkingParser.split(full) conv.add_message("assistant", response, self.config.max_history_messages, thinking=thinking) metrics.record_request( True, (time.monotonic() - start) * 1000, len(full), model_id, ) self.circuit_breaker.record_success() except Exception: metrics.record_request( False, (time.monotonic() - start) * 1000, model=model_id, ) self.circuit_breaker.record_failure() raise def get_status(self) -> Dict: lb_info = {} for model_id, lb_pool in self._lb_pools.items(): lb_info[model_id] = lb_pool.get_pool_info() return { "version": VERSION, "current_model": self._current_model, "models": list(MODEL_REGISTRY.keys()), "load_balancer": lb_info, "conversations": len(self._conversations), "circuit_breaker": self.circuit_breaker.state, "rate_limiter": self.rate_limiter.get_info(), } # ═══════════════════════════════════════════════════════════════ # SESSION POOL # ═══════════════════════════════════════════════════════════════ class SessionPool: def __init__(self, config: Config): self.config = config self._clients = [ MultiModelClient(config) for _ in range(config.pool_size) ] self._idx = 0 self._lock = threading.Lock() def init_default(self): for c in self._clients: c.init_model(self.config.default_model) def init_model(self, model_id: str) -> int: total = 0 for c in self._clients: total += c.init_model_all(model_id) return total def acquire(self) -> MultiModelClient: with self._lock: c = self._clients[self._idx % len(self._clients)] self._idx += 1 return c # ═══════════════════════════════════════════════════════════════ # ALIAS RESOLVER # ═══════════════════════════════════════════════════════════════ ALIASES = { "gpt-oss": "gpt-oss-120b", "gptoss": "gpt-oss-120b", "amd": "gpt-oss-120b", "command-a": "command-a-vision", "command-vision": "command-a-vision", "cohere-vision": "command-a-vision", "command-translate": "command-a-translate", "cohere-translate": "command-a-translate", "translate": "command-a-translate", "command-reasoning": "command-a-reasoning", "reasoning": "command-a-reasoning", "cohere-reasoning": "command-a-reasoning", "command-r": "command-a-reasoning", "minimax": "minimax-vl-01", "minimax-vl": "minimax-vl-01", "glm": "glm-4.5", "glm4": "glm-4.5", "glm-4": "glm-4.5", "zhipu": "glm-4.5", "gpt": "chatgpt", "gpt-3.5": "chatgpt", "gpt3": "chatgpt", "openai": "chatgpt", "qwen": "qwen3-vl", "qwen3": "qwen3-vl", "qwen-vl": "qwen3-vl", "qwen-coder": "qwen2.5-coder", "qwen2.5": "qwen2.5-coder", "qwen25-coder": "qwen2.5-coder", "coder": "qwen2.5-coder", } def resolve_alias(model_id: str) -> str: if not model_id: return config.default_model return ALIASES.get(model_id.lower(), model_id) # ═══════════════════════════════════════════════════════════════ # FLASK APP # ═══════════════════════════════════════════════════════════════ config = Config.from_env() pool = SessionPool(config) pool.init_default() app = Flask(APP_NAME) @app.after_request def cors(response): response.headers["Access-Control-Allow-Origin"] = "*" response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization" response.headers["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS" return response @app.errorhandler(APIError) def handle_api_error(e: APIError): return jsonify({"ok": False, **e.to_dict()}), e.status @app.route("/") def index(): return jsonify({ "name": APP_NAME, "version": VERSION, "default_model": config.default_model, "features": ["load_balancing", "10_req_per_second_limit", "failover", "vision"], "models": list(MODEL_REGISTRY.keys()), "beta_models": [mid for mid, mdef in MODEL_REGISTRY.items() if mdef.is_beta], "vision_models": [mid for mid, mdef in MODEL_REGISTRY.items() if mdef.supports_vision], "endpoints": { "POST /chat": "Chat with any model", "POST /chat/stream": "Streaming chat", "POST /v1/chat/completions": "OpenAI-compatible (supports vision)", "GET /v1/models": "List models", "POST /models/init": "Init a model", "GET /health": "Health check", "GET /metrics": "Metrics", "GET /lb/status": "Load balancer status", }, }) @app.route("/chat", methods=["POST"]) def chat(): data = freq.get_json(force=True, silent=True) or {} raw_message = data.get("message", "") images = data.get("images", []) # Support multimodal content directly in message field if isinstance(raw_message, list): text, extracted = extract_text_and_images(raw_message) images = images or extracted message = text else: message = str(raw_message).strip() if not message and not images: return jsonify({"ok": False, "error": "'message' required"}), 400 model_id = resolve_alias(data.get("model", config.default_model)) include_thinking = data.get("include_thinking", config.include_thinking) client = pool.acquire() if data.get("new_conversation"): client.new_conversation(data.get("system_prompt"), model_id) extra = {} if model_id == "command-a-reasoning" and "thinking_budget" in data: extra["thinking_budget"] = data["thinking_budget"] result = client.send_message( message, model=model_id, system_prompt=data.get("system_prompt"), temperature=data.get("temperature"), max_tokens=data.get("max_tokens"), include_thinking=include_thinking, images=images or None, **extra, ) thinking, clean = ThinkingParser.split(result) mdef = MODEL_REGISTRY.get(model_id) resp = { "ok": True, "response": clean, "model": model_id, "conversation_id": client.active_conversation.conversation_id, "history_size": len(client.active_conversation.messages), } if thinking: resp["thinking"] = thinking if mdef and mdef.is_beta: resp["beta"] = True return jsonify(resp) @app.route("/chat/stream", methods=["POST"]) def chat_stream(): data = freq.get_json(force=True, silent=True) or {} raw_message = data.get("message", "") images = data.get("images", []) if isinstance(raw_message, list): text, extracted = extract_text_and_images(raw_message) images = images or extracted message = text else: message = str(raw_message).strip() if not message and not images: return jsonify({"ok": False, "error": "'message' required"}), 400 model_id = resolve_alias(data.get("model", config.default_model)) include_thinking = data.get("include_thinking", config.include_thinking) client = pool.acquire() if data.get("new_conversation"): client.new_conversation(data.get("system_prompt"), model_id) mdef = MODEL_REGISTRY.get(model_id) use_stream = mdef.supports_streaming if mdef else False extra = {} if model_id == "command-a-reasoning" and "thinking_budget" in data: extra["thinking_budget"] = data["thinking_budget"] def generate(): try: if use_stream: for chunk in client.send_message( message, stream=True, model=model_id, system_prompt=data.get("system_prompt"), temperature=data.get("temperature"), max_tokens=data.get("max_tokens"), include_thinking=include_thinking, images=images or None, **extra, ): yield f"data: {json.dumps({'chunk': chunk})}\n\n" else: result = client.send_message( message, model=model_id, system_prompt=data.get("system_prompt"), temperature=data.get("temperature"), max_tokens=data.get("max_tokens"), include_thinking=include_thinking, images=images or None, **extra, ) yield f"data: {json.dumps({'chunk': result})}\n\n" yield "data: [DONE]\n\n" except APIError as e: yield f"data: {json.dumps(e.to_dict())}\n\n" return Response(stream_with_context(generate()), content_type="text/event-stream") @app.route("/v1/models", methods=["GET"]) def list_models(): models = [] for mid, mdef in MODEL_REGISTRY.items(): model_info = { "id": mid, "object": "model", "owned_by": mdef.owned_by, "created": 0, "description": mdef.description, "capabilities": { "vision": mdef.supports_vision, "streaming": mdef.supports_streaming, "system_prompt": mdef.supports_system_prompt, "temperature": mdef.supports_temperature, "history": mdef.supports_history, "thinking": mdef.supports_thinking, }, "load_balancing": { "enabled": mdef.lb_enabled, "pool_size": mdef.lb_pool_size, }, } if mdef.is_beta: model_info["beta"] = True models.append(model_info) return jsonify({"object": "list", "data": models}) @app.route("/v1/chat/completions", methods=["POST", "OPTIONS"]) def openai_compat(): if freq.method == "OPTIONS": return "", 200 data = freq.get_json(force=True, silent=True) or {} messages = data.get("messages", []) do_stream = data.get("stream", False) temperature = data.get("temperature") max_tokens = data.get("max_tokens") model_id = resolve_alias(data.get("model", config.default_model)) include_thinking = data.get("include_thinking", config.include_thinking) if model_id not in MODEL_REGISTRY: return jsonify({ "error": { "message": f"Model '{model_id}' not found. Available: {list(MODEL_REGISTRY.keys())}", "type": "invalid_request_error", "available_models": list(MODEL_REGISTRY.keys()), } }), 404 if not messages: return jsonify({"error": {"message": "messages required"}}), 400 # ── Extract user message, system prompt, and images ─────── user_msg: str = "" system_prompt: Optional[str] = None images: List[str] = [] for msg in messages: role = msg.get("role", "") content = msg.get("content", "") if role == "system": system_prompt = content if isinstance(content, str) else str(content) if role == "user": if isinstance(content, list): text, imgs = extract_text_and_images(content) user_msg = text images.extend(imgs) elif isinstance(content, str): user_msg = content else: user_msg = str(content) if not user_msg and not images: return jsonify({"error": {"message": "No user message"}}), 400 rid = f"chatcmpl-{uuid.uuid4().hex[:29]}" created = int(time.time()) client = pool.acquire() client.new_conversation(system_prompt, model_id) # Replay history (all but the last user message) for msg in messages[:-1]: role = msg.get("role") content = msg.get("content", "") if role in ("user", "assistant") and content: text = ( extract_text_and_images(content)[0] if isinstance(content, list) else str(content) ) if text: client.active_conversation.add_message(role, text) mdef = MODEL_REGISTRY[model_id] extra = {} if model_id == "command-a-reasoning" and "thinking_budget" in data: extra["thinking_budget"] = data["thinking_budget"] if do_stream: def generate(): try: yield f"data: {json.dumps({'id': rid, 'object': 'chat.completion.chunk', 'created': created, 'model': model_id, 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n" if mdef.supports_streaming: for chunk in client.send_message( user_msg, stream=True, model=model_id, temperature=temperature, max_tokens=max_tokens, include_thinking=include_thinking, images=images or None, **extra, ): yield f"data: {json.dumps({'id': rid, 'object': 'chat.completion.chunk', 'created': created, 'model': model_id, 'choices': [{'index': 0, 'delta': {'content': chunk}, 'finish_reason': None}]})}\n\n" else: result = client.send_message( user_msg, model=model_id, temperature=temperature, max_tokens=max_tokens, include_thinking=include_thinking, images=images or None, **extra, ) yield f"data: {json.dumps({'id': rid, 'object': 'chat.completion.chunk', 'created': created, 'model': model_id, 'choices': [{'index': 0, 'delta': {'content': result}, 'finish_reason': None}]})}\n\n" yield f"data: {json.dumps({'id': rid, 'object': 'chat.completion.chunk', 'created': created, 'model': model_id, 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n" yield "data: [DONE]\n\n" except Exception as e: yield f"data: {json.dumps({'error': {'message': str(e)}})}\n\n" return Response(stream_with_context(generate()), content_type="text/event-stream") result = client.send_message( user_msg, model=model_id, temperature=temperature, max_tokens=max_tokens, include_thinking=include_thinking, images=images or None, **extra, ) return jsonify({ "id": rid, "object": "chat.completion", "created": created, "model": model_id, "choices": [{ "index": 0, "message": {"role": "assistant", "content": result}, "finish_reason": "stop", }], "usage": { "prompt_tokens": len(user_msg) // 4, "completion_tokens": len(result) // 4, "total_tokens": (len(user_msg) + len(result)) // 4, }, }) @app.route("/new", methods=["POST"]) def new_conv(): data = freq.get_json(force=True, silent=True) or {} model_id = resolve_alias(data.get("model", config.default_model)) client = pool.acquire() conv = client.new_conversation(data.get("system_prompt"), model_id) return jsonify({ "ok": True, "conversation_id": conv.conversation_id, "model": model_id, }) @app.route("/health", methods=["GET"]) def health(): client = pool.acquire() return jsonify(client.get_status()) @app.route("/metrics", methods=["GET"]) def metrics_endpoint(): return jsonify(metrics.to_dict()) @app.route("/lb/status", methods=["GET"]) def lb_status(): all_pools = {} for client in pool._clients: for model_id, lb_pool in client._lb_pools.items(): key = model_id if key not in all_pools: all_pools[key] = [] all_pools[key].append(lb_pool.get_pool_info()) return jsonify({ "ok": True, "version": VERSION, "rate_limit": f"{config.rate_limit_rps} req/s", "models": all_pools, }) @app.route("/conversations", methods=["GET"]) def conversations(): client = pool.acquire() return jsonify({ "conversations": [c.to_dict() for c in client._conversations.values()] }) @app.route("/models/init", methods=["POST"]) def init_model_ep(): data = freq.get_json(force=True, silent=True) or {} model_id = resolve_alias(data.get("model", "")) if not model_id or model_id not in MODEL_REGISTRY: return jsonify({ "ok": False, "error": f"Unknown model. Available: {list(MODEL_REGISTRY.keys())}", }), 400 count = pool.init_model(model_id) mdef = MODEL_REGISTRY[model_id] resp = { "ok": True, "model": model_id, "initialized_instances": count, "lb_enabled": mdef.lb_enabled, "pool_size_per_client": mdef.lb_pool_size, } if mdef.is_beta: resp["beta"] = True return jsonify(resp) # ═══════════════════════════════════════════════════════════════ # ENTRY POINT # ═══════════════════════════════════════════════════════════════ if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) log.info(f"Starting {APP_NAME} v{VERSION} on port {port}") log.info(f"Models: {list(MODEL_REGISTRY.keys())}") log.info(f"Rate limit: {config.rate_limit_rps} req/s (burst: {config.rate_limit_burst})") for mid, mdef in MODEL_REGISTRY.items(): lb_str = ( f"LB ON (pool={mdef.lb_pool_size})" if mdef.lb_enabled else "LB OFF (single instance)" ) vision_str = " [VISION]" if mdef.supports_vision else "" beta_str = " [BETA]" if mdef.is_beta else "" log.info(f" {mid}: {lb_str}{vision_str}{beta_str}") app.run(host="0.0.0.0", port=port, threaded=True)