MB-IDK commited on
Commit
1cd3ed6
Β·
verified Β·
1 Parent(s): e2eee06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +839 -188
app.py CHANGED
@@ -1,7 +1,7 @@
1
  #!/usr/bin/env python3
2
  """
3
  Multi-Model AI API β€” HuggingFace Spaces Edition
4
- Unified API gateway for multiple AI models via Hugging Face Spaces.
5
  """
6
 
7
  import re, os, json, uuid, time, random, string, logging, threading
@@ -23,7 +23,7 @@ except ImportError:
23
  # CONFIG & CONSTANTS
24
  # ═══════════════════════════════════════════════════════════════
25
 
26
- VERSION = "2.2.0-hf"
27
  APP_NAME = "Multi-Model-AI-API"
28
  DEFAULT_SYSTEM_PROMPT = "You are a helpful, friendly AI assistant."
29
  DEFAULT_MODEL = "gpt-oss-120b"
@@ -34,6 +34,8 @@ log = logging.getLogger(APP_NAME)
34
  USER_AGENTS = [
35
  "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Chrome/144.0.0.0 Safari/537.36",
36
  "Mozilla/5.0 (Macintosh; Intel Mac OS X 14_5) AppleWebKit/605.1.15 Safari/605.1.15",
 
 
37
  ]
38
 
39
  # ═══════════════════════════════════════════════════════════════
@@ -61,12 +63,17 @@ class ModelDef:
61
  api_name: Optional[str] = None
62
  extra_params: Dict[str, Any] = field(default_factory=dict)
63
  clean_analysis: bool = False
 
 
 
64
 
65
  MODEL_REGISTRY: Dict[str, ModelDef] = {}
66
 
 
67
  def register_model(m: ModelDef):
68
  MODEL_REGISTRY[m.model_id] = m
69
 
 
70
  def _init_registry():
71
  register_model(ModelDef(
72
  model_id="gpt-oss-120b", display_name="AMD GPT-OSS-120B",
@@ -74,6 +81,7 @@ def _init_registry():
74
  owned_by="amd", description="AMD open-source 120B model",
75
  fn_index=8, clean_analysis=True, default_temperature=0.0,
76
  supports_vision=False, supports_thinking=False,
 
77
  ))
78
  register_model(ModelDef(
79
  model_id="command-a-vision", display_name="Cohere Command-A Vision",
@@ -83,6 +91,7 @@ def _init_registry():
83
  supports_temperature=False, supports_streaming=False, supports_history=False,
84
  supports_thinking=False, max_tokens_default=700,
85
  extra_params={"max_new_tokens": 700},
 
86
  ))
87
  register_model(ModelDef(
88
  model_id="command-a-translate", display_name="Cohere Command-A Translate",
@@ -92,6 +101,7 @@ def _init_registry():
92
  supports_temperature=False, supports_streaming=False, supports_history=False,
93
  supports_thinking=False, max_tokens_default=700,
94
  extra_params={"max_new_tokens": 700},
 
95
  ))
96
  register_model(ModelDef(
97
  model_id="minimax-vl-01", display_name="MiniMax VL-01",
@@ -101,6 +111,7 @@ def _init_registry():
101
  supports_temperature=True, supports_streaming=False, supports_history=False,
102
  supports_thinking=False, max_tokens_default=12800, default_temperature=0.1,
103
  extra_params={"max_tokens": 12800, "top_p": 0.9},
 
104
  ))
105
  register_model(ModelDef(
106
  model_id="glm-4.5", display_name="GLM-4.5 (ZhipuAI)",
@@ -110,6 +121,7 @@ def _init_registry():
110
  supports_temperature=True, supports_streaming=False, supports_history=False,
111
  supports_thinking=True, thinking_default=True, default_temperature=1.0,
112
  extra_params={"thinking_enabled": True},
 
113
  ))
114
  register_model(ModelDef(
115
  model_id="chatgpt", display_name="ChatGPT (Community)",
@@ -119,6 +131,7 @@ def _init_registry():
119
  supports_temperature=True, supports_streaming=False, supports_history=True,
120
  supports_thinking=False, default_temperature=1.0,
121
  extra_params={"top_p": 1.0},
 
122
  ))
123
  register_model(ModelDef(
124
  model_id="qwen3-vl", display_name="Qwen3-VL (Alibaba)",
@@ -127,8 +140,10 @@ def _init_registry():
127
  api_name="/add_message", supports_vision=True, supports_system_prompt=False,
128
  supports_temperature=False, supports_streaming=False, supports_history=False,
129
  supports_thinking=False, max_tokens_default=4096,
 
130
  ))
131
 
 
132
  _init_registry()
133
 
134
  # ═══════════════════════════════════════════════════════════════
@@ -143,8 +158,8 @@ class Config:
143
  max_retries: int = 3
144
  retry_backoff_base: float = 1.5
145
  retry_jitter: float = 0.5
146
- rate_limit_rpm: int = 10
147
- rate_limit_burst: int = 3
148
  pool_size: int = 2
149
  max_history_messages: int = 50
150
  max_message_length: int = 10000
@@ -158,12 +173,14 @@ class Config:
158
  env_map = {
159
  "MMAI_TIMEOUT": ("timeout_stream", int),
160
  "MMAI_MAX_RETRIES": ("max_retries", int),
161
- "MMAI_RATE_LIMIT": ("rate_limit_rpm", int),
 
162
  "MMAI_POOL_SIZE": ("pool_size", int),
163
  "MMAI_SYSTEM_PROMPT": ("default_system_prompt", str),
164
  "MMAI_TEMPERATURE": ("default_temperature", float),
165
  "MMAI_DEFAULT_MODEL": ("default_model", str),
166
- "MMAI_INCLUDE_THINKING": ("include_thinking", lambda x: x.lower() in ("1", "true")),
 
167
  }
168
  for env_key, (attr, conv) in env_map.items():
169
  val = os.environ.get(env_key)
@@ -183,12 +200,17 @@ class APIError(Exception):
183
  super().__init__(message)
184
  self.code = code
185
  self.status = status
 
186
  def to_dict(self):
187
  return {"error": str(self), "code": self.code}
188
 
 
189
  class ModelNotFoundError(APIError):
190
  def __init__(self, model_id: str):
191
- super().__init__(f"Model '{model_id}' not found. Available: {list(MODEL_REGISTRY.keys())}", "MODEL_NOT_FOUND", 404)
 
 
 
192
 
193
  # ═══════════════════════════════════════════════════════════════
194
  # RESPONSE CLEANER
@@ -232,7 +254,8 @@ class ResponseCleaner:
232
  }
233
  for entity, char in entities.items():
234
  text = text.replace(entity, char)
235
- text = re.sub(r'&#x([0-9a-fA-F]+);', lambda m: chr(int(m.group(1), 16)), text)
 
236
  text = re.sub(r'&#(\d+);', lambda m: chr(int(m.group(1))), text)
237
  return text
238
 
@@ -249,12 +272,25 @@ class ResponseCleaner:
249
  if '<details' not in text and '<div' not in text:
250
  return text.strip()
251
  thinking_text = ""
252
- thinking_match = re.search(r'<details[^>]*>.*?<div[^>]*>(.*?)</div>\s*</details>', text, re.DOTALL | re.IGNORECASE)
 
 
 
253
  if thinking_match:
254
  thinking_text = cls._strip_html(thinking_match.group(1)).strip()
255
- text_without_details = re.sub(r'<details[^>]*>.*?</details>', '', text, flags=re.DOTALL | re.IGNORECASE).strip()
256
- div_match = re.search(r"<div[^>]*>\s*(.*?)\s*</div>", text_without_details, re.DOTALL | re.IGNORECASE)
257
- response_text = cls._strip_html(div_match.group(1)).strip() if div_match else cls._strip_html(text_without_details).strip()
 
 
 
 
 
 
 
 
 
 
258
  if thinking_text and include_thinking:
259
  return f"<thinking>\n{thinking_text}\n</thinking>\n{response_text}"
260
  return response_text
@@ -307,7 +343,8 @@ class ResponseCleaner:
307
  return str(result)
308
 
309
  @classmethod
310
- def clean(cls, text: str, model_id: str = "", include_thinking: bool = True) -> str:
 
311
  if not text:
312
  return text
313
  text = text.strip()
@@ -326,7 +363,10 @@ class ResponseCleaner:
326
  class ThinkingParser:
327
  @staticmethod
328
  def split(text: str) -> Tuple[Optional[str], str]:
329
- match = re.match(r'\s*<thinking>\s*\n?(.*?)\n?\s*</thinking>\s*\n?(.*)', text, re.DOTALL | re.IGNORECASE)
 
 
 
330
  if match:
331
  thinking = match.group(1).strip()
332
  response = match.group(2).strip()
@@ -351,6 +391,7 @@ class Message:
351
  timestamp: float = field(default_factory=time.time)
352
  message_id: str = field(default_factory=lambda: str(uuid.uuid4()))
353
 
 
354
  @dataclass
355
  class Conversation:
356
  conversation_id: str = field(default_factory=lambda: str(uuid.uuid4()))
@@ -361,7 +402,9 @@ class Conversation:
361
  system_prompt: str = DEFAULT_SYSTEM_PROMPT
362
  model_id: str = DEFAULT_MODEL
363
 
364
- def add_message(self, role: str, content: str, max_messages: int = 50, thinking: Optional[str] = None) -> Message:
 
 
365
  msg = Message(role=role, content=content, thinking=thinking)
366
  self.messages.append(msg)
367
  self.updated_at = time.time()
@@ -378,7 +421,9 @@ class Conversation:
378
  non_system = [m for m in self.messages if m.role != "system"]
379
  i = 0
380
  while i < len(non_system) - 1:
381
- if non_system[i].role == "user" and i + 1 < len(non_system) and non_system[i + 1].role == "assistant":
 
 
382
  history.append([non_system[i].content, non_system[i + 1].content])
383
  i += 2
384
  else:
@@ -390,13 +435,16 @@ class Conversation:
390
 
391
  def to_dict(self) -> Dict:
392
  return {
393
- "conversation_id": self.conversation_id, "title": self.title,
394
- "model": self.model_id, "message_count": len(self.messages),
395
- "created_at": self.created_at, "updated_at": self.updated_at,
 
 
 
396
  }
397
 
398
  # ═══════════════════════════════════════════════════════════════
399
- # METRICS & RATE LIMITER
400
  # ═══════════════════════════════════════════════════════════════
401
 
402
  @dataclass
@@ -411,8 +459,12 @@ class Metrics:
411
  requests_per_model: Dict[str, int] = field(default_factory=dict)
412
  _latencies: deque = field(default_factory=lambda: deque(maxlen=1000), repr=False)
413
  started_at: float = field(default_factory=time.time)
 
 
 
414
 
415
- def record_request(self, success: bool, duration_ms: float, chars: int = 0, model: str = ""):
 
416
  with self._lock:
417
  self.total_requests += 1
418
  if success:
@@ -422,48 +474,85 @@ class Metrics:
422
  self.failed_requests += 1
423
  self._latencies.append(duration_ms)
424
  if model:
425
- self.requests_per_model[model] = self.requests_per_model.get(model, 0) + 1
 
 
426
 
427
  def record_retry(self):
428
  with self._lock:
429
  self.total_retries += 1
430
 
 
 
 
 
 
 
431
  def to_dict(self) -> Dict:
432
  with self._lock:
433
- avg = sum(self._latencies) / len(self._latencies) if self._latencies else 0
434
- rate = self.successful_requests / self.total_requests if self.total_requests else 1
 
 
435
  return {
436
- "total_requests": self.total_requests, "successful": self.successful_requests,
437
- "failed": self.failed_requests, "success_rate": round(rate, 4),
438
- "retries": self.total_retries, "chars_received": self.total_chars_received,
439
- "avg_latency_ms": round(avg, 1), "active_streams": self.active_streams,
 
 
 
 
440
  "uptime_s": round(time.time() - self.started_at, 1),
441
  "per_model": dict(self.requests_per_model),
 
 
 
 
442
  }
443
 
 
444
  metrics = Metrics()
445
 
 
 
 
 
446
  class RateLimiter:
447
- def __init__(self, rpm: int = 10, burst: int = 3):
448
- self.rate = rpm / 60.0
 
 
449
  self.max_tokens = float(burst)
450
  self.tokens = float(burst)
451
  self.last_refill = time.monotonic()
452
  self._lock = threading.Lock()
453
 
454
- def acquire(self, timeout: float = 30.0) -> bool:
455
  deadline = time.monotonic() + timeout
456
  while True:
457
  with self._lock:
458
  now = time.monotonic()
459
- self.tokens = min(self.max_tokens, self.tokens + (now - self.last_refill) * self.rate)
 
 
 
 
460
  self.last_refill = now
461
  if self.tokens >= 1.0:
462
  self.tokens -= 1.0
463
  return True
464
  if time.monotonic() >= deadline:
465
  return False
466
- time.sleep(0.1)
 
 
 
 
 
 
 
 
467
 
468
  # ═══════════════════════════════════════════════════════════════
469
  # CIRCUIT BREAKER
@@ -514,7 +603,8 @@ class CircuitBreaker:
514
 
515
  class GradioSSEParser:
516
  @staticmethod
517
- def parse_sse(response: requests.Response, log_raw: bool = False) -> Generator[Dict, None, None]:
 
518
  buffer = ""
519
  for chunk in response.iter_content(chunk_size=None, decode_unicode=True):
520
  if chunk is None:
@@ -554,11 +644,19 @@ class GradioSSEParser:
554
  # ═══════════════════════════════════════════════════════════════
555
 
556
  class ModelProvider(ABC):
557
- def __init__(self, model_def: ModelDef, config: Config):
558
  self.model_def = model_def
559
  self.config = config
 
560
  self.ready = False
561
  self._lock = threading.Lock()
 
 
 
 
 
 
 
562
 
563
  @abstractmethod
564
  def initialize(self) -> bool: ...
@@ -570,10 +668,58 @@ class ModelProvider(ABC):
570
  def generate_stream(self, message: str, **kwargs) -> Generator[str, None, None]:
571
  yield self.generate(message, **kwargs)
572
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
 
574
  class GptOssProvider(ModelProvider):
575
- def __init__(self, model_def, config):
576
- super().__init__(model_def, config)
577
  self._session = requests.Session()
578
  self._rotate()
579
 
@@ -594,31 +740,48 @@ class GptOssProvider(ModelProvider):
594
  return True
595
  self._rotate()
596
  try:
597
- r = self._session.get(f"{self.model_def.space_id}/gradio_api/info", timeout=15)
 
 
598
  self.ready = r.status_code == 200
599
  return self.ready
600
- except:
601
  return False
602
 
603
- def generate(self, message, history=None, system_prompt=None, temperature=None, max_tokens=None, **kw):
 
604
  if not self.ready:
605
  self.initialize()
606
  sys_p = system_prompt or self.config.default_system_prompt
607
- temp = temperature if temperature is not None else self.model_def.default_temperature
 
608
  h = self._hash()
609
- payload = {"data": [message, history or [], sys_p, temp], "event_data": None,
610
- "fn_index": self.model_def.fn_index, "trigger_id": None, "session_hash": h}
611
- r = self._session.post(f"{self.model_def.space_id}/gradio_api/queue/join?",
612
- json=payload, headers={"Content-Type": "application/json"}, timeout=30)
 
 
 
 
 
 
 
 
 
613
  if r.status_code != 200:
614
  raise APIError(f"Queue join failed: {r.status_code}")
615
  data = r.json()
616
  if not data.get("event_id"):
617
- raise APIError(f"No event_id")
618
-
619
- resp = self._session.get(f"{self.model_def.space_id}/gradio_api/queue/data",
620
- params={"session_hash": h}, headers={"Accept": "text/event-stream"},
621
- timeout=self.config.timeout_stream, stream=True)
 
 
 
 
622
  full = ""
623
  for d in GradioSSEParser.parse_sse(resp):
624
  msg = d.get("msg", "")
@@ -635,21 +798,37 @@ class GptOssProvider(ModelProvider):
635
  break
636
  if not full.strip():
637
  raise APIError("Empty response", "EMPTY")
638
- return ResponseCleaner.clean_analysis(full) if self.model_def.clean_analysis else full
 
639
 
640
- def generate_stream(self, message, history=None, system_prompt=None, temperature=None, max_tokens=None, **kw):
 
641
  if not self.ready:
642
  self.initialize()
643
  sys_p = system_prompt or self.config.default_system_prompt
644
- temp = temperature if temperature is not None else self.model_def.default_temperature
 
645
  h = self._hash()
646
- payload = {"data": [message, history or [], sys_p, temp], "event_data": None,
647
- "fn_index": self.model_def.fn_index, "trigger_id": None, "session_hash": h}
648
- self._session.post(f"{self.model_def.space_id}/gradio_api/queue/join?",
649
- json=payload, headers={"Content-Type": "application/json"}, timeout=30)
650
- resp = self._session.get(f"{self.model_def.space_id}/gradio_api/queue/data",
651
- params={"session_hash": h}, headers={"Accept": "text/event-stream"},
652
- timeout=self.config.timeout_stream, stream=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
653
  metrics.active_streams += 1
654
  last = ""
655
  try:
@@ -658,7 +837,7 @@ class GptOssProvider(ModelProvider):
658
  if msg in ("process_generating", "process_completed"):
659
  output = d.get("output", {})
660
  if not output.get("success", True):
661
- raise APIError(f"Gradio error")
662
  raw = GradioSSEParser.extract_text(output)
663
  if raw:
664
  if self.model_def.clean_analysis:
@@ -680,27 +859,35 @@ class GptOssProvider(ModelProvider):
680
 
681
  class GradioClientProvider(ModelProvider):
682
  """Generic provider for all gradio_client based models."""
683
- def __init__(self, model_def, config):
684
- super().__init__(model_def, config)
 
685
  self._client = None
686
  self._chat_counter = 0
687
 
688
  def initialize(self) -> bool:
689
  if not HAS_GRADIO_CLIENT:
690
- raise APIError(f"gradio_client not installed", "MISSING_DEP")
691
  with self._lock:
692
  if self.ready:
693
  return True
694
  try:
695
- log.info(f"Connecting to {self.model_def.space_id}...")
 
 
 
696
  self._client = GradioClient(self.model_def.space_id)
697
  self.ready = True
698
  return True
699
  except Exception as e:
700
- log.error(f"Init failed for {self.model_def.model_id}: {e}")
 
 
 
701
  return False
702
 
703
- def generate(self, message, history=None, system_prompt=None, temperature=None, max_tokens=None, **kw):
 
704
  if not self.ready:
705
  self.initialize()
706
  if not self._client:
@@ -709,50 +896,73 @@ class GradioClientProvider(ModelProvider):
709
  mid = self.model_def.model_id
710
  try:
711
  if mid == "command-a-vision":
712
- max_new = max_tokens or self.model_def.extra_params.get("max_new_tokens", 700)
713
- result = self._client.predict(message={"text": message, "files": []},
714
- max_new_tokens=max_new, api_name=self.model_def.api_name)
 
 
 
 
715
  elif mid == "command-a-translate":
716
- max_new = max_tokens or self.model_def.extra_params.get("max_new_tokens", 700)
717
- result = self._client.predict(message=message, max_new_tokens=max_new,
718
- api_name=self.model_def.api_name)
 
 
 
 
719
  elif mid == "minimax-vl-01":
720
- temp = temperature if temperature is not None else self.model_def.default_temperature
721
- max_tok = max_tokens or self.model_def.extra_params.get("max_tokens", 12800)
722
- top_p = kw.get("top_p", self.model_def.extra_params.get("top_p", 0.9))
723
- result = self._client.predict(message={"text": message, "files": []},
724
- max_tokens=max_tok, temperature=temp, top_p=top_p,
725
- api_name=self.model_def.api_name)
 
 
 
 
 
726
  elif mid == "glm-4.5":
727
  sys_p = system_prompt or self.config.default_system_prompt
728
- temp = temperature if temperature is not None else self.model_def.default_temperature
729
- thinking = kw.get("thinking_enabled", self.model_def.thinking_default)
730
- include = kw.get("include_thinking", self.config.include_thinking)
731
- result = self._client.predict(msg=message, sys_prompt=sys_p,
732
- thinking_enabled=thinking, temperature=temp,
733
- api_name=self.model_def.api_name)
 
 
 
 
 
734
  return self._extract_glm(result, include)
735
  elif mid == "chatgpt":
736
- temp = temperature if temperature is not None else self.model_def.default_temperature
737
- top_p = kw.get("top_p", self.model_def.extra_params.get("top_p", 1.0))
 
 
738
  chat_hist = []
739
  if history:
740
  for pair in history:
741
  if isinstance(pair, (list, tuple)) and len(pair) == 2:
742
  chat_hist.append([str(pair[0]), str(pair[1])])
743
- result = self._client.predict(inputs=message, top_p=top_p, temperature=temp,
744
- chat_counter=self._chat_counter, chatbot=chat_hist,
745
- api_name=self.model_def.api_name)
 
 
746
  self._chat_counter += 1
747
  return ResponseCleaner.extract_chatgpt_text(result)
748
  elif mid == "qwen3-vl":
749
- result = self._client.predict(input_value={"files": None, "text": message},
750
- api_name="/add_message")
 
 
751
  return ResponseCleaner.extract_qwen_text(result)
752
  else:
753
  raise APIError(f"Unknown model handler: {mid}")
754
 
755
- # Default extraction for simple results
756
  if isinstance(result, str):
757
  return result.strip()
758
  if isinstance(result, dict):
@@ -786,28 +996,268 @@ class GradioClientProvider(ModelProvider):
786
  return ResponseCleaner.clean_glm(str(result), include_thinking)
787
 
788
 
789
- # Factory
790
- def create_provider(model_id: str, config: Config) -> ModelProvider:
 
791
  if model_id not in MODEL_REGISTRY:
792
  raise ModelNotFoundError(model_id)
793
  mdef = MODEL_REGISTRY[model_id]
794
  if model_id == "gpt-oss-120b":
795
- return GptOssProvider(mdef, config)
796
- return GradioClientProvider(mdef, config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
797
 
798
  # ═══════════════════════════════════════════════════════════════
799
- # MULTI-MODEL CLIENT
800
  # ═══════════════════════════════════════════════════════════════
801
 
802
  class MultiModelClient:
803
  def __init__(self, config: Config):
804
  self.config = config
805
- self._providers: Dict[str, ModelProvider] = {}
806
  self._lock = threading.Lock()
807
  self._conversations: Dict[str, Conversation] = {}
808
  self._active_conv_id: Optional[str] = None
809
  self._current_model = config.default_model
810
- self.rate_limiter = RateLimiter(config.rate_limit_rpm, config.rate_limit_burst)
811
  self.circuit_breaker = CircuitBreaker()
812
 
813
  @property
@@ -820,45 +1270,74 @@ class MultiModelClient:
820
  raise ModelNotFoundError(m)
821
  self._current_model = m
822
 
823
- def _get_provider(self, model_id: str) -> ModelProvider:
824
- if model_id not in self._providers:
825
  with self._lock:
826
- if model_id not in self._providers:
827
- self._providers[model_id] = create_provider(model_id, self.config)
828
- return self._providers[model_id]
829
-
830
- def _ensure_ready(self, model_id: str) -> ModelProvider:
831
- p = self._get_provider(model_id)
832
- if not p.ready:
833
- if not p.initialize():
834
- raise APIError(f"Cannot init {model_id}", "INIT_FAILED")
835
- return p
 
 
 
 
 
836
 
837
  @property
838
  def active_conversation(self) -> Conversation:
839
  if self._active_conv_id not in self._conversations:
840
- conv = Conversation(system_prompt=self.config.default_system_prompt, model_id=self._current_model)
 
 
 
841
  self._conversations[conv.conversation_id] = conv
842
  self._active_conv_id = conv.conversation_id
843
  return self._conversations[self._active_conv_id]
844
 
845
- def new_conversation(self, system_prompt=None, model_id=None) -> Conversation:
846
- conv = Conversation(system_prompt=system_prompt or self.config.default_system_prompt,
847
- model_id=model_id or self._current_model)
 
 
 
848
  self._conversations[conv.conversation_id] = conv
849
  self._active_conv_id = conv.conversation_id
850
  return conv
851
 
852
  def init_model(self, model_id: str) -> bool:
853
  try:
854
- return self._get_provider(model_id).initialize()
855
- except:
 
856
  return False
857
 
858
- def send_message(self, message: str, *, stream: bool = False, model: Optional[str] = None,
859
- conversation_id: Optional[str] = None, system_prompt: Optional[str] = None,
860
- temperature: Optional[float] = None, max_tokens: Optional[int] = None,
861
- include_thinking: Optional[bool] = None, **kwargs) -> Union[str, Generator]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
862
  model_id = model or self._current_model
863
  if model_id not in MODEL_REGISTRY:
864
  raise ModelNotFoundError(model_id)
@@ -871,9 +1350,10 @@ class MultiModelClient:
871
  if not self.circuit_breaker.can_execute():
872
  raise APIError("Circuit breaker open", "CIRCUIT_OPEN", 503)
873
  if not self.rate_limiter.acquire(timeout=10.0):
874
- raise APIError("Rate limited", "RATE_LIMITED", 429)
875
 
876
- conv = self._conversations.get(conversation_id, self.active_conversation) if conversation_id else self.active_conversation
 
877
  conv.model_id = model_id
878
  if system_prompt:
879
  conv.system_prompt = system_prompt
@@ -881,9 +1361,11 @@ class MultiModelClient:
881
  history = conv.build_gradio_history() if mdef.supports_history else None
882
  conv.add_message("user", message, self.config.max_history_messages)
883
 
884
- eff_temp = temperature if temperature is not None else mdef.default_temperature
 
885
  eff_sys = conv.system_prompt if mdef.supports_system_prompt else None
886
- eff_thinking = include_thinking if include_thinking is not None else self.config.include_thinking
 
887
 
888
  extra = dict(kwargs)
889
  if mdef.supports_thinking:
@@ -894,21 +1376,39 @@ class MultiModelClient:
894
  for attempt in range(self.config.max_retries + 1):
895
  try:
896
  if attempt > 0:
897
- time.sleep(self.config.retry_backoff_base ** attempt + random.uniform(0, self.config.retry_jitter))
 
 
 
898
  metrics.record_retry()
899
 
900
- provider = self._ensure_ready(model_id)
901
 
902
  if stream and mdef.supports_streaming:
903
- gen = provider.generate_stream(message, history=history, system_prompt=eff_sys,
904
- temperature=eff_temp, max_tokens=max_tokens, **extra)
 
 
 
 
 
 
905
  return self._wrap_stream(gen, conv, start, model_id)
906
 
907
- result = provider.generate(message, history=history, system_prompt=eff_sys,
908
- temperature=eff_temp, max_tokens=max_tokens, **extra)
 
 
 
 
 
 
 
909
  dur = (time.monotonic() - start) * 1000
910
  thinking, response = ThinkingParser.split(result)
911
- conv.add_message("assistant", response, self.config.max_history_messages, thinking=thinking)
 
 
912
  metrics.record_request(True, dur, len(result), model_id)
913
  self.circuit_breaker.record_success()
914
  return result
@@ -933,31 +1433,46 @@ class MultiModelClient:
933
  full += chunk
934
  yield chunk
935
  thinking, response = ThinkingParser.split(full)
936
- conv.add_message("assistant", response, self.config.max_history_messages, thinking=thinking)
937
- metrics.record_request(True, (time.monotonic() - start) * 1000, len(full), model_id)
 
 
 
 
 
938
  self.circuit_breaker.record_success()
939
  except Exception:
940
- metrics.record_request(False, (time.monotonic() - start) * 1000, model=model_id)
 
 
941
  self.circuit_breaker.record_failure()
942
  raise
943
 
944
  def get_status(self) -> Dict:
 
 
 
 
945
  return {
946
- "version": VERSION, "current_model": self._current_model,
 
947
  "models": list(MODEL_REGISTRY.keys()),
948
- "providers": {m: "READY" if p.ready else "NOT_READY" for m, p in self._providers.items()},
949
  "conversations": len(self._conversations),
950
  "circuit_breaker": self.circuit_breaker.state,
 
951
  }
952
 
953
  # ═══════════════════════════════════════════════════════════════
954
- # SESSION POOL
955
  # ═══════════════════════════════════════════════════════════════
956
 
957
  class SessionPool:
958
  def __init__(self, config: Config):
959
  self.config = config
960
- self._clients = [MultiModelClient(config) for _ in range(config.pool_size)]
 
 
961
  self._idx = 0
962
  self._lock = threading.Lock()
963
 
@@ -966,7 +1481,10 @@ class SessionPool:
966
  c.init_model(self.config.default_model)
967
 
968
  def init_model(self, model_id: str) -> int:
969
- return sum(1 for c in self._clients if c.init_model(model_id))
 
 
 
970
 
971
  def acquire(self) -> MultiModelClient:
972
  with self._lock:
@@ -980,14 +1498,17 @@ class SessionPool:
980
 
981
  ALIASES = {
982
  "gpt-oss": "gpt-oss-120b", "gptoss": "gpt-oss-120b", "amd": "gpt-oss-120b",
983
- "command-a": "command-a-vision", "command-vision": "command-a-vision", "cohere-vision": "command-a-vision",
984
- "command-translate": "command-a-translate", "cohere-translate": "command-a-translate", "translate": "command-a-translate",
 
 
985
  "minimax": "minimax-vl-01", "minimax-vl": "minimax-vl-01",
986
  "glm": "glm-4.5", "glm4": "glm-4.5", "glm-4": "glm-4.5", "zhipu": "glm-4.5",
987
  "gpt": "chatgpt", "gpt-3.5": "chatgpt", "gpt3": "chatgpt", "openai": "chatgpt",
988
  "qwen": "qwen3-vl", "qwen3": "qwen3-vl", "qwen-vl": "qwen3-vl",
989
  }
990
 
 
991
  def resolve_alias(model_id: str) -> str:
992
  return ALIASES.get(model_id.lower(), model_id)
993
 
@@ -1001,6 +1522,7 @@ pool.init_default()
1001
 
1002
  app = Flask(APP_NAME)
1003
 
 
1004
  @app.after_request
1005
  def cors(response):
1006
  response.headers["Access-Control-Allow-Origin"] = "*"
@@ -1008,15 +1530,19 @@ def cors(response):
1008
  response.headers["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS"
1009
  return response
1010
 
 
1011
  @app.errorhandler(APIError)
1012
  def handle_api_error(e: APIError):
1013
  return jsonify({"ok": False, **e.to_dict()}), e.status
1014
 
 
1015
  @app.route("/")
1016
  def index():
1017
  return jsonify({
1018
- "name": APP_NAME, "version": VERSION,
 
1019
  "default_model": config.default_model,
 
1020
  "models": list(MODEL_REGISTRY.keys()),
1021
  "endpoints": {
1022
  "POST /chat": "Chat with any model",
@@ -1024,11 +1550,13 @@ def index():
1024
  "POST /v1/chat/completions": "OpenAI-compatible",
1025
  "GET /v1/models": "List models",
1026
  "POST /models/init": "Init a model",
1027
- "GET /health": "Health check",
1028
  "GET /metrics": "Metrics",
 
1029
  },
1030
  })
1031
 
 
1032
  @app.route("/chat", methods=["POST"])
1033
  def chat():
1034
  data = freq.get_json(force=True, silent=True) or {}
@@ -1040,17 +1568,26 @@ def chat():
1040
  client = pool.acquire()
1041
  if data.get("new_conversation"):
1042
  client.new_conversation(data.get("system_prompt"), model_id)
1043
- result = client.send_message(message, model=model_id, system_prompt=data.get("system_prompt"),
1044
- temperature=data.get("temperature"), max_tokens=data.get("max_tokens"),
1045
- include_thinking=include_thinking)
 
 
 
 
1046
  thinking, clean = ThinkingParser.split(result)
1047
- resp = {"ok": True, "response": clean, "model": model_id,
1048
- "conversation_id": client.active_conversation.conversation_id,
1049
- "history_size": len(client.active_conversation.messages)}
 
 
 
 
1050
  if thinking:
1051
  resp["thinking"] = thinking
1052
  return jsonify(resp)
1053
 
 
1054
  @app.route("/chat/stream", methods=["POST"])
1055
  def chat_stream():
1056
  data = freq.get_json(force=True, silent=True) or {}
@@ -1068,40 +1605,57 @@ def chat_stream():
1068
  def generate():
1069
  try:
1070
  if use_stream:
1071
- for chunk in client.send_message(message, stream=True, model=model_id,
1072
- system_prompt=data.get("system_prompt"),
1073
- temperature=data.get("temperature"),
1074
- max_tokens=data.get("max_tokens"),
1075
- include_thinking=include_thinking):
 
 
1076
  yield f"data: {json.dumps({'chunk': chunk})}\n\n"
1077
  else:
1078
- result = client.send_message(message, model=model_id,
1079
- system_prompt=data.get("system_prompt"),
1080
- temperature=data.get("temperature"),
1081
- max_tokens=data.get("max_tokens"),
1082
- include_thinking=include_thinking)
 
 
1083
  yield f"data: {json.dumps({'chunk': result})}\n\n"
1084
  yield "data: [DONE]\n\n"
1085
  except APIError as e:
1086
  yield f"data: {json.dumps(e.to_dict())}\n\n"
1087
 
1088
- return Response(stream_with_context(generate()), content_type="text/event-stream")
 
 
1089
 
1090
  @app.route("/v1/models", methods=["GET"])
1091
  def list_models():
1092
  models = []
1093
  for mid, mdef in MODEL_REGISTRY.items():
1094
  models.append({
1095
- "id": mid, "object": "model", "owned_by": mdef.owned_by, "created": 0,
 
 
 
1096
  "description": mdef.description,
1097
  "capabilities": {
1098
- "vision": mdef.supports_vision, "streaming": mdef.supports_streaming,
1099
- "system_prompt": mdef.supports_system_prompt, "temperature": mdef.supports_temperature,
1100
- "history": mdef.supports_history, "thinking": mdef.supports_thinking,
 
 
 
 
 
 
 
1101
  },
1102
  })
1103
  return jsonify({"object": "list", "data": models})
1104
 
 
1105
  @app.route("/v1/chat/completions", methods=["POST", "OPTIONS"])
1106
  def openai_compat():
1107
  if freq.method == "OPTIONS":
@@ -1115,7 +1669,12 @@ def openai_compat():
1115
  include_thinking = data.get("include_thinking", config.include_thinking)
1116
 
1117
  if model_id not in MODEL_REGISTRY:
1118
- return jsonify({"error": {"message": f"Model '{model_id}' not found", "type": "invalid_request_error"}}), 404
 
 
 
 
 
1119
  if not messages:
1120
  return jsonify({"error": {"message": "messages required"}}), 400
1121
 
@@ -1145,68 +1704,160 @@ def openai_compat():
1145
  if do_stream:
1146
  def generate():
1147
  try:
1148
- yield f"data: {json.dumps({'id': rid, 'object': 'chat.completion.chunk', 'created': created, 'model': model_id, 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n"
 
 
 
 
 
1149
  if mdef.supports_streaming:
1150
- for chunk in client.send_message(user_msg, stream=True, model=model_id,
1151
- temperature=temperature, max_tokens=max_tokens,
1152
- include_thinking=include_thinking):
1153
- yield f"data: {json.dumps({'id': rid, 'object': 'chat.completion.chunk', 'created': created, 'model': model_id, 'choices': [{'index': 0, 'delta': {'content': chunk}, 'finish_reason': None}]})}\n\n"
 
 
 
 
 
 
 
 
 
1154
  else:
1155
- result = client.send_message(user_msg, model=model_id, temperature=temperature,
1156
- max_tokens=max_tokens, include_thinking=include_thinking)
1157
- yield f"data: {json.dumps({'id': rid, 'object': 'chat.completion.chunk', 'created': created, 'model': model_id, 'choices': [{'index': 0, 'delta': {'content': result}, 'finish_reason': None}]})}\n\n"
1158
- yield f"data: {json.dumps({'id': rid, 'object': 'chat.completion.chunk', 'created': created, 'model': model_id, 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1159
  yield "data: [DONE]\n\n"
1160
  except Exception as e:
1161
  yield f"data: {json.dumps({'error': {'message': str(e)}})}\n\n"
1162
- return Response(stream_with_context(generate()), content_type="text/event-stream")
1163
 
1164
- result = client.send_message(user_msg, model=model_id, temperature=temperature,
1165
- max_tokens=max_tokens, include_thinking=include_thinking)
 
 
 
 
 
1166
  return jsonify({
1167
- "id": rid, "object": "chat.completion", "created": created, "model": model_id,
1168
- "choices": [{"index": 0, "message": {"role": "assistant", "content": result}, "finish_reason": "stop"}],
1169
- "usage": {"prompt_tokens": len(user_msg) // 4, "completion_tokens": len(result) // 4,
1170
- "total_tokens": (len(user_msg) + len(result)) // 4},
 
 
 
 
 
 
 
 
 
 
1171
  })
1172
 
 
1173
  @app.route("/new", methods=["POST"])
1174
  def new_conv():
1175
  data = freq.get_json(force=True, silent=True) or {}
1176
  model_id = resolve_alias(data.get("model", config.default_model))
1177
  client = pool.acquire()
1178
  conv = client.new_conversation(data.get("system_prompt"), model_id)
1179
- return jsonify({"ok": True, "conversation_id": conv.conversation_id, "model": model_id})
 
 
 
 
 
1180
 
1181
  @app.route("/health", methods=["GET"])
1182
  def health():
1183
  client = pool.acquire()
1184
  return jsonify(client.get_status())
1185
 
 
1186
  @app.route("/metrics", methods=["GET"])
1187
  def metrics_endpoint():
1188
  return jsonify(metrics.to_dict())
1189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1190
  @app.route("/conversations", methods=["GET"])
1191
  def conversations():
1192
  client = pool.acquire()
1193
- return jsonify({"conversations": [c.to_dict() for c in client._conversations.values()]})
 
 
 
1194
 
1195
  @app.route("/models/init", methods=["POST"])
1196
  def init_model_ep():
1197
  data = freq.get_json(force=True, silent=True) or {}
1198
  model_id = resolve_alias(data.get("model", ""))
1199
  if not model_id or model_id not in MODEL_REGISTRY:
1200
- return jsonify({"ok": False, "error": f"Unknown model. Available: {list(MODEL_REGISTRY.keys())}"}), 400
 
 
 
1201
  count = pool.init_model(model_id)
1202
- return jsonify({"ok": True, "model": model_id, "initialized_clients": count})
 
 
 
 
 
 
 
 
1203
 
1204
  # ═══════════════════════════════════════════════════════════════
1205
- # ENTRY POINT (for HuggingFace Spaces)
1206
  # ═══════════════════════════════════════════════════════════════
1207
 
1208
  if __name__ == "__main__":
1209
  port = int(os.environ.get("PORT", 7860))
1210
  log.info(f"Starting Multi-Model AI API v{VERSION} on port {port}")
1211
  log.info(f"Models: {list(MODEL_REGISTRY.keys())}")
 
 
 
 
 
 
 
 
1212
  app.run(host="0.0.0.0", port=port, threaded=True)
 
1
  #!/usr/bin/env python3
2
  """
3
  Multi-Model AI API β€” HuggingFace Spaces Edition
4
+ With load balancing (multiple provider instances per model) and 10 req/s rate limiting.
5
  """
6
 
7
  import re, os, json, uuid, time, random, string, logging, threading
 
23
  # CONFIG & CONSTANTS
24
  # ═══════════════════════════════════════════════════════════════
25
 
26
+ VERSION = "2.3.0-hf-lb"
27
  APP_NAME = "Multi-Model-AI-API"
28
  DEFAULT_SYSTEM_PROMPT = "You are a helpful, friendly AI assistant."
29
  DEFAULT_MODEL = "gpt-oss-120b"
 
34
  USER_AGENTS = [
35
  "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Chrome/144.0.0.0 Safari/537.36",
36
  "Mozilla/5.0 (Macintosh; Intel Mac OS X 14_5) AppleWebKit/605.1.15 Safari/605.1.15",
37
+ "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 Chrome/143.0.0.0 Safari/537.36",
38
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:128.0) Gecko/20100101 Firefox/128.0",
39
  ]
40
 
41
  # ═══════════════════════════════════════════════════════════════
 
63
  api_name: Optional[str] = None
64
  extra_params: Dict[str, Any] = field(default_factory=dict)
65
  clean_analysis: bool = False
66
+ # Load balancing config per model
67
+ lb_pool_size: int = 2 # number of provider instances for load balancing
68
+ lb_enabled: bool = True # whether load balancing is enabled
69
 
70
  MODEL_REGISTRY: Dict[str, ModelDef] = {}
71
 
72
+
73
  def register_model(m: ModelDef):
74
  MODEL_REGISTRY[m.model_id] = m
75
 
76
+
77
  def _init_registry():
78
  register_model(ModelDef(
79
  model_id="gpt-oss-120b", display_name="AMD GPT-OSS-120B",
 
81
  owned_by="amd", description="AMD open-source 120B model",
82
  fn_index=8, clean_analysis=True, default_temperature=0.0,
83
  supports_vision=False, supports_thinking=False,
84
+ lb_pool_size=3, lb_enabled=True,
85
  ))
86
  register_model(ModelDef(
87
  model_id="command-a-vision", display_name="Cohere Command-A Vision",
 
91
  supports_temperature=False, supports_streaming=False, supports_history=False,
92
  supports_thinking=False, max_tokens_default=700,
93
  extra_params={"max_new_tokens": 700},
94
+ lb_pool_size=2, lb_enabled=True,
95
  ))
96
  register_model(ModelDef(
97
  model_id="command-a-translate", display_name="Cohere Command-A Translate",
 
101
  supports_temperature=False, supports_streaming=False, supports_history=False,
102
  supports_thinking=False, max_tokens_default=700,
103
  extra_params={"max_new_tokens": 700},
104
+ lb_pool_size=1, lb_enabled=False, # NO load balancing for translate
105
  ))
106
  register_model(ModelDef(
107
  model_id="minimax-vl-01", display_name="MiniMax VL-01",
 
111
  supports_temperature=True, supports_streaming=False, supports_history=False,
112
  supports_thinking=False, max_tokens_default=12800, default_temperature=0.1,
113
  extra_params={"max_tokens": 12800, "top_p": 0.9},
114
+ lb_pool_size=2, lb_enabled=True,
115
  ))
116
  register_model(ModelDef(
117
  model_id="glm-4.5", display_name="GLM-4.5 (ZhipuAI)",
 
121
  supports_temperature=True, supports_streaming=False, supports_history=False,
122
  supports_thinking=True, thinking_default=True, default_temperature=1.0,
123
  extra_params={"thinking_enabled": True},
124
+ lb_pool_size=2, lb_enabled=True,
125
  ))
126
  register_model(ModelDef(
127
  model_id="chatgpt", display_name="ChatGPT (Community)",
 
131
  supports_temperature=True, supports_streaming=False, supports_history=True,
132
  supports_thinking=False, default_temperature=1.0,
133
  extra_params={"top_p": 1.0},
134
+ lb_pool_size=2, lb_enabled=True,
135
  ))
136
  register_model(ModelDef(
137
  model_id="qwen3-vl", display_name="Qwen3-VL (Alibaba)",
 
140
  api_name="/add_message", supports_vision=True, supports_system_prompt=False,
141
  supports_temperature=False, supports_streaming=False, supports_history=False,
142
  supports_thinking=False, max_tokens_default=4096,
143
+ lb_pool_size=2, lb_enabled=True,
144
  ))
145
 
146
+
147
  _init_registry()
148
 
149
  # ═══════════════════════════════════════════════════════════════
 
158
  max_retries: int = 3
159
  retry_backoff_base: float = 1.5
160
  retry_jitter: float = 0.5
161
+ rate_limit_rps: int = 10 # requests per SECOND (changed from RPM)
162
+ rate_limit_burst: int = 15 # burst capacity
163
  pool_size: int = 2
164
  max_history_messages: int = 50
165
  max_message_length: int = 10000
 
173
  env_map = {
174
  "MMAI_TIMEOUT": ("timeout_stream", int),
175
  "MMAI_MAX_RETRIES": ("max_retries", int),
176
+ "MMAI_RATE_LIMIT_RPS": ("rate_limit_rps", int),
177
+ "MMAI_RATE_LIMIT_BURST": ("rate_limit_burst", int),
178
  "MMAI_POOL_SIZE": ("pool_size", int),
179
  "MMAI_SYSTEM_PROMPT": ("default_system_prompt", str),
180
  "MMAI_TEMPERATURE": ("default_temperature", float),
181
  "MMAI_DEFAULT_MODEL": ("default_model", str),
182
+ "MMAI_INCLUDE_THINKING": ("include_thinking",
183
+ lambda x: x.lower() in ("1", "true")),
184
  }
185
  for env_key, (attr, conv) in env_map.items():
186
  val = os.environ.get(env_key)
 
200
  super().__init__(message)
201
  self.code = code
202
  self.status = status
203
+
204
  def to_dict(self):
205
  return {"error": str(self), "code": self.code}
206
 
207
+
208
  class ModelNotFoundError(APIError):
209
  def __init__(self, model_id: str):
210
+ super().__init__(
211
+ f"Model '{model_id}' not found. Available: {list(MODEL_REGISTRY.keys())}",
212
+ "MODEL_NOT_FOUND", 404,
213
+ )
214
 
215
  # ═══════════════════════════════════════════════════════════════
216
  # RESPONSE CLEANER
 
254
  }
255
  for entity, char in entities.items():
256
  text = text.replace(entity, char)
257
+ text = re.sub(r'&#x([0-9a-fA-F]+);',
258
+ lambda m: chr(int(m.group(1), 16)), text)
259
  text = re.sub(r'&#(\d+);', lambda m: chr(int(m.group(1))), text)
260
  return text
261
 
 
272
  if '<details' not in text and '<div' not in text:
273
  return text.strip()
274
  thinking_text = ""
275
+ thinking_match = re.search(
276
+ r'<details[^>]*>.*?<div[^>]*>(.*?)</div>\s*</details>',
277
+ text, re.DOTALL | re.IGNORECASE,
278
+ )
279
  if thinking_match:
280
  thinking_text = cls._strip_html(thinking_match.group(1)).strip()
281
+ text_without_details = re.sub(
282
+ r'<details[^>]*>.*?</details>', '', text,
283
+ flags=re.DOTALL | re.IGNORECASE,
284
+ ).strip()
285
+ div_match = re.search(
286
+ r"<div[^>]*>\s*(.*?)\s*</div>",
287
+ text_without_details, re.DOTALL | re.IGNORECASE,
288
+ )
289
+ response_text = (
290
+ cls._strip_html(div_match.group(1)).strip()
291
+ if div_match
292
+ else cls._strip_html(text_without_details).strip()
293
+ )
294
  if thinking_text and include_thinking:
295
  return f"<thinking>\n{thinking_text}\n</thinking>\n{response_text}"
296
  return response_text
 
343
  return str(result)
344
 
345
  @classmethod
346
+ def clean(cls, text: str, model_id: str = "",
347
+ include_thinking: bool = True) -> str:
348
  if not text:
349
  return text
350
  text = text.strip()
 
363
  class ThinkingParser:
364
  @staticmethod
365
  def split(text: str) -> Tuple[Optional[str], str]:
366
+ match = re.match(
367
+ r'\s*<thinking>\s*\n?(.*?)\n?\s*</thinking>\s*\n?(.*)',
368
+ text, re.DOTALL | re.IGNORECASE,
369
+ )
370
  if match:
371
  thinking = match.group(1).strip()
372
  response = match.group(2).strip()
 
391
  timestamp: float = field(default_factory=time.time)
392
  message_id: str = field(default_factory=lambda: str(uuid.uuid4()))
393
 
394
+
395
  @dataclass
396
  class Conversation:
397
  conversation_id: str = field(default_factory=lambda: str(uuid.uuid4()))
 
402
  system_prompt: str = DEFAULT_SYSTEM_PROMPT
403
  model_id: str = DEFAULT_MODEL
404
 
405
+ def add_message(self, role: str, content: str,
406
+ max_messages: int = 50,
407
+ thinking: Optional[str] = None) -> Message:
408
  msg = Message(role=role, content=content, thinking=thinking)
409
  self.messages.append(msg)
410
  self.updated_at = time.time()
 
421
  non_system = [m for m in self.messages if m.role != "system"]
422
  i = 0
423
  while i < len(non_system) - 1:
424
+ if (non_system[i].role == "user"
425
+ and i + 1 < len(non_system)
426
+ and non_system[i + 1].role == "assistant"):
427
  history.append([non_system[i].content, non_system[i + 1].content])
428
  i += 2
429
  else:
 
435
 
436
  def to_dict(self) -> Dict:
437
  return {
438
+ "conversation_id": self.conversation_id,
439
+ "title": self.title,
440
+ "model": self.model_id,
441
+ "message_count": len(self.messages),
442
+ "created_at": self.created_at,
443
+ "updated_at": self.updated_at,
444
  }
445
 
446
  # ═══════════════════════════════════════════════════════════════
447
+ # METRICS
448
  # ═══════════════════════════════════════════════════════════════
449
 
450
  @dataclass
 
459
  requests_per_model: Dict[str, int] = field(default_factory=dict)
460
  _latencies: deque = field(default_factory=lambda: deque(maxlen=1000), repr=False)
461
  started_at: float = field(default_factory=time.time)
462
+ # Load balancer metrics
463
+ lb_total_dispatches: int = 0
464
+ lb_failovers: int = 0
465
 
466
+ def record_request(self, success: bool, duration_ms: float,
467
+ chars: int = 0, model: str = ""):
468
  with self._lock:
469
  self.total_requests += 1
470
  if success:
 
474
  self.failed_requests += 1
475
  self._latencies.append(duration_ms)
476
  if model:
477
+ self.requests_per_model[model] = (
478
+ self.requests_per_model.get(model, 0) + 1
479
+ )
480
 
481
  def record_retry(self):
482
  with self._lock:
483
  self.total_retries += 1
484
 
485
+ def record_lb_dispatch(self, failover: bool = False):
486
+ with self._lock:
487
+ self.lb_total_dispatches += 1
488
+ if failover:
489
+ self.lb_failovers += 1
490
+
491
  def to_dict(self) -> Dict:
492
  with self._lock:
493
+ avg = (sum(self._latencies) / len(self._latencies)
494
+ if self._latencies else 0)
495
+ rate = (self.successful_requests / self.total_requests
496
+ if self.total_requests else 1)
497
  return {
498
+ "total_requests": self.total_requests,
499
+ "successful": self.successful_requests,
500
+ "failed": self.failed_requests,
501
+ "success_rate": round(rate, 4),
502
+ "retries": self.total_retries,
503
+ "chars_received": self.total_chars_received,
504
+ "avg_latency_ms": round(avg, 1),
505
+ "active_streams": self.active_streams,
506
  "uptime_s": round(time.time() - self.started_at, 1),
507
  "per_model": dict(self.requests_per_model),
508
+ "load_balancer": {
509
+ "total_dispatches": self.lb_total_dispatches,
510
+ "failovers": self.lb_failovers,
511
+ },
512
  }
513
 
514
+
515
  metrics = Metrics()
516
 
517
+ # ═══════════════════════════════════════════════════════════════
518
+ # RATE LIMITER β€” 10 requests per SECOND (token bucket)
519
+ # ═══════════════════════════════════════════════════════════════
520
+
521
  class RateLimiter:
522
+ """Token-bucket rate limiter. Default: 10 requests/second with burst."""
523
+
524
+ def __init__(self, rps: int = 10, burst: int = 15):
525
+ self.rate = float(rps) # tokens per second
526
  self.max_tokens = float(burst)
527
  self.tokens = float(burst)
528
  self.last_refill = time.monotonic()
529
  self._lock = threading.Lock()
530
 
531
+ def acquire(self, timeout: float = 10.0) -> bool:
532
  deadline = time.monotonic() + timeout
533
  while True:
534
  with self._lock:
535
  now = time.monotonic()
536
+ elapsed = now - self.last_refill
537
+ self.tokens = min(
538
+ self.max_tokens,
539
+ self.tokens + elapsed * self.rate,
540
+ )
541
  self.last_refill = now
542
  if self.tokens >= 1.0:
543
  self.tokens -= 1.0
544
  return True
545
  if time.monotonic() >= deadline:
546
  return False
547
+ time.sleep(0.05) # short sleep for per-second limiting
548
+
549
+ def get_info(self) -> Dict:
550
+ with self._lock:
551
+ return {
552
+ "rate_rps": self.rate,
553
+ "burst": self.max_tokens,
554
+ "available_tokens": round(self.tokens, 2),
555
+ }
556
 
557
  # ═══════════════════════════════════════════════════════════════
558
  # CIRCUIT BREAKER
 
603
 
604
  class GradioSSEParser:
605
  @staticmethod
606
+ def parse_sse(response: requests.Response,
607
+ log_raw: bool = False) -> Generator[Dict, None, None]:
608
  buffer = ""
609
  for chunk in response.iter_content(chunk_size=None, decode_unicode=True):
610
  if chunk is None:
 
644
  # ═══════════════════════════════════════════════════════════════
645
 
646
  class ModelProvider(ABC):
647
+ def __init__(self, model_def: ModelDef, config: Config, instance_id: int = 0):
648
  self.model_def = model_def
649
  self.config = config
650
+ self.instance_id = instance_id
651
  self.ready = False
652
  self._lock = threading.Lock()
653
+ # Per-instance health tracking
654
+ self._consecutive_failures = 0
655
+ self._last_success_time = 0.0
656
+ self._last_failure_time = 0.0
657
+ self._total_requests = 0
658
+ self._total_failures = 0
659
+ self._latencies: deque = deque(maxlen=50)
660
 
661
  @abstractmethod
662
  def initialize(self) -> bool: ...
 
668
  def generate_stream(self, message: str, **kwargs) -> Generator[str, None, None]:
669
  yield self.generate(message, **kwargs)
670
 
671
+ def record_success(self, latency_ms: float):
672
+ self._consecutive_failures = 0
673
+ self._last_success_time = time.time()
674
+ self._total_requests += 1
675
+ self._latencies.append(latency_ms)
676
+
677
+ def record_failure(self):
678
+ self._consecutive_failures += 1
679
+ self._last_failure_time = time.time()
680
+ self._total_requests += 1
681
+ self._total_failures += 1
682
+
683
+ @property
684
+ def avg_latency(self) -> float:
685
+ return sum(self._latencies) / len(self._latencies) if self._latencies else 0.0
686
+
687
+ @property
688
+ def health_score(self) -> float:
689
+ """0.0 (worst) to 1.0 (best). Used by load balancer to pick instance."""
690
+ if not self.ready:
691
+ return 0.0
692
+ score = 1.0
693
+ # Penalise consecutive failures
694
+ score -= min(self._consecutive_failures * 0.2, 0.8)
695
+ # Penalise high avg latency (>10s = bad)
696
+ if self._latencies:
697
+ avg = self.avg_latency
698
+ if avg > 10000:
699
+ score -= 0.3
700
+ elif avg > 5000:
701
+ score -= 0.15
702
+ # Penalise high failure rate
703
+ if self._total_requests > 5:
704
+ fail_rate = self._total_failures / self._total_requests
705
+ score -= fail_rate * 0.4
706
+ return max(0.0, min(1.0, score))
707
+
708
+ def get_instance_info(self) -> Dict:
709
+ return {
710
+ "instance_id": self.instance_id,
711
+ "ready": self.ready,
712
+ "health_score": round(self.health_score, 3),
713
+ "consecutive_failures": self._consecutive_failures,
714
+ "total_requests": self._total_requests,
715
+ "total_failures": self._total_failures,
716
+ "avg_latency_ms": round(self.avg_latency, 1),
717
+ }
718
+
719
 
720
  class GptOssProvider(ModelProvider):
721
+ def __init__(self, model_def, config, instance_id=0):
722
+ super().__init__(model_def, config, instance_id)
723
  self._session = requests.Session()
724
  self._rotate()
725
 
 
740
  return True
741
  self._rotate()
742
  try:
743
+ r = self._session.get(
744
+ f"{self.model_def.space_id}/gradio_api/info", timeout=15,
745
+ )
746
  self.ready = r.status_code == 200
747
  return self.ready
748
+ except Exception:
749
  return False
750
 
751
+ def generate(self, message, history=None, system_prompt=None,
752
+ temperature=None, max_tokens=None, **kw):
753
  if not self.ready:
754
  self.initialize()
755
  sys_p = system_prompt or self.config.default_system_prompt
756
+ temp = (temperature if temperature is not None
757
+ else self.model_def.default_temperature)
758
  h = self._hash()
759
+ payload = {
760
+ "data": [message, history or [], sys_p, temp],
761
+ "event_data": None,
762
+ "fn_index": self.model_def.fn_index,
763
+ "trigger_id": None,
764
+ "session_hash": h,
765
+ }
766
+ r = self._session.post(
767
+ f"{self.model_def.space_id}/gradio_api/queue/join?",
768
+ json=payload,
769
+ headers={"Content-Type": "application/json"},
770
+ timeout=30,
771
+ )
772
  if r.status_code != 200:
773
  raise APIError(f"Queue join failed: {r.status_code}")
774
  data = r.json()
775
  if not data.get("event_id"):
776
+ raise APIError("No event_id")
777
+
778
+ resp = self._session.get(
779
+ f"{self.model_def.space_id}/gradio_api/queue/data",
780
+ params={"session_hash": h},
781
+ headers={"Accept": "text/event-stream"},
782
+ timeout=self.config.timeout_stream,
783
+ stream=True,
784
+ )
785
  full = ""
786
  for d in GradioSSEParser.parse_sse(resp):
787
  msg = d.get("msg", "")
 
798
  break
799
  if not full.strip():
800
  raise APIError("Empty response", "EMPTY")
801
+ return (ResponseCleaner.clean_analysis(full)
802
+ if self.model_def.clean_analysis else full)
803
 
804
+ def generate_stream(self, message, history=None, system_prompt=None,
805
+ temperature=None, max_tokens=None, **kw):
806
  if not self.ready:
807
  self.initialize()
808
  sys_p = system_prompt or self.config.default_system_prompt
809
+ temp = (temperature if temperature is not None
810
+ else self.model_def.default_temperature)
811
  h = self._hash()
812
+ payload = {
813
+ "data": [message, history or [], sys_p, temp],
814
+ "event_data": None,
815
+ "fn_index": self.model_def.fn_index,
816
+ "trigger_id": None,
817
+ "session_hash": h,
818
+ }
819
+ self._session.post(
820
+ f"{self.model_def.space_id}/gradio_api/queue/join?",
821
+ json=payload,
822
+ headers={"Content-Type": "application/json"},
823
+ timeout=30,
824
+ )
825
+ resp = self._session.get(
826
+ f"{self.model_def.space_id}/gradio_api/queue/data",
827
+ params={"session_hash": h},
828
+ headers={"Accept": "text/event-stream"},
829
+ timeout=self.config.timeout_stream,
830
+ stream=True,
831
+ )
832
  metrics.active_streams += 1
833
  last = ""
834
  try:
 
837
  if msg in ("process_generating", "process_completed"):
838
  output = d.get("output", {})
839
  if not output.get("success", True):
840
+ raise APIError("Gradio error")
841
  raw = GradioSSEParser.extract_text(output)
842
  if raw:
843
  if self.model_def.clean_analysis:
 
859
 
860
  class GradioClientProvider(ModelProvider):
861
  """Generic provider for all gradio_client based models."""
862
+
863
+ def __init__(self, model_def, config, instance_id=0):
864
+ super().__init__(model_def, config, instance_id)
865
  self._client = None
866
  self._chat_counter = 0
867
 
868
  def initialize(self) -> bool:
869
  if not HAS_GRADIO_CLIENT:
870
+ raise APIError("gradio_client not installed", "MISSING_DEP")
871
  with self._lock:
872
  if self.ready:
873
  return True
874
  try:
875
+ log.info(
876
+ f"[Instance {self.instance_id}] Connecting to "
877
+ f"{self.model_def.space_id}..."
878
+ )
879
  self._client = GradioClient(self.model_def.space_id)
880
  self.ready = True
881
  return True
882
  except Exception as e:
883
+ log.error(
884
+ f"[Instance {self.instance_id}] Init failed for "
885
+ f"{self.model_def.model_id}: {e}"
886
+ )
887
  return False
888
 
889
+ def generate(self, message, history=None, system_prompt=None,
890
+ temperature=None, max_tokens=None, **kw):
891
  if not self.ready:
892
  self.initialize()
893
  if not self._client:
 
896
  mid = self.model_def.model_id
897
  try:
898
  if mid == "command-a-vision":
899
+ max_new = (max_tokens
900
+ or self.model_def.extra_params.get("max_new_tokens", 700))
901
+ result = self._client.predict(
902
+ message={"text": message, "files": []},
903
+ max_new_tokens=max_new,
904
+ api_name=self.model_def.api_name,
905
+ )
906
  elif mid == "command-a-translate":
907
+ max_new = (max_tokens
908
+ or self.model_def.extra_params.get("max_new_tokens", 700))
909
+ result = self._client.predict(
910
+ message=message,
911
+ max_new_tokens=max_new,
912
+ api_name=self.model_def.api_name,
913
+ )
914
  elif mid == "minimax-vl-01":
915
+ temp = (temperature if temperature is not None
916
+ else self.model_def.default_temperature)
917
+ max_tok = (max_tokens
918
+ or self.model_def.extra_params.get("max_tokens", 12800))
919
+ top_p = kw.get("top_p",
920
+ self.model_def.extra_params.get("top_p", 0.9))
921
+ result = self._client.predict(
922
+ message={"text": message, "files": []},
923
+ max_tokens=max_tok, temperature=temp, top_p=top_p,
924
+ api_name=self.model_def.api_name,
925
+ )
926
  elif mid == "glm-4.5":
927
  sys_p = system_prompt or self.config.default_system_prompt
928
+ temp = (temperature if temperature is not None
929
+ else self.model_def.default_temperature)
930
+ thinking = kw.get("thinking_enabled",
931
+ self.model_def.thinking_default)
932
+ include = kw.get("include_thinking",
933
+ self.config.include_thinking)
934
+ result = self._client.predict(
935
+ msg=message, sys_prompt=sys_p,
936
+ thinking_enabled=thinking, temperature=temp,
937
+ api_name=self.model_def.api_name,
938
+ )
939
  return self._extract_glm(result, include)
940
  elif mid == "chatgpt":
941
+ temp = (temperature if temperature is not None
942
+ else self.model_def.default_temperature)
943
+ top_p = kw.get("top_p",
944
+ self.model_def.extra_params.get("top_p", 1.0))
945
  chat_hist = []
946
  if history:
947
  for pair in history:
948
  if isinstance(pair, (list, tuple)) and len(pair) == 2:
949
  chat_hist.append([str(pair[0]), str(pair[1])])
950
+ result = self._client.predict(
951
+ inputs=message, top_p=top_p, temperature=temp,
952
+ chat_counter=self._chat_counter, chatbot=chat_hist,
953
+ api_name=self.model_def.api_name,
954
+ )
955
  self._chat_counter += 1
956
  return ResponseCleaner.extract_chatgpt_text(result)
957
  elif mid == "qwen3-vl":
958
+ result = self._client.predict(
959
+ input_value={"files": None, "text": message},
960
+ api_name="/add_message",
961
+ )
962
  return ResponseCleaner.extract_qwen_text(result)
963
  else:
964
  raise APIError(f"Unknown model handler: {mid}")
965
 
 
966
  if isinstance(result, str):
967
  return result.strip()
968
  if isinstance(result, dict):
 
996
  return ResponseCleaner.clean_glm(str(result), include_thinking)
997
 
998
 
999
+ # Factory β€” creates a single provider instance
1000
+ def create_provider(model_id: str, config: Config,
1001
+ instance_id: int = 0) -> ModelProvider:
1002
  if model_id not in MODEL_REGISTRY:
1003
  raise ModelNotFoundError(model_id)
1004
  mdef = MODEL_REGISTRY[model_id]
1005
  if model_id == "gpt-oss-120b":
1006
+ return GptOssProvider(mdef, config, instance_id)
1007
+ return GradioClientProvider(mdef, config, instance_id)
1008
+
1009
+ # ═══════════════════════════════════════════════════════════════
1010
+ # LOAD BALANCER β€” Per-model provider pool with health-aware
1011
+ # round-robin + failover
1012
+ # ═══════════════════════════════════════════════════════════════
1013
+
1014
+ class LoadBalancedProviderPool:
1015
+ """
1016
+ Manages multiple provider instances for a single model.
1017
+ Selects the best instance based on health score with
1018
+ weighted-random selection (healthier instances chosen more).
1019
+ Falls back through all instances on failure.
1020
+ """
1021
+
1022
+ def __init__(self, model_id: str, config: Config):
1023
+ self.model_id = model_id
1024
+ self.config = config
1025
+ self.mdef = MODEL_REGISTRY[model_id]
1026
+ pool_size = self.mdef.lb_pool_size if self.mdef.lb_enabled else 1
1027
+ self._instances: List[ModelProvider] = []
1028
+ self._rr_index = 0
1029
+ self._lock = threading.Lock()
1030
+
1031
+ for i in range(pool_size):
1032
+ self._instances.append(create_provider(model_id, config, instance_id=i))
1033
+
1034
+ log.info(
1035
+ f"[LB] Created pool for '{model_id}' with {len(self._instances)} "
1036
+ f"instance(s), lb_enabled={self.mdef.lb_enabled}"
1037
+ )
1038
+
1039
+ @property
1040
+ def pool_size(self) -> int:
1041
+ return len(self._instances)
1042
+
1043
+ def initialize_all(self) -> int:
1044
+ """Initialize all instances, return count of successful ones."""
1045
+ ok = 0
1046
+ for inst in self._instances:
1047
+ try:
1048
+ if inst.initialize():
1049
+ ok += 1
1050
+ except Exception as e:
1051
+ log.warning(
1052
+ f"[LB] Failed to init {self.model_id} "
1053
+ f"instance {inst.instance_id}: {e}"
1054
+ )
1055
+ return ok
1056
+
1057
+ def initialize_one(self) -> bool:
1058
+ """Initialize at least one instance."""
1059
+ for inst in self._instances:
1060
+ try:
1061
+ if inst.initialize():
1062
+ return True
1063
+ except Exception:
1064
+ continue
1065
+ return False
1066
+
1067
+ def _select_instance(self) -> ModelProvider:
1068
+ """
1069
+ Select best available instance.
1070
+ Strategy: weighted random by health score.
1071
+ If all have equal scores, falls back to round-robin.
1072
+ """
1073
+ if len(self._instances) == 1:
1074
+ return self._instances[0]
1075
+
1076
+ with self._lock:
1077
+ # Collect health scores
1078
+ scored = []
1079
+ for inst in self._instances:
1080
+ score = inst.health_score
1081
+ # Give a minimum weight so unhealthy instances can still recover
1082
+ scored.append((inst, max(score, 0.05)))
1083
+
1084
+ total_weight = sum(s for _, s in scored)
1085
+ if total_weight <= 0:
1086
+ # All dead, just round-robin
1087
+ inst = self._instances[self._rr_index % len(self._instances)]
1088
+ self._rr_index += 1
1089
+ return inst
1090
+
1091
+ # Weighted random selection
1092
+ r = random.uniform(0, total_weight)
1093
+ cumulative = 0.0
1094
+ for inst, weight in scored:
1095
+ cumulative += weight
1096
+ if r <= cumulative:
1097
+ return inst
1098
+
1099
+ # Fallback
1100
+ return scored[-1][0]
1101
+
1102
+ def _get_ordered_instances(self) -> List[ModelProvider]:
1103
+ """Return instances ordered by health score (best first)."""
1104
+ return sorted(self._instances, key=lambda p: p.health_score, reverse=True)
1105
+
1106
+ def execute(self, fn_name: str, **kwargs) -> Any:
1107
+ """
1108
+ Execute a provider method with automatic failover.
1109
+ Tries the best instance first, fails over to others.
1110
+ """
1111
+ primary = self._select_instance()
1112
+ metrics.record_lb_dispatch()
1113
+
1114
+ # Ensure primary is ready
1115
+ if not primary.ready:
1116
+ try:
1117
+ primary.initialize()
1118
+ except Exception:
1119
+ pass
1120
+
1121
+ # Try primary
1122
+ start = time.monotonic()
1123
+ try:
1124
+ result = self._call_provider(primary, fn_name, **kwargs)
1125
+ latency = (time.monotonic() - start) * 1000
1126
+ primary.record_success(latency)
1127
+ return result
1128
+ except Exception as primary_err:
1129
+ primary.record_failure()
1130
+ log.warning(
1131
+ f"[LB] Primary instance {primary.instance_id} for "
1132
+ f"'{self.model_id}' failed: {primary_err}"
1133
+ )
1134
+
1135
+ # Failover through remaining instances
1136
+ for inst in self._get_ordered_instances():
1137
+ if inst is primary:
1138
+ continue
1139
+ if not inst.ready:
1140
+ try:
1141
+ inst.initialize()
1142
+ except Exception:
1143
+ continue
1144
+
1145
+ metrics.record_lb_dispatch(failover=True)
1146
+ start = time.monotonic()
1147
+ try:
1148
+ result = self._call_provider(inst, fn_name, **kwargs)
1149
+ latency = (time.monotonic() - start) * 1000
1150
+ inst.record_success(latency)
1151
+ log.info(
1152
+ f"[LB] Failover to instance {inst.instance_id} "
1153
+ f"for '{self.model_id}' succeeded"
1154
+ )
1155
+ return result
1156
+ except Exception as e:
1157
+ inst.record_failure()
1158
+ log.warning(
1159
+ f"[LB] Failover instance {inst.instance_id} "
1160
+ f"for '{self.model_id}' also failed: {e}"
1161
+ )
1162
+
1163
+ raise APIError(
1164
+ f"All {len(self._instances)} instances for '{self.model_id}' failed",
1165
+ "ALL_INSTANCES_FAILED",
1166
+ )
1167
+
1168
+ def execute_stream(self, **kwargs) -> Generator[str, None, None]:
1169
+ """
1170
+ Execute streaming with failover.
1171
+ Since generators can't easily be retried mid-stream,
1172
+ we do failover only on initial connection failure.
1173
+ """
1174
+ primary = self._select_instance()
1175
+ metrics.record_lb_dispatch()
1176
+
1177
+ if not primary.ready:
1178
+ try:
1179
+ primary.initialize()
1180
+ except Exception:
1181
+ pass
1182
+
1183
+ # Try primary
1184
+ try:
1185
+ yield from self._call_provider_stream(primary, **kwargs)
1186
+ return
1187
+ except Exception as primary_err:
1188
+ primary.record_failure()
1189
+ log.warning(
1190
+ f"[LB] Stream primary instance {primary.instance_id} "
1191
+ f"for '{self.model_id}' failed: {primary_err}"
1192
+ )
1193
+
1194
+ # Failover
1195
+ for inst in self._get_ordered_instances():
1196
+ if inst is primary:
1197
+ continue
1198
+ if not inst.ready:
1199
+ try:
1200
+ inst.initialize()
1201
+ except Exception:
1202
+ continue
1203
+
1204
+ metrics.record_lb_dispatch(failover=True)
1205
+ try:
1206
+ yield from self._call_provider_stream(inst, **kwargs)
1207
+ return
1208
+ except Exception as e:
1209
+ inst.record_failure()
1210
+ log.warning(
1211
+ f"[LB] Stream failover instance {inst.instance_id} "
1212
+ f"for '{self.model_id}' failed: {e}"
1213
+ )
1214
+
1215
+ raise APIError(
1216
+ f"All streaming instances for '{self.model_id}' failed",
1217
+ "ALL_INSTANCES_FAILED",
1218
+ )
1219
+
1220
+ def _call_provider(self, provider: ModelProvider, fn_name: str,
1221
+ **kwargs) -> Any:
1222
+ if not provider.ready:
1223
+ provider.initialize()
1224
+ fn = getattr(provider, fn_name)
1225
+ return fn(**kwargs)
1226
+
1227
+ def _call_provider_stream(self, provider: ModelProvider,
1228
+ **kwargs) -> Generator[str, None, None]:
1229
+ if not provider.ready:
1230
+ provider.initialize()
1231
+ start = time.monotonic()
1232
+ try:
1233
+ yield from provider.generate_stream(**kwargs)
1234
+ latency = (time.monotonic() - start) * 1000
1235
+ provider.record_success(latency)
1236
+ except Exception:
1237
+ provider.record_failure()
1238
+ raise
1239
+
1240
+ def get_pool_info(self) -> Dict:
1241
+ return {
1242
+ "model_id": self.model_id,
1243
+ "lb_enabled": self.mdef.lb_enabled,
1244
+ "pool_size": len(self._instances),
1245
+ "instances": [inst.get_instance_info() for inst in self._instances],
1246
+ }
1247
 
1248
  # ═══════════════════════════════════════════════════════════════
1249
+ # MULTI-MODEL CLIENT (with load balancing)
1250
  # ═══════════════════════════════════════════════════════════════
1251
 
1252
  class MultiModelClient:
1253
  def __init__(self, config: Config):
1254
  self.config = config
1255
+ self._lb_pools: Dict[str, LoadBalancedProviderPool] = {}
1256
  self._lock = threading.Lock()
1257
  self._conversations: Dict[str, Conversation] = {}
1258
  self._active_conv_id: Optional[str] = None
1259
  self._current_model = config.default_model
1260
+ self.rate_limiter = RateLimiter(config.rate_limit_rps, config.rate_limit_burst)
1261
  self.circuit_breaker = CircuitBreaker()
1262
 
1263
  @property
 
1270
  raise ModelNotFoundError(m)
1271
  self._current_model = m
1272
 
1273
+ def _get_lb_pool(self, model_id: str) -> LoadBalancedProviderPool:
1274
+ if model_id not in self._lb_pools:
1275
  with self._lock:
1276
+ if model_id not in self._lb_pools:
1277
+ self._lb_pools[model_id] = LoadBalancedProviderPool(
1278
+ model_id, self.config
1279
+ )
1280
+ return self._lb_pools[model_id]
1281
+
1282
+ def _ensure_ready(self, model_id: str) -> LoadBalancedProviderPool:
1283
+ pool = self._get_lb_pool(model_id)
1284
+ # Make sure at least one instance is ready
1285
+ has_ready = any(inst.ready for inst in pool._instances)
1286
+ if not has_ready:
1287
+ if not pool.initialize_one():
1288
+ raise APIError(f"Cannot init any instance for {model_id}",
1289
+ "INIT_FAILED")
1290
+ return pool
1291
 
1292
  @property
1293
  def active_conversation(self) -> Conversation:
1294
  if self._active_conv_id not in self._conversations:
1295
+ conv = Conversation(
1296
+ system_prompt=self.config.default_system_prompt,
1297
+ model_id=self._current_model,
1298
+ )
1299
  self._conversations[conv.conversation_id] = conv
1300
  self._active_conv_id = conv.conversation_id
1301
  return self._conversations[self._active_conv_id]
1302
 
1303
+ def new_conversation(self, system_prompt=None,
1304
+ model_id=None) -> Conversation:
1305
+ conv = Conversation(
1306
+ system_prompt=system_prompt or self.config.default_system_prompt,
1307
+ model_id=model_id or self._current_model,
1308
+ )
1309
  self._conversations[conv.conversation_id] = conv
1310
  self._active_conv_id = conv.conversation_id
1311
  return conv
1312
 
1313
  def init_model(self, model_id: str) -> bool:
1314
  try:
1315
+ pool = self._get_lb_pool(model_id)
1316
+ return pool.initialize_one()
1317
+ except Exception:
1318
  return False
1319
 
1320
+ def init_model_all(self, model_id: str) -> int:
1321
+ """Init all instances in the pool, return count of ready ones."""
1322
+ try:
1323
+ pool = self._get_lb_pool(model_id)
1324
+ return pool.initialize_all()
1325
+ except Exception:
1326
+ return 0
1327
+
1328
+ def send_message(
1329
+ self,
1330
+ message: str,
1331
+ *,
1332
+ stream: bool = False,
1333
+ model: Optional[str] = None,
1334
+ conversation_id: Optional[str] = None,
1335
+ system_prompt: Optional[str] = None,
1336
+ temperature: Optional[float] = None,
1337
+ max_tokens: Optional[int] = None,
1338
+ include_thinking: Optional[bool] = None,
1339
+ **kwargs,
1340
+ ) -> Union[str, Generator]:
1341
  model_id = model or self._current_model
1342
  if model_id not in MODEL_REGISTRY:
1343
  raise ModelNotFoundError(model_id)
 
1350
  if not self.circuit_breaker.can_execute():
1351
  raise APIError("Circuit breaker open", "CIRCUIT_OPEN", 503)
1352
  if not self.rate_limiter.acquire(timeout=10.0):
1353
+ raise APIError("Rate limited (10 req/s max)", "RATE_LIMITED", 429)
1354
 
1355
+ conv = (self._conversations.get(conversation_id, self.active_conversation)
1356
+ if conversation_id else self.active_conversation)
1357
  conv.model_id = model_id
1358
  if system_prompt:
1359
  conv.system_prompt = system_prompt
 
1361
  history = conv.build_gradio_history() if mdef.supports_history else None
1362
  conv.add_message("user", message, self.config.max_history_messages)
1363
 
1364
+ eff_temp = (temperature if temperature is not None
1365
+ else mdef.default_temperature)
1366
  eff_sys = conv.system_prompt if mdef.supports_system_prompt else None
1367
+ eff_thinking = (include_thinking if include_thinking is not None
1368
+ else self.config.include_thinking)
1369
 
1370
  extra = dict(kwargs)
1371
  if mdef.supports_thinking:
 
1376
  for attempt in range(self.config.max_retries + 1):
1377
  try:
1378
  if attempt > 0:
1379
+ time.sleep(
1380
+ self.config.retry_backoff_base ** attempt
1381
+ + random.uniform(0, self.config.retry_jitter)
1382
+ )
1383
  metrics.record_retry()
1384
 
1385
+ lb_pool = self._ensure_ready(model_id)
1386
 
1387
  if stream and mdef.supports_streaming:
1388
+ gen = lb_pool.execute_stream(
1389
+ message=message,
1390
+ history=history,
1391
+ system_prompt=eff_sys,
1392
+ temperature=eff_temp,
1393
+ max_tokens=max_tokens,
1394
+ **extra,
1395
+ )
1396
  return self._wrap_stream(gen, conv, start, model_id)
1397
 
1398
+ result = lb_pool.execute(
1399
+ "generate",
1400
+ message=message,
1401
+ history=history,
1402
+ system_prompt=eff_sys,
1403
+ temperature=eff_temp,
1404
+ max_tokens=max_tokens,
1405
+ **extra,
1406
+ )
1407
  dur = (time.monotonic() - start) * 1000
1408
  thinking, response = ThinkingParser.split(result)
1409
+ conv.add_message("assistant", response,
1410
+ self.config.max_history_messages,
1411
+ thinking=thinking)
1412
  metrics.record_request(True, dur, len(result), model_id)
1413
  self.circuit_breaker.record_success()
1414
  return result
 
1433
  full += chunk
1434
  yield chunk
1435
  thinking, response = ThinkingParser.split(full)
1436
+ conv.add_message("assistant", response,
1437
+ self.config.max_history_messages,
1438
+ thinking=thinking)
1439
+ metrics.record_request(
1440
+ True, (time.monotonic() - start) * 1000,
1441
+ len(full), model_id,
1442
+ )
1443
  self.circuit_breaker.record_success()
1444
  except Exception:
1445
+ metrics.record_request(
1446
+ False, (time.monotonic() - start) * 1000, model=model_id,
1447
+ )
1448
  self.circuit_breaker.record_failure()
1449
  raise
1450
 
1451
  def get_status(self) -> Dict:
1452
+ lb_info = {}
1453
+ for model_id, pool in self._lb_pools.items():
1454
+ lb_info[model_id] = pool.get_pool_info()
1455
+
1456
  return {
1457
+ "version": VERSION,
1458
+ "current_model": self._current_model,
1459
  "models": list(MODEL_REGISTRY.keys()),
1460
+ "load_balancer": lb_info,
1461
  "conversations": len(self._conversations),
1462
  "circuit_breaker": self.circuit_breaker.state,
1463
+ "rate_limiter": self.rate_limiter.get_info(),
1464
  }
1465
 
1466
  # ═══════════════════════════════════════════════════════════════
1467
+ # SESSION POOL (top-level pool of MultiModelClients)
1468
  # ═══════════════════════════════════════════════════════════════
1469
 
1470
  class SessionPool:
1471
  def __init__(self, config: Config):
1472
  self.config = config
1473
+ self._clients = [
1474
+ MultiModelClient(config) for _ in range(config.pool_size)
1475
+ ]
1476
  self._idx = 0
1477
  self._lock = threading.Lock()
1478
 
 
1481
  c.init_model(self.config.default_model)
1482
 
1483
  def init_model(self, model_id: str) -> int:
1484
+ total = 0
1485
+ for c in self._clients:
1486
+ total += c.init_model_all(model_id)
1487
+ return total
1488
 
1489
  def acquire(self) -> MultiModelClient:
1490
  with self._lock:
 
1498
 
1499
  ALIASES = {
1500
  "gpt-oss": "gpt-oss-120b", "gptoss": "gpt-oss-120b", "amd": "gpt-oss-120b",
1501
+ "command-a": "command-a-vision", "command-vision": "command-a-vision",
1502
+ "cohere-vision": "command-a-vision",
1503
+ "command-translate": "command-a-translate",
1504
+ "cohere-translate": "command-a-translate", "translate": "command-a-translate",
1505
  "minimax": "minimax-vl-01", "minimax-vl": "minimax-vl-01",
1506
  "glm": "glm-4.5", "glm4": "glm-4.5", "glm-4": "glm-4.5", "zhipu": "glm-4.5",
1507
  "gpt": "chatgpt", "gpt-3.5": "chatgpt", "gpt3": "chatgpt", "openai": "chatgpt",
1508
  "qwen": "qwen3-vl", "qwen3": "qwen3-vl", "qwen-vl": "qwen3-vl",
1509
  }
1510
 
1511
+
1512
  def resolve_alias(model_id: str) -> str:
1513
  return ALIASES.get(model_id.lower(), model_id)
1514
 
 
1522
 
1523
  app = Flask(APP_NAME)
1524
 
1525
+
1526
  @app.after_request
1527
  def cors(response):
1528
  response.headers["Access-Control-Allow-Origin"] = "*"
 
1530
  response.headers["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS"
1531
  return response
1532
 
1533
+
1534
  @app.errorhandler(APIError)
1535
  def handle_api_error(e: APIError):
1536
  return jsonify({"ok": False, **e.to_dict()}), e.status
1537
 
1538
+
1539
  @app.route("/")
1540
  def index():
1541
  return jsonify({
1542
+ "name": APP_NAME,
1543
+ "version": VERSION,
1544
  "default_model": config.default_model,
1545
+ "features": ["load_balancing", "10_req_per_second_limit", "failover"],
1546
  "models": list(MODEL_REGISTRY.keys()),
1547
  "endpoints": {
1548
  "POST /chat": "Chat with any model",
 
1550
  "POST /v1/chat/completions": "OpenAI-compatible",
1551
  "GET /v1/models": "List models",
1552
  "POST /models/init": "Init a model",
1553
+ "GET /health": "Health check (incl. LB status)",
1554
  "GET /metrics": "Metrics",
1555
+ "GET /lb/status": "Load balancer detailed status",
1556
  },
1557
  })
1558
 
1559
+
1560
  @app.route("/chat", methods=["POST"])
1561
  def chat():
1562
  data = freq.get_json(force=True, silent=True) or {}
 
1568
  client = pool.acquire()
1569
  if data.get("new_conversation"):
1570
  client.new_conversation(data.get("system_prompt"), model_id)
1571
+ result = client.send_message(
1572
+ message, model=model_id,
1573
+ system_prompt=data.get("system_prompt"),
1574
+ temperature=data.get("temperature"),
1575
+ max_tokens=data.get("max_tokens"),
1576
+ include_thinking=include_thinking,
1577
+ )
1578
  thinking, clean = ThinkingParser.split(result)
1579
+ resp = {
1580
+ "ok": True,
1581
+ "response": clean,
1582
+ "model": model_id,
1583
+ "conversation_id": client.active_conversation.conversation_id,
1584
+ "history_size": len(client.active_conversation.messages),
1585
+ }
1586
  if thinking:
1587
  resp["thinking"] = thinking
1588
  return jsonify(resp)
1589
 
1590
+
1591
  @app.route("/chat/stream", methods=["POST"])
1592
  def chat_stream():
1593
  data = freq.get_json(force=True, silent=True) or {}
 
1605
  def generate():
1606
  try:
1607
  if use_stream:
1608
+ for chunk in client.send_message(
1609
+ message, stream=True, model=model_id,
1610
+ system_prompt=data.get("system_prompt"),
1611
+ temperature=data.get("temperature"),
1612
+ max_tokens=data.get("max_tokens"),
1613
+ include_thinking=include_thinking,
1614
+ ):
1615
  yield f"data: {json.dumps({'chunk': chunk})}\n\n"
1616
  else:
1617
+ result = client.send_message(
1618
+ message, model=model_id,
1619
+ system_prompt=data.get("system_prompt"),
1620
+ temperature=data.get("temperature"),
1621
+ max_tokens=data.get("max_tokens"),
1622
+ include_thinking=include_thinking,
1623
+ )
1624
  yield f"data: {json.dumps({'chunk': result})}\n\n"
1625
  yield "data: [DONE]\n\n"
1626
  except APIError as e:
1627
  yield f"data: {json.dumps(e.to_dict())}\n\n"
1628
 
1629
+ return Response(stream_with_context(generate()),
1630
+ content_type="text/event-stream")
1631
+
1632
 
1633
  @app.route("/v1/models", methods=["GET"])
1634
  def list_models():
1635
  models = []
1636
  for mid, mdef in MODEL_REGISTRY.items():
1637
  models.append({
1638
+ "id": mid,
1639
+ "object": "model",
1640
+ "owned_by": mdef.owned_by,
1641
+ "created": 0,
1642
  "description": mdef.description,
1643
  "capabilities": {
1644
+ "vision": mdef.supports_vision,
1645
+ "streaming": mdef.supports_streaming,
1646
+ "system_prompt": mdef.supports_system_prompt,
1647
+ "temperature": mdef.supports_temperature,
1648
+ "history": mdef.supports_history,
1649
+ "thinking": mdef.supports_thinking,
1650
+ },
1651
+ "load_balancing": {
1652
+ "enabled": mdef.lb_enabled,
1653
+ "pool_size": mdef.lb_pool_size,
1654
  },
1655
  })
1656
  return jsonify({"object": "list", "data": models})
1657
 
1658
+
1659
  @app.route("/v1/chat/completions", methods=["POST", "OPTIONS"])
1660
  def openai_compat():
1661
  if freq.method == "OPTIONS":
 
1669
  include_thinking = data.get("include_thinking", config.include_thinking)
1670
 
1671
  if model_id not in MODEL_REGISTRY:
1672
+ return jsonify({
1673
+ "error": {
1674
+ "message": f"Model '{model_id}' not found",
1675
+ "type": "invalid_request_error",
1676
+ }
1677
+ }), 404
1678
  if not messages:
1679
  return jsonify({"error": {"message": "messages required"}}), 400
1680
 
 
1704
  if do_stream:
1705
  def generate():
1706
  try:
1707
+ yield (
1708
+ f"data: {json.dumps({'id': rid, 'object': 'chat.completion.chunk', "
1709
+ f"'created': created, 'model': model_id, 'choices': ["
1710
+ f"{{'index': 0, 'delta': {{'role': 'assistant'}}, "
1711
+ f"'finish_reason': None}}]})}\n\n"
1712
+ )
1713
  if mdef.supports_streaming:
1714
+ for chunk in client.send_message(
1715
+ user_msg, stream=True, model=model_id,
1716
+ temperature=temperature, max_tokens=max_tokens,
1717
+ include_thinking=include_thinking,
1718
+ ):
1719
+ yield (
1720
+ f"data: {json.dumps({'id': rid, "
1721
+ f"'object': 'chat.completion.chunk', "
1722
+ f"'created': created, 'model': model_id, "
1723
+ f"'choices': [{{'index': 0, "
1724
+ f"'delta': {{'content': chunk}}, "
1725
+ f"'finish_reason': None}}]})}\n\n"
1726
+ )
1727
  else:
1728
+ result = client.send_message(
1729
+ user_msg, model=model_id, temperature=temperature,
1730
+ max_tokens=max_tokens,
1731
+ include_thinking=include_thinking,
1732
+ )
1733
+ yield (
1734
+ f"data: {json.dumps({'id': rid, "
1735
+ f"'object': 'chat.completion.chunk', "
1736
+ f"'created': created, 'model': model_id, "
1737
+ f"'choices': [{{'index': 0, "
1738
+ f"'delta': {{'content': result}}, "
1739
+ f"'finish_reason': None}}]})}\n\n"
1740
+ )
1741
+ yield (
1742
+ f"data: {json.dumps({'id': rid, "
1743
+ f"'object': 'chat.completion.chunk', "
1744
+ f"'created': created, 'model': model_id, "
1745
+ f"'choices': [{{'index': 0, 'delta': {{}}, "
1746
+ f"'finish_reason': 'stop'}}]})}\n\n"
1747
+ )
1748
  yield "data: [DONE]\n\n"
1749
  except Exception as e:
1750
  yield f"data: {json.dumps({'error': {'message': str(e)}})}\n\n"
 
1751
 
1752
+ return Response(stream_with_context(generate()),
1753
+ content_type="text/event-stream")
1754
+
1755
+ result = client.send_message(
1756
+ user_msg, model=model_id, temperature=temperature,
1757
+ max_tokens=max_tokens, include_thinking=include_thinking,
1758
+ )
1759
  return jsonify({
1760
+ "id": rid,
1761
+ "object": "chat.completion",
1762
+ "created": created,
1763
+ "model": model_id,
1764
+ "choices": [{
1765
+ "index": 0,
1766
+ "message": {"role": "assistant", "content": result},
1767
+ "finish_reason": "stop",
1768
+ }],
1769
+ "usage": {
1770
+ "prompt_tokens": len(user_msg) // 4,
1771
+ "completion_tokens": len(result) // 4,
1772
+ "total_tokens": (len(user_msg) + len(result)) // 4,
1773
+ },
1774
  })
1775
 
1776
+
1777
  @app.route("/new", methods=["POST"])
1778
  def new_conv():
1779
  data = freq.get_json(force=True, silent=True) or {}
1780
  model_id = resolve_alias(data.get("model", config.default_model))
1781
  client = pool.acquire()
1782
  conv = client.new_conversation(data.get("system_prompt"), model_id)
1783
+ return jsonify({
1784
+ "ok": True,
1785
+ "conversation_id": conv.conversation_id,
1786
+ "model": model_id,
1787
+ })
1788
+
1789
 
1790
  @app.route("/health", methods=["GET"])
1791
  def health():
1792
  client = pool.acquire()
1793
  return jsonify(client.get_status())
1794
 
1795
+
1796
  @app.route("/metrics", methods=["GET"])
1797
  def metrics_endpoint():
1798
  return jsonify(metrics.to_dict())
1799
 
1800
+
1801
+ @app.route("/lb/status", methods=["GET"])
1802
+ def lb_status():
1803
+ """Detailed load balancer status for all models across all clients."""
1804
+ all_pools = {}
1805
+ for client in pool._clients:
1806
+ for model_id, lb_pool in client._lb_pools.items():
1807
+ key = f"{model_id}"
1808
+ if key not in all_pools:
1809
+ all_pools[key] = []
1810
+ all_pools[key].append(lb_pool.get_pool_info())
1811
+ return jsonify({
1812
+ "ok": True,
1813
+ "version": VERSION,
1814
+ "rate_limit": f"{config.rate_limit_rps} req/s",
1815
+ "models": all_pools,
1816
+ })
1817
+
1818
+
1819
  @app.route("/conversations", methods=["GET"])
1820
  def conversations():
1821
  client = pool.acquire()
1822
+ return jsonify({
1823
+ "conversations": [c.to_dict() for c in client._conversations.values()]
1824
+ })
1825
+
1826
 
1827
  @app.route("/models/init", methods=["POST"])
1828
  def init_model_ep():
1829
  data = freq.get_json(force=True, silent=True) or {}
1830
  model_id = resolve_alias(data.get("model", ""))
1831
  if not model_id or model_id not in MODEL_REGISTRY:
1832
+ return jsonify({
1833
+ "ok": False,
1834
+ "error": f"Unknown model. Available: {list(MODEL_REGISTRY.keys())}",
1835
+ }), 400
1836
  count = pool.init_model(model_id)
1837
+ mdef = MODEL_REGISTRY[model_id]
1838
+ return jsonify({
1839
+ "ok": True,
1840
+ "model": model_id,
1841
+ "initialized_instances": count,
1842
+ "lb_enabled": mdef.lb_enabled,
1843
+ "pool_size_per_client": mdef.lb_pool_size,
1844
+ })
1845
+
1846
 
1847
  # ═══════════════════════════════════════════════════════════════
1848
+ # ENTRY POINT
1849
  # ═══════════════════════════════════════════════════════════════
1850
 
1851
  if __name__ == "__main__":
1852
  port = int(os.environ.get("PORT", 7860))
1853
  log.info(f"Starting Multi-Model AI API v{VERSION} on port {port}")
1854
  log.info(f"Models: {list(MODEL_REGISTRY.keys())}")
1855
+ log.info(f"Rate limit: {config.rate_limit_rps} req/s (burst: {config.rate_limit_burst})")
1856
+ for mid, mdef in MODEL_REGISTRY.items():
1857
+ lb_status_str = (
1858
+ f"LB ON (pool={mdef.lb_pool_size})"
1859
+ if mdef.lb_enabled
1860
+ else "LB OFF (single instance)"
1861
+ )
1862
+ log.info(f" {mid}: {lb_status_str}")
1863
  app.run(host="0.0.0.0", port=port, threaded=True)