github-actions[bot] commited on
Commit
128a79a
ยท
1 Parent(s): e2968a4

๐Ÿš€ Auto-deploy backend from GitHub (54956be)

Browse files
requirements.txt CHANGED
@@ -17,5 +17,6 @@ joblib==1.4.2
17
  scipy==1.15.1
18
  numpy==2.2.1
19
  firebase-admin>=6.2.0
 
20
  redis[hiredis]>=5.0.0
21
  PyYAML>=6.0.0
 
17
  scipy==1.15.1
18
  numpy==2.2.1
19
  firebase-admin>=6.2.0
20
+ openai>=1.12.0
21
  redis[hiredis]>=5.0.0
22
  PyYAML>=6.0.0
services/ai_client.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from openai import OpenAI, APIError, RateLimitError, APITimeoutError
3
+ from functools import lru_cache
4
+
5
+ __all__ = [
6
+ "get_deepseek_client",
7
+ "CHAT_MODEL",
8
+ "REASONER_MODEL",
9
+ "DEEPSEEK_BASE_URL",
10
+ "APIError",
11
+ "RateLimitError",
12
+ "APITimeoutError",
13
+ ]
14
+
15
+ DEEPSEEK_BASE_URL = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com")
16
+ CHAT_MODEL = os.getenv("DEEPSEEK_MODEL", "deepseek-chat")
17
+ REASONER_MODEL = os.getenv("DEEPSEEK_REASONER_MODEL", "deepseek-reasoner")
18
+
19
+
20
+ @lru_cache(maxsize=1)
21
+ def get_deepseek_client() -> OpenAI:
22
+ api_key = os.getenv("DEEPSEEK_API_KEY")
23
+ if not api_key:
24
+ raise ValueError("DEEPSEEK_API_KEY environment variable not set")
25
+ return OpenAI(
26
+ api_key=api_key,
27
+ base_url=DEEPSEEK_BASE_URL,
28
+ )
services/inference_client.py CHANGED
@@ -10,13 +10,198 @@ from typing import Any, Dict, List, Optional, Tuple
10
 
11
  import requests
12
  import yaml
13
- from huggingface_hub import InferenceClient as HFInferenceClient
14
 
 
15
  from .logging_utils import configure_structured_logging, log_model_call
16
 
17
  LOGGER = configure_structured_logging("mathpulse.inference")
18
  TEMP_CHAT_MODEL_OVERRIDE_ENV = "INFERENCE_CHAT_MODEL_TEMP_OVERRIDE"
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def _normalize_local_space_url(raw_url: str) -> str:
22
  """Accept either hf.space host or huggingface.co/spaces URL for local_space provider."""
@@ -24,8 +209,6 @@ def _normalize_local_space_url(raw_url: str) -> str:
24
  if not cleaned:
25
  return "http://127.0.0.1:7860"
26
 
27
- # Convert page URL format to runtime host format:
28
- # https://huggingface.co/spaces/{owner}/{space} -> https://{owner}-{space}.hf.space
29
  match = re.match(r"^https?://huggingface\.co/spaces/([^/]+)/([^/]+)$", cleaned, re.IGNORECASE)
30
  if match:
31
  owner = match.group(1).strip().lower()
@@ -41,28 +224,31 @@ class InferenceRequest:
41
  model: Optional[str] = None
42
  task_type: str = "default"
43
  request_tag: str = ""
44
- max_new_tokens: int = 512
45
  temperature: float = 0.2
46
  top_p: float = 0.9
47
  repetition_penalty: float = 1.15
48
  timeout_sec: Optional[int] = None
 
49
 
50
 
51
  class InferenceClient:
52
- def __init__(self) -> None:
53
- # Try multiple config paths (HF Space, Docker, local development)
54
- # The deploy script uploads config/ to the space root
 
 
55
  config_paths = [
56
- Path("./config/models.yaml"), # Current working directory (most reliable)
57
- Path("/config/models.yaml"), # HF Space root
58
- Path("/app/config/models.yaml"), # App directory
59
- Path.cwd() / "config" / "models.yaml", # CWD with config subdir
60
- Path(__file__).resolve().parents[2] / "config" / "models.yaml", # Package root
61
  ]
62
-
63
  config: Dict[str, object] = {}
64
  config_path = None
65
-
66
  for path in config_paths:
67
  if path.exists():
68
  config_path = path
@@ -70,7 +256,7 @@ class InferenceClient:
70
  config = yaml.safe_load(fh) or {}
71
  LOGGER.info(f"โœ… Loaded config from {config_path}")
72
  break
73
-
74
  if not config_path:
75
  LOGGER.warning(f"โš ๏ธ Config file not found. Checked: {[str(p) for p in config_paths]}")
76
  LOGGER.warning(f" CWD: {Path.cwd()}")
@@ -84,74 +270,43 @@ class InferenceClient:
84
  if isinstance(primary_cfg, dict):
85
  primary = primary_cfg
86
 
87
- self.provider = os.getenv("INFERENCE_PROVIDER", "hf_inference").strip().lower()
88
- self.pro_provider = os.getenv("INFERENCE_PRO_PROVIDER", "hf_inference").strip().lower()
89
- self.gpu_provider = os.getenv("INFERENCE_GPU_PROVIDER", "hf_inference").strip().lower()
90
- self.cpu_provider = os.getenv("INFERENCE_CPU_PROVIDER", "hf_inference").strip().lower()
91
- self.enable_provider_fallback = os.getenv("INFERENCE_ENABLE_PROVIDER_FALLBACK", "true").strip().lower() in {"1", "true", "yes", "on"}
92
- self.pro_enabled = os.getenv("INFERENCE_PRO_ENABLED", "false").strip().lower() in {"1", "true", "yes", "on"}
93
- self.hf_token = os.getenv(
94
- "HF_TOKEN",
95
- os.getenv("HUGGING_FACE_API_TOKEN", os.getenv("HUGGINGFACE_API_TOKEN", "")),
96
- )
97
- self.hf_base_url = os.getenv("INFERENCE_HF_BASE_URL", "https://router.huggingface.co/hf-inference/models")
98
- self.hf_chat_url = os.getenv("INFERENCE_HF_CHAT_URL", "https://router.huggingface.co/v1/chat/completions")
99
-
100
- # Featherless AI for Qwen math models (used as fallback when HF router fails)
101
- self.featherless_api_key = os.getenv("FEATHERLESS_API_KEY", "")
102
- self.featherless_chat_url = os.getenv("FEATHERLESS_CHAT_URL", "https://api.featherless.ai/openai/v1/chat/completions")
103
-
104
- # DeepSeek API (primary inference provider)
105
- self.deepseek_api_key = os.getenv("DEEPSEEK_API_KEY", "")
106
- self.deepseek_base_url = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com").rstrip("/")
107
- self.deepseek_chat_url = f"{self.deepseek_base_url}/v1/chat/completions"
108
-
109
  self.local_space_url = _normalize_local_space_url(
110
  os.getenv("INFERENCE_LOCAL_SPACE_URL", "http://127.0.0.1:7860")
111
  )
112
  self.local_generate_path = os.getenv("INFERENCE_LOCAL_SPACE_GENERATE_PATH", "/gradio_api/call/generate")
113
- self.pro_route_header_name = os.getenv("INFERENCE_PRO_ROUTE_HEADER_NAME", "")
114
- self.pro_route_header_value = os.getenv("INFERENCE_PRO_ROUTE_HEADER_VALUE", "true")
115
 
116
- self.enforce_qwen_only = os.getenv("INFERENCE_ENFORCE_QWEN_ONLY", "false").strip().lower() in {"1", "true", "yes", "on"}
117
- self.qwen_lock_model = os.getenv("INFERENCE_QWEN_LOCK_MODEL", "deepseek-chat").strip() or "deepseek-chat"
118
 
119
- default_model_fallback = str(primary.get("id") or "deepseek-chat")
120
  env_model_id = os.getenv("INFERENCE_MODEL_ID", "").strip()
121
  self.default_model = env_model_id or default_model_fallback
122
-
123
  default_max_tokens = str(primary.get("max_new_tokens") or 512)
124
  self.default_max_new_tokens = int(os.getenv("INFERENCE_MAX_NEW_TOKENS", default_max_tokens))
125
-
126
  default_temp = str(primary.get("temperature") or 0.2)
127
  self.default_temperature = float(os.getenv("INFERENCE_TEMPERATURE", default_temp))
128
-
129
  default_top_p = str(primary.get("top_p") or 0.9)
130
  self.default_top_p = float(os.getenv("INFERENCE_TOP_P", default_top_p))
131
-
132
- # Task-specific model overrides via environment variables
133
  self.chat_model_override = os.getenv("INFERENCE_CHAT_MODEL_ID", "").strip()
134
  self.chat_model_temp_override = os.getenv(TEMP_CHAT_MODEL_OVERRIDE_ENV, "").strip()
135
  self.chat_strict_model_only = os.getenv("INFERENCE_CHAT_STRICT_MODEL_ONLY", "true").strip().lower() in {"1", "true", "yes", "on"}
136
- self.chat_hard_model = os.getenv("INFERENCE_CHAT_HARD_MODEL_ID", "meta-llama/Meta-Llama-3-70B-Instruct").strip()
137
- self.chat_hard_trigger_enabled = os.getenv("INFERENCE_CHAT_HARD_TRIGGER_ENABLED", "false").strip().lower() in {"1", "true", "yes", "on"}
138
- self.chat_hard_prompt_chars = max(256, int(os.getenv("INFERENCE_CHAT_HARD_PROMPT_CHARS", "800")))
139
- self.chat_hard_history_chars = max(
140
- self.chat_hard_prompt_chars,
141
- int(os.getenv("INFERENCE_CHAT_HARD_HISTORY_CHARS", "1800")),
142
- )
143
- hard_keywords_raw = os.getenv(
144
- "INFERENCE_CHAT_HARD_KEYWORDS",
145
- "step-by-step,show all steps,derive,proof,prove,rigorous,multi-step,word problem",
146
- )
147
- self.chat_hard_keywords = [kw.strip().lower() for kw in hard_keywords_raw.split(",") if kw.strip()]
148
 
149
- self.hf_timeout_sec = int(os.getenv("INFERENCE_HF_TIMEOUT_SEC", "90"))
150
  self.local_timeout_sec = int(os.getenv("INFERENCE_LOCAL_SPACE_TIMEOUT_SEC", "90"))
151
  self.max_retries = int(os.getenv("INFERENCE_MAX_RETRIES", "3"))
152
  self.backoff_sec = float(os.getenv("INFERENCE_BACKOFF_SEC", "1.5"))
153
- self.interactive_timeout_sec = int(os.getenv("INFERENCE_INTERACTIVE_TIMEOUT_SEC", str(self.hf_timeout_sec)))
154
- self.background_timeout_sec = int(os.getenv("INFERENCE_BACKGROUND_TIMEOUT_SEC", str(self.hf_timeout_sec)))
155
  self.interactive_max_retries = int(os.getenv("INFERENCE_INTERACTIVE_MAX_RETRIES", str(self.max_retries)))
156
  self.background_max_retries = int(os.getenv("INFERENCE_BACKGROUND_MAX_RETRIES", str(self.max_retries)))
157
  self.interactive_backoff_sec = float(os.getenv("INFERENCE_INTERACTIVE_BACKOFF_SEC", str(self.backoff_sec)))
@@ -172,12 +327,6 @@ class InferenceClient:
172
  )
173
  self.cpu_only_tasks = {v.strip().lower() for v in cpu_tasks_raw.split(",") if v.strip()}
174
 
175
- pro_tasks_raw = os.getenv(
176
- "INFERENCE_PRO_PRIORITY_TASKS",
177
- "chat,quiz_generation,lesson_generation,learning_path,verify_solution",
178
- )
179
- self.pro_priority_tasks = {v.strip().lower() for v in pro_tasks_raw.split(",") if v.strip()}
180
-
181
  interactive_tasks_raw = os.getenv(
182
  "INFERENCE_INTERACTIVE_TASKS",
183
  "chat,verify_solution,daily_insight",
@@ -189,29 +338,20 @@ class InferenceClient:
189
  )
190
 
191
  # Default task-to-model routing.
192
- # Keep all tasks pinned to deepseek-chat when qwen-only lock is active.
193
  self.task_model_map: Dict[str, str] = {
194
- "chat": "deepseek-chat",
195
- "verify_solution": "deepseek-chat",
196
- "lesson_generation": "deepseek-chat",
197
- "quiz_generation": "deepseek-chat",
198
- "learning_path": "deepseek-chat",
199
- "daily_insight": "deepseek-chat",
200
- "risk_classification": "deepseek-chat",
201
- "risk_narrative": "deepseek-chat",
202
  }
203
- # Fallback chains (only to other HF-supported models, no featherless-ai)
204
  self.task_fallback_model_map: Dict[str, List[str]] = {
205
- "chat": [
206
- "meta-llama/Llama-3.1-8B-Instruct",
207
- "google/gemma-2-2b-it",
208
- ],
209
- "verify_solution": [
210
- "meta-llama/Llama-3.1-8B-Instruct",
211
- "google/gemma-2-2b-it",
212
- ],
213
  }
214
- # Model-to-provider mappings (not needed when using model:provider syntax directly)
215
  self.model_provider_map: Dict[str, str] = {}
216
  self.task_provider_map: Dict[str, str] = {}
217
  if isinstance(config, dict):
@@ -224,7 +364,6 @@ class InferenceClient:
224
  for task, model in task_models.items()
225
  if str(task).strip() and str(model).strip()
226
  }
227
- # Merge config models with defaults (config overrides defaults)
228
  self.task_model_map.update(config_task_models)
229
  task_fallback_models = routing_cfg.get("task_fallback_model_map", {})
230
  if isinstance(task_fallback_models, dict):
@@ -265,21 +404,19 @@ class InferenceClient:
265
  else:
266
  env_override_note = ""
267
 
268
- if self.enforce_qwen_only:
269
- qwen_map_before = dict(self.task_model_map)
270
- self.default_model = self.qwen_lock_model
271
  for task_key in list(self.task_model_map.keys()):
272
- self.task_model_map[task_key] = self.qwen_lock_model
273
  self.fallback_models = []
274
  self.task_fallback_model_map = {
275
  task_key: [] for task_key in self.task_model_map.keys()
276
  }
277
- self.chat_hard_trigger_enabled = False
278
- LOGGER.info(f"๐Ÿ”’ INFERENCE_ENFORCE_QWEN_ONLY enabled: locking all inference tasks to {self.qwen_lock_model}")
279
- LOGGER.info(f" Cleared fallback models and hard-escalation path")
280
- LOGGER.info(f" Task model mappings forced from: {qwen_map_before}")
281
 
282
- # Log configuration loaded for debugging
283
  config_status = "from file" if config_path else "hardcoded defaults (no config file found)"
284
  effective_chat_model_for_logs = self.chat_model_override or self.task_model_map.get("chat", self.default_model)
285
  LOGGER.info(f"โœ… InferenceClient initialized {config_status}{env_override_note}")
@@ -287,7 +424,7 @@ class InferenceClient:
287
  LOGGER.info(f" Chat model: {effective_chat_model_for_logs}")
288
  LOGGER.info(f" Chat temp override ({TEMP_CHAT_MODEL_OVERRIDE_ENV}): {self.chat_model_temp_override or 'disabled'}")
289
  LOGGER.info(f" Chat strict model lock: {self.chat_strict_model_only}")
290
- LOGGER.info(f" Global Qwen-only lock: {self.enforce_qwen_only}")
291
  LOGGER.info(f" Verify solution model: {self.task_model_map.get('verify_solution', self.default_model)}")
292
  LOGGER.info(f" Full task_model_map: {self.task_model_map}")
293
 
@@ -299,18 +436,23 @@ class InferenceClient:
299
  "requests_error": 0,
300
  "retries_total": 0,
301
  "fallback_attempts": 0,
 
 
302
  "route_counts": {},
303
  "task_counts": {},
304
  "provider_counts": {},
305
  "status_code_counts": {},
306
  }
307
 
 
 
308
  def _bump_metric(self, key: str, inc: int = 1) -> None:
309
  with self._metrics_lock:
310
  current = self._metrics.get(key) or 0
311
  if not isinstance(current, int):
312
  current = 0
313
  self._metrics[key] = current + inc
 
314
 
315
  def _bump_bucket(self, key: str, bucket: str, inc: int = 1) -> None:
316
  with self._metrics_lock:
@@ -322,6 +464,50 @@ class InferenceClient:
322
  if not isinstance(current, int):
323
  current = 0
324
  mapping[bucket] = current + inc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
  def _record_attempt(self, *, task_type: str, provider: str, route: str, fallback_depth: int) -> None:
327
  self._bump_metric("requests_total", 1)
@@ -333,6 +519,10 @@ class InferenceClient:
333
 
334
  def snapshot_metrics(self) -> Dict[str, Any]:
335
  with self._metrics_lock:
 
 
 
 
336
  snapshot = {
337
  "uptime_sec": round(max(0.0, time.time() - self._metrics_started_at), 2),
338
  "requests_total": self._metrics.get("requests_total") or 0,
@@ -340,6 +530,9 @@ class InferenceClient:
340
  "requests_error": self._metrics.get("requests_error") or 0,
341
  "retries_total": self._metrics.get("retries_total") or 0,
342
  "fallback_attempts": self._metrics.get("fallback_attempts") or 0,
 
 
 
343
  "route_counts": dict(self._metrics.get("route_counts") or {}),
344
  "task_counts": dict(self._metrics.get("task_counts") or {}),
345
  "provider_counts": dict(self._metrics.get("provider_counts") or {}),
@@ -351,22 +544,18 @@ class InferenceClient:
351
  effective_task = (req.task_type or "default").strip().lower()
352
  request_tag = req.request_tag.strip() or f"{effective_task}-{int(time.time() * 1000)}"
353
  selected_model, model_selection_source = self._resolve_primary_model(req)
354
-
355
  model_chain = self._model_chain_for_task(effective_task, selected_model)
356
  last_error: Optional[Exception] = None
357
- provider_chain = self._provider_chain_for_task(req.task_type)
358
-
359
- # Normalize model name (remove any provider suffix since we use hf_inference router)
360
- model_base = selected_model.split(":")[0] if ":" in selected_model else selected_model
361
-
362
- # Log model selection for debugging - confirm which model will actually be used
363
  LOGGER.info(
364
- f"๐ŸŽฏ request_tag={request_tag} task={effective_task} source={model_selection_source} "
365
- f"selected_model={model_base} (primary) provider_chain={provider_chain}"
366
  )
367
  LOGGER.info(f" fallback_chain={model_chain[1:] if len(model_chain) > 1 else 'none'}")
368
 
369
-
370
  for fallback_depth, model_name in enumerate(model_chain):
371
  request_for_model = InferenceRequest(
372
  messages=req.messages,
@@ -379,20 +568,19 @@ class InferenceClient:
379
  repetition_penalty=req.repetition_penalty,
380
  timeout_sec=req.timeout_sec,
381
  )
382
-
383
- for provider in provider_chain:
384
- try:
385
- result = self._generate_with_provider(request_for_model, provider, fallback_depth)
386
- if fallback_depth > 0:
387
- LOGGER.info(f"โœ… Fallback succeeded at depth={fallback_depth} model={model_name} provider={provider}")
388
- return result
389
- except Exception as exc:
390
- last_error = exc
391
- fallback_hint = f" (depth {fallback_depth})" if fallback_depth > 0 else ""
392
- LOGGER.warning(
393
- f"โš ๏ธ Attempt failed{fallback_hint}: task={request_for_model.task_type} "
394
- f"provider={provider} model={model_name} error={exc.__class__.__name__}: {str(exc)[:100]}"
395
- )
396
 
397
  if last_error:
398
  raise last_error
@@ -405,10 +593,6 @@ class InferenceClient:
405
  effective_task = (req.task_type or "default").strip().lower()
406
  runtime_chat_override = self._runtime_chat_model_override()
407
 
408
- def _base_model(model_name: str) -> str:
409
- return (model_name or "").split(":", 1)[0].strip()
410
-
411
- # Check explicit request model first, then chat override env, then task map/default.
412
  if effective_task == "chat" and runtime_chat_override:
413
  selected_model = runtime_chat_override
414
  model_selection_source = "chat_temp_override_env"
@@ -422,107 +606,39 @@ class InferenceClient:
422
  selected_model = self.task_model_map.get(effective_task, self.default_model)
423
  model_selection_source = "task_map"
424
 
425
- if self.enforce_qwen_only:
426
- effective_qwen_lock_model = self.qwen_lock_model
427
  if effective_task == "chat":
428
- effective_qwen_lock_model = runtime_chat_override or self.chat_model_override or self.qwen_lock_model
429
 
430
- selected_base = _base_model(selected_model)
431
- lock_base = _base_model(effective_qwen_lock_model)
432
  if selected_base != lock_base:
433
  LOGGER.warning(
434
- f"โš ๏ธ Qwen-only lock replaced requested model {selected_model} with {effective_qwen_lock_model}"
435
  )
436
- selected_model = effective_qwen_lock_model
437
- model_selection_source = f"{model_selection_source}:qwen_only"
438
 
439
  if effective_task == "chat" and self.chat_strict_model_only:
440
  return selected_model, f"{model_selection_source}:chat_strict_model_only"
441
 
442
- if effective_task == "chat" and self.chat_hard_trigger_enabled and self.chat_hard_model:
443
- should_escalate, reason = self._should_escalate_chat_to_hard_model(req.messages)
444
- if should_escalate and selected_model != self.chat_hard_model:
445
- return self.chat_hard_model, f"chat_hard_escalation:{reason}"
446
-
447
  return selected_model, model_selection_source
448
 
449
- def _should_escalate_chat_to_hard_model(self, messages: List[Dict[str, str]]) -> Tuple[bool, str]:
450
- latest_user = self._latest_user_message(messages)
451
- if not latest_user:
452
- return False, "no_user_message"
453
-
454
- latest_norm = latest_user.lower()
455
- prompt_chars = len(latest_user)
456
- history_chars = 0
457
- for msg in messages:
458
- content = (msg.get("content") or "") if isinstance(msg, dict) else ""
459
- history_chars += len(content)
460
-
461
- keyword_hit = ""
462
- for kw in self.chat_hard_keywords:
463
- if kw and kw in latest_norm:
464
- keyword_hit = kw
465
- break
466
-
467
- math_marker_count = len(
468
- re.findall(
469
- r"(=|\bintegral\b|\bderivative\b|\bmatrix\b|\blimit\b|\bproof\b|\bderive\b|\bsolve\b)",
470
- latest_norm,
471
- )
472
- )
473
-
474
- long_prompt = prompt_chars >= self.chat_hard_prompt_chars
475
- long_history = history_chars >= self.chat_hard_history_chars
476
- immediate_hard_request = any(
477
- phrase in latest_norm
478
- for phrase in (
479
- "show all steps",
480
- "step-by-step",
481
- "step by step",
482
- "rigorous proof",
483
- "formal proof",
484
- )
485
- )
486
-
487
- # Escalate immediately for long step-by-step prompts or heavy math density.
488
- escalate = bool(keyword_hit and immediate_hard_request)
489
- if not escalate:
490
- escalate = bool(keyword_hit and (long_prompt or long_history or math_marker_count >= 2))
491
- if not escalate and long_prompt and math_marker_count >= 2:
492
- escalate = True
493
- if not escalate and long_history and math_marker_count >= 2:
494
- escalate = True
495
-
496
- if not escalate:
497
- return False, "normal"
498
-
499
- reasons: List[str] = []
500
- if long_prompt:
501
- reasons.append(f"prompt_chars={prompt_chars}")
502
- if long_history:
503
- reasons.append(f"history_chars={history_chars}")
504
- if keyword_hit:
505
- reasons.append(f"keyword={keyword_hit}")
506
- if immediate_hard_request:
507
- reasons.append("immediate_hard_request")
508
- if math_marker_count >= 2:
509
- reasons.append(f"math_markers={math_marker_count}")
510
- return True, ",".join(reasons) if reasons else "hard_prompt"
511
-
512
  def _model_chain_for_task(self, task_type: str, selected_model: str) -> List[str]:
513
  normalized = (task_type or "default").strip().lower()
514
  runtime_chat_override = self._runtime_chat_model_override() if normalized == "chat" else ""
515
- chat_qwen_lock_model = runtime_chat_override or (self.chat_model_override if normalized == "chat" else "")
516
 
517
- if self.enforce_qwen_only:
518
  if normalized == "chat":
519
- locked_model = (chat_qwen_lock_model or self.qwen_lock_model or "").strip()
520
  else:
521
- locked_model = (self.qwen_lock_model or "").strip()
522
  return [locked_model] if locked_model else []
523
 
524
  if normalized == "chat" and self.chat_strict_model_only:
525
- chat_model = (chat_qwen_lock_model or selected_model or "").strip()
526
  return [chat_model] if chat_model else []
527
 
528
  per_task_candidates = self.task_fallback_model_map.get(task_type, [])
@@ -542,34 +658,6 @@ class InferenceClient:
542
  return deduped[:max_models]
543
  return deduped
544
 
545
- def _provider_chain_for_task(self, task_type: str) -> List[str]:
546
- normalized = (task_type or "default").strip().lower()
547
- forced_provider = self.task_provider_map.get(normalized)
548
- if forced_provider:
549
- return [forced_provider]
550
-
551
- if normalized in self.cpu_only_tasks:
552
- return [self.cpu_provider]
553
-
554
- if self.pro_enabled and normalized in self.pro_priority_tasks:
555
- chain = [self.pro_provider]
556
- if self.enable_provider_fallback and self.gpu_provider not in chain:
557
- chain.append(self.gpu_provider)
558
- if self.enable_provider_fallback and self.provider not in chain:
559
- chain.append(self.provider)
560
- return chain
561
-
562
- if normalized in self.gpu_required_tasks:
563
- chain = [self.gpu_provider]
564
- if self.enable_provider_fallback and self.cpu_provider != self.gpu_provider:
565
- chain.append(self.cpu_provider)
566
- return chain
567
-
568
- chain = [self.provider]
569
- if self.enable_provider_fallback and self.cpu_provider not in chain:
570
- chain.append(self.cpu_provider)
571
- return chain
572
-
573
  def _retry_profile(self, task_type: str) -> Tuple[int, float]:
574
  normalized = (task_type or "default").strip().lower()
575
  if normalized in self.interactive_tasks:
@@ -586,23 +674,6 @@ class InferenceClient:
586
  return self.interactive_timeout_sec
587
  return self.background_timeout_sec
588
 
589
- def _resolve_route_label(self, provider: str, task_type: str) -> str:
590
- normalized = (task_type or "default").strip().lower()
591
- if self.pro_enabled and normalized in self.pro_priority_tasks and provider == self.pro_provider:
592
- return "pro-priority"
593
- return "standard"
594
-
595
- def _generate_with_provider(self, req: InferenceRequest, provider: str, fallback_depth: int) -> str:
596
- route = self._resolve_route_label(provider, req.task_type)
597
- if provider == "local_space":
598
- return self._call_local_space(req, provider=provider, route=route, fallback_depth=fallback_depth)
599
-
600
- if provider == "deepseek":
601
- return self._call_deepseek(req, provider=provider, route=route, fallback_depth=fallback_depth)
602
-
603
- # All other providers use HF inference router
604
- return self._call_hf_inference(req, provider=provider, route=route, fallback_depth=fallback_depth)
605
-
606
  def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
607
  parts: List[str] = []
608
  for msg in messages:
@@ -615,9 +686,9 @@ class InferenceClient:
615
  prefix = "SYSTEM"
616
  elif role == "assistant":
617
  prefix = "ASSISTANT"
618
- parts.append(f"{prefix}:\\n{content}")
619
  parts.append("ASSISTANT:")
620
- return "\\n\\n".join(parts)
621
 
622
  def _latest_user_message(self, messages: List[Dict[str, str]]) -> str:
623
  for msg in reversed(messages):
@@ -627,160 +698,223 @@ class InferenceClient:
627
  return content
628
  return self._messages_to_prompt(messages)
629
 
630
- def _post_with_retry(
631
- self,
632
- url: str,
633
- *,
634
- headers: Dict[str, str],
635
- payload: Dict[str, object],
636
- timeout: int,
637
- provider: str,
638
- model: str,
639
- task_type: str,
640
- request_tag: str,
641
- fallback_depth: int,
642
- route: str,
643
- ) -> Tuple[requests.Response, float, int]:
644
- self._record_attempt(
645
- task_type=task_type,
646
- provider=provider,
647
- route=route,
648
- fallback_depth=fallback_depth,
649
  )
 
 
650
  max_retries, backoff_sec = self._retry_profile(task_type)
651
- attempt = 0
652
 
653
- def _retry_sleep(retry_attempt: int) -> None:
654
- # Small jitter reduces synchronized retry storms during transient provider issues.
655
- jitter_factor = random.uniform(0.9, 1.2)
656
- time.sleep(backoff_sec * retry_attempt * jitter_factor)
657
 
658
- while True:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659
  start = time.perf_counter()
660
  try:
661
- resp = requests.post(url, headers=headers, json=payload, timeout=timeout)
662
- except Exception as exc:
663
  latency_ms = (time.perf_counter() - start) * 1000
 
 
 
 
 
 
 
 
664
  log_model_call(
665
  LOGGER,
666
- provider=provider,
667
- model=model,
668
- endpoint=url,
669
  latency_ms=latency_ms,
670
  input_tokens=None,
671
  output_tokens=None,
672
- status="error",
673
- error_class=exc.__class__.__name__,
674
- error_message=str(exc),
675
  task_type=task_type,
676
- request_tag=request_tag,
677
  retry_attempt=attempt + 1,
678
  fallback_depth=fallback_depth,
679
  route=route,
680
  )
681
- if attempt >= max_retries - 1:
682
- self._bump_metric("requests_error", 1)
683
- raise
684
- attempt += 1
685
- self._bump_metric("retries_total", 1)
686
- _retry_sleep(attempt)
687
- continue
 
 
688
 
689
- latency_ms = (time.perf_counter() - start) * 1000
690
- if resp.status_code in {408, 429, 500, 502, 503, 504} and attempt < max_retries - 1:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
691
  log_model_call(
692
  LOGGER,
693
- provider=provider,
694
- model=model,
695
- endpoint=url,
696
  latency_ms=latency_ms,
697
  input_tokens=None,
698
  output_tokens=None,
699
  status="error",
700
- error_class="HTTPRetry",
701
- error_message=f"status={resp.status_code}",
702
  task_type=task_type,
703
- request_tag=request_tag,
704
  retry_attempt=attempt + 1,
705
  fallback_depth=fallback_depth,
706
  route=route,
707
  )
708
- attempt += 1
709
- self._bump_metric("retries_total", 1)
710
- _retry_sleep(attempt)
711
- continue
712
- return resp, latency_ms, attempt + 1
713
 
714
- def _call_hf_inference_direct(self, req: InferenceRequest, *, provider: str, route: str, fallback_depth: int) -> str:
715
- """
716
- Call Qwen models via Featherless AI provider.
717
- Uses HF InferenceClient with provider="featherless-ai" for direct model access.
718
- """
719
- if not self.hf_token:
720
- raise RuntimeError("HF_TOKEN is not set")
721
 
 
722
  target_model = req.model or self.default_model
723
- target_model_base = target_model.split(":")[0] if ":" in target_model else target_model
724
-
 
 
 
 
 
 
 
 
 
 
 
 
725
  timeout = self._timeout_for(req, provider)
 
 
 
 
 
 
 
726
  start = time.perf_counter()
727
-
728
  try:
729
- # Use HF InferenceClient with featherless-ai provider for Qwen models.
730
- client = HFInferenceClient(
731
- model=target_model_base,
732
- token=self.hf_token,
733
- provider="featherless-ai",
734
- timeout=timeout
735
- )
736
-
737
- response = client.chat_completion(
738
- messages=req.messages,
739
- max_tokens=req.max_new_tokens or self.default_max_new_tokens,
740
- temperature=req.temperature or self.default_temperature,
741
- top_p=req.top_p or self.default_top_p,
742
- )
743
- latency_ms = (time.perf_counter() - start) * 1000
744
-
745
- # Extract text from response
746
- if hasattr(response, "choices") and response.choices:
747
- content = response.choices[0].message.content or ""
748
- text = content.strip()
749
- else:
750
- text = self._extract_text(response)
751
-
752
- log_model_call(
753
- LOGGER,
754
- provider="featherless-ai",
755
- model=target_model_base,
756
- endpoint="featherless-ai_inference",
757
- latency_ms=latency_ms,
758
- input_tokens=None,
759
- output_tokens=None,
760
- status="ok",
761
- task_type=req.task_type,
762
- request_tag=req.request_tag,
763
- retry_attempt=1,
764
- fallback_depth=fallback_depth,
765
- route=route,
766
- )
767
- self._record_attempt(
768
- task_type=req.task_type,
769
- provider="featherless-ai",
770
- route=route,
771
- fallback_depth=fallback_depth,
772
- )
773
- self._bump_metric("requests_ok", 1)
774
- return text
775
-
776
  except Exception as exc:
777
  latency_ms = (time.perf_counter() - start) * 1000
778
- self._bump_metric("requests_error", 1)
779
  log_model_call(
780
  LOGGER,
781
- provider="featherless-ai",
782
- model=target_model_base,
783
- endpoint="featherless-ai_inference",
784
  latency_ms=latency_ms,
785
  input_tokens=None,
786
  output_tokens=None,
@@ -793,255 +927,10 @@ class InferenceClient:
793
  fallback_depth=fallback_depth,
794
  route=route,
795
  )
796
- LOGGER.warning(
797
- "task=%s provider=featherless-ai model=%s fallback_depth=%s failed: %s",
798
- req.task_type,
799
- target_model_base,
800
- fallback_depth,
801
- exc,
802
- )
803
- raise
804
-
805
- def _call_hf_inference(self, req: InferenceRequest, *, provider: str, route: str, fallback_depth: int) -> str:
806
- if not self.hf_token:
807
- raise RuntimeError("HF_TOKEN is not set")
808
-
809
- target_model = req.model or self.default_model
810
- chat_model = target_model if ":" in target_model else f"{target_model}:fastest"
811
- url = self.hf_chat_url
812
-
813
- # Log which model is actually being used
814
- model_base = target_model.split(":")[0] if ":" in target_model else target_model
815
- LOGGER.debug(
816
- f"๐Ÿ“Œ Calling HF inference: task={req.task_type} model={model_base} "
817
- f"route={route} depth={fallback_depth}"
818
- )
819
-
820
- payload: Dict[str, object] = {
821
- "model": chat_model,
822
- "messages": req.messages,
823
- "stream": False,
824
- "max_tokens": req.max_new_tokens or self.default_max_new_tokens,
825
- "temperature": req.temperature,
826
- "top_p": req.top_p,
827
- }
828
- headers = {
829
- "Authorization": f"Bearer {self.hf_token}",
830
- "Content-Type": "application/json",
831
- "X-MathPulse-Task": (req.task_type or "default").strip().lower(),
832
- }
833
- if route == "pro-priority" and self.pro_route_header_name.strip():
834
- headers[self.pro_route_header_name.strip()] = self.pro_route_header_value
835
-
836
- timeout = self._timeout_for(req, provider)
837
-
838
- resp, latency_ms, retry_attempt = self._post_with_retry(
839
- url,
840
- headers=headers,
841
- payload=payload,
842
- timeout=timeout,
843
- provider=provider,
844
- model=target_model,
845
- task_type=req.task_type,
846
- request_tag=req.request_tag,
847
- fallback_depth=fallback_depth,
848
- route=route,
849
- )
850
- self._bump_bucket("status_code_counts", str(resp.status_code), 1)
851
- if resp.status_code != 200:
852
- self._bump_metric("requests_error", 1)
853
- raise RuntimeError(f"HF Inference error {resp.status_code}: {resp.text}")
854
-
855
- data = resp.json()
856
- text = self._extract_text(data)
857
-
858
- # Log successful inference with actual model and response time
859
- LOGGER.info(
860
- f"โœ… HF inference success: task={req.task_type} model={model_base} "
861
- f"latency={latency_ms:.0f}ms tokens_out={len(text.split())}"
862
- )
863
-
864
- log_model_call(
865
- LOGGER,
866
- provider=provider,
867
- model=target_model,
868
- endpoint=url,
869
- latency_ms=latency_ms,
870
- input_tokens=None,
871
- output_tokens=None,
872
- status="ok",
873
- task_type=req.task_type,
874
- request_tag=req.request_tag,
875
- retry_attempt=retry_attempt,
876
- fallback_depth=fallback_depth,
877
- route=route,
878
- )
879
- self._bump_metric("requests_ok", 1)
880
- return text
881
-
882
- def _call_featherless(self, req: InferenceRequest, *, provider: str, route: str, fallback_depth: int) -> str:
883
- if not self.featherless_api_key:
884
- raise RuntimeError("FEATHERLESS_API_KEY is not set")
885
-
886
- target_model = req.model or self.default_model
887
- url = self.featherless_chat_url
888
-
889
- payload: Dict[str, object] = {
890
- "model": target_model,
891
- "messages": req.messages,
892
- "stream": False,
893
- "max_tokens": req.max_new_tokens or self.default_max_new_tokens,
894
- "temperature": req.temperature,
895
- "top_p": req.top_p,
896
- }
897
- headers = {
898
- "Authorization": f"Bearer {self.featherless_api_key}",
899
- "Content-Type": "application/json",
900
- "X-MathPulse-Task": (req.task_type or "default").strip().lower(),
901
- }
902
-
903
- timeout = self._timeout_for(req, provider)
904
-
905
- resp, latency_ms, retry_attempt = self._post_with_retry(
906
- url,
907
- headers=headers,
908
- payload=payload,
909
- timeout=timeout,
910
- provider=provider,
911
- model=target_model,
912
- task_type=req.task_type,
913
- request_tag=req.request_tag,
914
- fallback_depth=fallback_depth,
915
- route=route,
916
- )
917
- self._bump_bucket("status_code_counts", str(resp.status_code), 1)
918
- if resp.status_code != 200:
919
- self._bump_metric("requests_error", 1)
920
- raise RuntimeError(f"Featherless API error {resp.status_code}: {resp.text}")
921
-
922
- data = resp.json()
923
- text = self._extract_text(data)
924
- log_model_call(
925
- LOGGER,
926
- provider=provider,
927
- model=target_model,
928
- endpoint=url,
929
- latency_ms=latency_ms,
930
- input_tokens=None,
931
- output_tokens=None,
932
- status="ok",
933
- task_type=req.task_type,
934
- request_tag=req.request_tag,
935
- retry_attempt=retry_attempt,
936
- fallback_depth=fallback_depth,
937
- route=route,
938
- )
939
- self._bump_metric("requests_ok", 1)
940
- return text
941
-
942
- def _call_deepseek(self, req: InferenceRequest, *, provider: str, route: str, fallback_depth: int) -> str:
943
- """Call DeepSeek API (OpenAI-compatible endpoint)."""
944
- if not self.deepseek_api_key:
945
- raise RuntimeError("DEEPSEEK_API_KEY is not set")
946
-
947
- target_model = req.model or self.default_model
948
- url = self.deepseek_chat_url
949
-
950
- model_base = target_model.split(":")[0] if ":" in target_model else target_model
951
- LOGGER.debug(
952
- f"๐Ÿ“Œ Calling DeepSeek: task={req.task_type} model={model_base} "
953
- f"route={route} depth={fallback_depth}"
954
- )
955
-
956
- payload: Dict[str, object] = {
957
- "model": target_model,
958
- "messages": req.messages,
959
- "stream": False,
960
- "max_tokens": req.max_new_tokens or self.default_max_new_tokens,
961
- "temperature": req.temperature,
962
- "top_p": req.top_p,
963
- }
964
- headers = {
965
- "Authorization": f"Bearer {self.deepseek_api_key}",
966
- "Content-Type": "application/json",
967
- "X-MathPulse-Task": (req.task_type or "default").strip().lower(),
968
- }
969
-
970
- timeout = self._timeout_for(req, provider)
971
-
972
- resp, latency_ms, retry_attempt = self._post_with_retry(
973
- url,
974
- headers=headers,
975
- payload=payload,
976
- timeout=timeout,
977
- provider=provider,
978
- model=target_model,
979
- task_type=req.task_type,
980
- request_tag=req.request_tag,
981
- fallback_depth=fallback_depth,
982
- route=route,
983
- )
984
- self._bump_bucket("status_code_counts", str(resp.status_code), 1)
985
- if resp.status_code != 200:
986
  self._bump_metric("requests_error", 1)
987
- raise RuntimeError(f"DeepSeek API error {resp.status_code}: {resp.text}")
988
-
989
- data = resp.json()
990
- text = self._extract_text(data)
991
-
992
- LOGGER.info(
993
- f"โœ… DeepSeek success: task={req.task_type} model={model_base} "
994
- f"latency={latency_ms:.0f}ms tokens_out={len(text.split())}"
995
- )
996
-
997
- log_model_call(
998
- LOGGER,
999
- provider=provider,
1000
- model=target_model,
1001
- endpoint=url,
1002
- latency_ms=latency_ms,
1003
- input_tokens=None,
1004
- output_tokens=None,
1005
- status="ok",
1006
- task_type=req.task_type,
1007
- request_tag=req.request_tag,
1008
- retry_attempt=retry_attempt,
1009
- fallback_depth=fallback_depth,
1010
- route=route,
1011
- )
1012
- self._bump_metric("requests_ok", 1)
1013
- return text
1014
-
1015
- def _call_local_space(self, req: InferenceRequest, *, provider: str, route: str, fallback_depth: int) -> str:
1016
- target_model = req.model or self.default_model
1017
- url = f"{self.local_space_url.rstrip('/')}{self.local_generate_path}"
1018
-
1019
- prompt = self._messages_to_prompt(req.messages)
1020
- payload: Dict[str, object] = {
1021
- "data": [
1022
- prompt,
1023
- [],
1024
- req.temperature,
1025
- req.top_p,
1026
- req.max_new_tokens,
1027
- ]
1028
- }
1029
- headers = {"Content-Type": "application/json"}
1030
-
1031
- timeout = self._timeout_for(req, provider)
1032
 
1033
- resp, latency_ms, retry_attempt = self._post_with_retry(
1034
- url,
1035
- headers=headers,
1036
- payload=payload,
1037
- timeout=timeout,
1038
- provider=provider,
1039
- model=target_model,
1040
- task_type=req.task_type,
1041
- request_tag=req.request_tag,
1042
- fallback_depth=fallback_depth,
1043
- route=route,
1044
- )
1045
  self._bump_bucket("status_code_counts", str(resp.status_code), 1)
1046
 
1047
  if resp.status_code != 200:
@@ -1080,7 +969,7 @@ class InferenceClient:
1080
  status="ok",
1081
  task_type=req.task_type,
1082
  request_tag=req.request_tag,
1083
- retry_attempt=retry_attempt,
1084
  fallback_depth=fallback_depth,
1085
  route=route,
1086
  )
@@ -1121,32 +1010,39 @@ class InferenceClient:
1121
 
1122
  def _clean_response_text(self, text: str) -> str:
1123
  """Strip JSON braces, template artifacts, and whitespace from response text."""
1124
- # Strip leading/trailing whitespace
1125
  text = text.strip()
1126
-
1127
- # Remove wrapping JSON braces or artifact markers
1128
  if text.startswith("{") and text.endswith("}"):
1129
  try:
1130
- # Try to parse as JSON - if it fails, return as-is
1131
  parsed = json.loads(text)
1132
- # If it's a dict with a "content" or "text" field, use that
1133
  if isinstance(parsed, dict):
1134
  if "content" in parsed:
1135
  text = str(parsed["content"]).strip()
1136
  elif "text" in parsed:
1137
  text = str(parsed["text"]).strip()
1138
  except json.JSONDecodeError:
1139
- # Not valid JSON, just clean up braces
1140
  text = text.strip("{}")
1141
-
1142
- # Remove any trailing artifact markers
1143
  if text.startswith("```json") or text.startswith("```"):
1144
  text = re.sub(r"^```(?:json)?", "", text).strip()
1145
  if text.endswith("```"):
1146
  text = text[:-3].strip()
1147
-
1148
  return text.strip()
1149
 
1150
 
1151
- def create_default_client() -> InferenceClient:
1152
- return InferenceClient()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  import requests
12
  import yaml
13
+ from openai import OpenAI, APIError, RateLimitError, APITimeoutError
14
 
15
+ from .ai_client import get_deepseek_client, CHAT_MODEL, REASONER_MODEL, DEEPSEEK_BASE_URL
16
  from .logging_utils import configure_structured_logging, log_model_call
17
 
18
  LOGGER = configure_structured_logging("mathpulse.inference")
19
  TEMP_CHAT_MODEL_OVERRIDE_ENV = "INFERENCE_CHAT_MODEL_TEMP_OVERRIDE"
20
 
21
+ # โ”€โ”€ Model Profiles โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
22
+ # A profile sets multiple env defaults in one shot.
23
+ # Individual env vars (DEEPSEEK_MODEL, DEEPSEEK_REASONER_MODEL, etc.) still override.
24
+ # Usage: MODEL_PROFILE=dev or MODEL_PROFILE=prod or MODEL_PROFILE=budget
25
+ # Profiles can also be applied at runtime via the admin panel without restart.
26
+
27
+ _MODEL_PROFILES: dict[str, dict[str, str]] = {
28
+ "dev": {
29
+ "INFERENCE_MODEL_ID": CHAT_MODEL,
30
+ "INFERENCE_CHAT_MODEL_ID": CHAT_MODEL,
31
+ "HF_QUIZ_MODEL_ID": CHAT_MODEL,
32
+ "HF_RAG_MODEL_ID": CHAT_MODEL,
33
+ "INFERENCE_LOCK_MODEL_ID": CHAT_MODEL,
34
+ },
35
+ "prod": {
36
+ "INFERENCE_MODEL_ID": CHAT_MODEL,
37
+ "INFERENCE_CHAT_MODEL_ID": CHAT_MODEL,
38
+ "HF_QUIZ_MODEL_ID": CHAT_MODEL,
39
+ "HF_RAG_MODEL_ID": REASONER_MODEL,
40
+ "INFERENCE_LOCK_MODEL_ID": CHAT_MODEL,
41
+ },
42
+ "budget": {
43
+ "INFERENCE_MODEL_ID": CHAT_MODEL,
44
+ "INFERENCE_CHAT_MODEL_ID": CHAT_MODEL,
45
+ "HF_QUIZ_MODEL_ID": CHAT_MODEL,
46
+ "HF_RAG_MODEL_ID": CHAT_MODEL,
47
+ "INFERENCE_LOCK_MODEL_ID": CHAT_MODEL,
48
+ },
49
+ }
50
+
51
+ # โ”€โ”€ Runtime Override Store โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
52
+ # Mutated at runtime by the admin panel via /api/admin/model-config.
53
+ # Priority: above env vars, below INFERENCE_ENFORCE_LOCK_MODEL.
54
+ # Persisted to Firestore so backend cold-restarts restore the last admin-set config.
55
+
56
+ _RUNTIME_OVERRIDES: dict[str, str] = {}
57
+ _RUNTIME_PROFILE: str = ""
58
+
59
+ _FS_COLLECTION = "system_config"
60
+ _FS_DOC = "active_model_config"
61
+
62
+
63
+ def _save_runtime_config_to_firestore() -> None:
64
+ try:
65
+ from firebase_admin import firestore as fs
66
+
67
+ db = fs.client()
68
+ db.collection(_FS_COLLECTION).document(_FS_DOC).set(
69
+ {
70
+ "profile": _RUNTIME_PROFILE,
71
+ "overrides": _RUNTIME_OVERRIDES,
72
+ "updatedAt": fs.SERVER_TIMESTAMP,
73
+ }
74
+ )
75
+ except Exception as e:
76
+ LOGGER.warning("Could not persist model config to Firestore: %s", e)
77
+
78
+
79
+ def _load_runtime_config_from_firestore() -> None:
80
+ try:
81
+ from firebase_admin import firestore as fs
82
+
83
+ db = fs.client()
84
+ doc = db.collection(_FS_COLLECTION).document(_FS_DOC).get()
85
+ if not doc.exists:
86
+ return
87
+ data = doc.to_dict() or {}
88
+ profile = str(data.get("profile", "")).strip().lower()
89
+ overrides = data.get("overrides", {})
90
+ if profile and profile in _MODEL_PROFILES:
91
+ global _RUNTIME_PROFILE
92
+ _RUNTIME_PROFILE = profile
93
+ _RUNTIME_OVERRIDES.clear()
94
+ _RUNTIME_OVERRIDES.update(_MODEL_PROFILES[profile])
95
+ if isinstance(overrides, dict):
96
+ for key, value in overrides.items():
97
+ _RUNTIME_OVERRIDES[str(key)] = str(value)
98
+ LOGGER.info("Restored runtime model config from Firestore: profile=%s", profile)
99
+ except ImportError:
100
+ LOGGER.debug("Firebase not available (optional for DeepSeek-only)")
101
+ except Exception as e:
102
+ LOGGER.warning("Could not restore model config from Firestore: %s", e)
103
+
104
+
105
+ def _apply_model_profile() -> None:
106
+ profile_name = os.getenv("MODEL_PROFILE", "").strip().lower()
107
+ if not profile_name:
108
+ return
109
+ profile = _MODEL_PROFILES.get(profile_name)
110
+ if profile is None:
111
+ LOGGER.warning("MODEL_PROFILE='%s' is not a known profile.", profile_name)
112
+ return
113
+ for key, value in profile.items():
114
+ if not os.environ.get(key):
115
+ os.environ[key] = value
116
+ LOGGER.info("Startup model profile applied: %s", profile_name)
117
+
118
+
119
+ _apply_model_profile()
120
+ _load_runtime_config_from_firestore()
121
+
122
+
123
+ def set_runtime_model_profile(profile_name: str) -> None:
124
+ """Apply a named profile at runtime without restarting the process."""
125
+ global _RUNTIME_PROFILE, _RUNTIME_OVERRIDES
126
+ normalized = profile_name.strip().lower()
127
+ profile = _MODEL_PROFILES.get(normalized)
128
+ if not profile:
129
+ raise ValueError(
130
+ f"Unknown profile: '{profile_name}'. Valid values: {list(_MODEL_PROFILES.keys())}"
131
+ )
132
+ _RUNTIME_PROFILE = normalized
133
+ _RUNTIME_OVERRIDES.clear()
134
+ _RUNTIME_OVERRIDES.update(profile)
135
+ LOGGER.info("Runtime model profile switched to: %s", profile_name)
136
+ _save_runtime_config_to_firestore()
137
+
138
+
139
+ def set_runtime_model_override(key: str, value: str) -> None:
140
+ """Set a single model env key at runtime."""
141
+ _RUNTIME_OVERRIDES[key] = value
142
+ LOGGER.info("Runtime model override set: %s = %s", key, value)
143
+ _save_runtime_config_to_firestore()
144
+
145
+
146
+ def reset_runtime_overrides() -> None:
147
+ """Clear all runtime overrides."""
148
+ global _RUNTIME_PROFILE
149
+ _RUNTIME_OVERRIDES.clear()
150
+ _RUNTIME_PROFILE = ""
151
+ LOGGER.info("Runtime model overrides cleared.")
152
+ _save_runtime_config_to_firestore()
153
+
154
+
155
+ def get_current_runtime_config() -> dict:
156
+ resolved: dict[str, str] = {}
157
+ for key in {
158
+ "INFERENCE_MODEL_ID", "INFERENCE_CHAT_MODEL_ID",
159
+ "HF_QUIZ_MODEL_ID", "HF_RAG_MODEL_ID", "INFERENCE_LOCK_MODEL_ID",
160
+ }:
161
+ resolved[key] = _resolve_key(key)
162
+ return {
163
+ "profile": _RUNTIME_PROFILE,
164
+ "overrides": dict(_RUNTIME_OVERRIDES),
165
+ "resolved": resolved,
166
+ }
167
+
168
+
169
+ def _resolve_key(key: str) -> str:
170
+ if value := _RUNTIME_OVERRIDES.get(key):
171
+ return value
172
+ if _RUNTIME_PROFILE and _RUNTIME_PROFILE in _MODEL_PROFILES:
173
+ if value := _MODEL_PROFILES[_RUNTIME_PROFILE].get(key):
174
+ return value
175
+ return os.getenv(key, "")
176
+
177
+
178
+ def get_model_for_task(task_type: str) -> str:
179
+ task = (task_type or "default").strip().lower()
180
+ enforce_lock = os.getenv("INFERENCE_ENFORCE_LOCK_MODEL", "true").strip().lower() in {"1", "true", "yes", "on"}
181
+ if enforce_lock:
182
+ override = (
183
+ _RUNTIME_OVERRIDES.get("INFERENCE_LOCK_MODEL_ID")
184
+ or os.getenv("INFERENCE_LOCK_MODEL_ID")
185
+ or CHAT_MODEL
186
+ )
187
+ return override
188
+ task_key_map = {
189
+ "chat": "INFERENCE_CHAT_MODEL_ID",
190
+ "quiz_generation": "HF_QUIZ_MODEL_ID",
191
+ "rag_lesson": "HF_RAG_MODEL_ID",
192
+ "rag_problem": "HF_RAG_MODEL_ID",
193
+ "rag_analysis_context": "HF_RAG_MODEL_ID",
194
+ }
195
+ if env_key := task_key_map.get(task):
196
+ if resolved := _resolve_key(env_key):
197
+ return resolved
198
+ return _resolve_key("INFERENCE_MODEL_ID") or CHAT_MODEL
199
+
200
+
201
+ def model_supports_thinking(model_id: str = "") -> bool:
202
+ mid = (model_id or os.getenv("INFERENCE_MODEL_ID") or "").strip()
203
+ return mid == REASONER_MODEL
204
+
205
 
206
  def _normalize_local_space_url(raw_url: str) -> str:
207
  """Accept either hf.space host or huggingface.co/spaces URL for local_space provider."""
 
209
  if not cleaned:
210
  return "http://127.0.0.1:7860"
211
 
 
 
212
  match = re.match(r"^https?://huggingface\.co/spaces/([^/]+)/([^/]+)$", cleaned, re.IGNORECASE)
213
  if match:
214
  owner = match.group(1).strip().lower()
 
224
  model: Optional[str] = None
225
  task_type: str = "default"
226
  request_tag: str = ""
227
+ max_new_tokens: int = 900
228
  temperature: float = 0.2
229
  top_p: float = 0.9
230
  repetition_penalty: float = 1.15
231
  timeout_sec: Optional[int] = None
232
+ enable_thinking: bool = False
233
 
234
 
235
  class InferenceClient:
236
+ def __init__(self, firestore_client: Optional[Any] = None) -> None:
237
+ self.firestore = firestore_client
238
+ self._last_persist_time = 0.0
239
+ self._persist_throttle_sec = 30.0
240
+
241
  config_paths = [
242
+ Path("./config/models.yaml"),
243
+ Path("/config/models.yaml"),
244
+ Path("/app/config/models.yaml"),
245
+ Path.cwd() / "config" / "models.yaml",
246
+ Path(__file__).resolve().parents[2] / "config" / "models.yaml",
247
  ]
248
+
249
  config: Dict[str, object] = {}
250
  config_path = None
251
+
252
  for path in config_paths:
253
  if path.exists():
254
  config_path = path
 
256
  config = yaml.safe_load(fh) or {}
257
  LOGGER.info(f"โœ… Loaded config from {config_path}")
258
  break
259
+
260
  if not config_path:
261
  LOGGER.warning(f"โš ๏ธ Config file not found. Checked: {[str(p) for p in config_paths]}")
262
  LOGGER.warning(f" CWD: {Path.cwd()}")
 
270
  if isinstance(primary_cfg, dict):
271
  primary = primary_cfg
272
 
273
+ self.provider = "deepseek"
274
+ self.ds_api_key = os.getenv("DEEPSEEK_API_KEY", "")
275
+ self.ds_base_url = os.getenv("DEEPSEEK_BASE_URL", DEEPSEEK_BASE_URL)
276
+ self.ds_chat_model = os.getenv("DEEPSEEK_MODEL", CHAT_MODEL)
277
+ self.ds_reasoner_model = os.getenv("DEEPSEEK_REASONER_MODEL", REASONER_MODEL)
278
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  self.local_space_url = _normalize_local_space_url(
280
  os.getenv("INFERENCE_LOCAL_SPACE_URL", "http://127.0.0.1:7860")
281
  )
282
  self.local_generate_path = os.getenv("INFERENCE_LOCAL_SPACE_GENERATE_PATH", "/gradio_api/call/generate")
 
 
283
 
284
+ self.enforce_lock_model = os.getenv("INFERENCE_ENFORCE_LOCK_MODEL", "true").strip().lower() in {"1", "true", "yes", "on"}
285
+ self.lock_model_id = os.getenv("INFERENCE_LOCK_MODEL_ID", CHAT_MODEL).strip() or CHAT_MODEL
286
 
287
+ default_model_fallback = str(primary.get("id") or CHAT_MODEL)
288
  env_model_id = os.getenv("INFERENCE_MODEL_ID", "").strip()
289
  self.default_model = env_model_id or default_model_fallback
290
+
291
  default_max_tokens = str(primary.get("max_new_tokens") or 512)
292
  self.default_max_new_tokens = int(os.getenv("INFERENCE_MAX_NEW_TOKENS", default_max_tokens))
293
+
294
  default_temp = str(primary.get("temperature") or 0.2)
295
  self.default_temperature = float(os.getenv("INFERENCE_TEMPERATURE", default_temp))
296
+
297
  default_top_p = str(primary.get("top_p") or 0.9)
298
  self.default_top_p = float(os.getenv("INFERENCE_TOP_P", default_top_p))
299
+
 
300
  self.chat_model_override = os.getenv("INFERENCE_CHAT_MODEL_ID", "").strip()
301
  self.chat_model_temp_override = os.getenv(TEMP_CHAT_MODEL_OVERRIDE_ENV, "").strip()
302
  self.chat_strict_model_only = os.getenv("INFERENCE_CHAT_STRICT_MODEL_ONLY", "true").strip().lower() in {"1", "true", "yes", "on"}
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
+ self.ds_timeout_sec = int(os.getenv("INFERENCE_HF_TIMEOUT_SEC", "90"))
305
  self.local_timeout_sec = int(os.getenv("INFERENCE_LOCAL_SPACE_TIMEOUT_SEC", "90"))
306
  self.max_retries = int(os.getenv("INFERENCE_MAX_RETRIES", "3"))
307
  self.backoff_sec = float(os.getenv("INFERENCE_BACKOFF_SEC", "1.5"))
308
+ self.interactive_timeout_sec = int(os.getenv("INFERENCE_INTERACTIVE_TIMEOUT_SEC", str(self.ds_timeout_sec)))
309
+ self.background_timeout_sec = int(os.getenv("INFERENCE_BACKGROUND_TIMEOUT_SEC", str(self.ds_timeout_sec)))
310
  self.interactive_max_retries = int(os.getenv("INFERENCE_INTERACTIVE_MAX_RETRIES", str(self.max_retries)))
311
  self.background_max_retries = int(os.getenv("INFERENCE_BACKGROUND_MAX_RETRIES", str(self.max_retries)))
312
  self.interactive_backoff_sec = float(os.getenv("INFERENCE_INTERACTIVE_BACKOFF_SEC", str(self.backoff_sec)))
 
327
  )
328
  self.cpu_only_tasks = {v.strip().lower() for v in cpu_tasks_raw.split(",") if v.strip()}
329
 
 
 
 
 
 
 
330
  interactive_tasks_raw = os.getenv(
331
  "INFERENCE_INTERACTIVE_TASKS",
332
  "chat,verify_solution,daily_insight",
 
338
  )
339
 
340
  # Default task-to-model routing.
 
341
  self.task_model_map: Dict[str, str] = {
342
+ "chat": CHAT_MODEL,
343
+ "verify_solution": CHAT_MODEL,
344
+ "lesson_generation": CHAT_MODEL,
345
+ "quiz_generation": CHAT_MODEL,
346
+ "learning_path": CHAT_MODEL,
347
+ "daily_insight": CHAT_MODEL,
348
+ "risk_classification": CHAT_MODEL,
349
+ "risk_narrative": CHAT_MODEL,
350
  }
 
351
  self.task_fallback_model_map: Dict[str, List[str]] = {
352
+ "chat": [CHAT_MODEL],
353
+ "verify_solution": [CHAT_MODEL],
 
 
 
 
 
 
354
  }
 
355
  self.model_provider_map: Dict[str, str] = {}
356
  self.task_provider_map: Dict[str, str] = {}
357
  if isinstance(config, dict):
 
364
  for task, model in task_models.items()
365
  if str(task).strip() and str(model).strip()
366
  }
 
367
  self.task_model_map.update(config_task_models)
368
  task_fallback_models = routing_cfg.get("task_fallback_model_map", {})
369
  if isinstance(task_fallback_models, dict):
 
404
  else:
405
  env_override_note = ""
406
 
407
+ if self.enforce_lock_model:
408
+ lock_map_before = dict(self.task_model_map)
409
+ self.default_model = self.lock_model_id
410
  for task_key in list(self.task_model_map.keys()):
411
+ self.task_model_map[task_key] = self.lock_model_id
412
  self.fallback_models = []
413
  self.task_fallback_model_map = {
414
  task_key: [] for task_key in self.task_model_map.keys()
415
  }
416
+ LOGGER.info(f"๐Ÿ”’ INFERENCE_ENFORCE_LOCK_MODEL enabled: locking all inference tasks to {self.lock_model_id}")
417
+ LOGGER.info(f" Cleared fallback models")
418
+ LOGGER.info(f" Task model mappings forced from: {lock_map_before}")
 
419
 
 
420
  config_status = "from file" if config_path else "hardcoded defaults (no config file found)"
421
  effective_chat_model_for_logs = self.chat_model_override or self.task_model_map.get("chat", self.default_model)
422
  LOGGER.info(f"โœ… InferenceClient initialized {config_status}{env_override_note}")
 
424
  LOGGER.info(f" Chat model: {effective_chat_model_for_logs}")
425
  LOGGER.info(f" Chat temp override ({TEMP_CHAT_MODEL_OVERRIDE_ENV}): {self.chat_model_temp_override or 'disabled'}")
426
  LOGGER.info(f" Chat strict model lock: {self.chat_strict_model_only}")
427
+ LOGGER.info(f" Global model lock: {self.enforce_lock_model}")
428
  LOGGER.info(f" Verify solution model: {self.task_model_map.get('verify_solution', self.default_model)}")
429
  LOGGER.info(f" Full task_model_map: {self.task_model_map}")
430
 
 
436
  "requests_error": 0,
437
  "retries_total": 0,
438
  "fallback_attempts": 0,
439
+ "latency_sum_ms": 0.0,
440
+ "latency_count": 0,
441
  "route_counts": {},
442
  "task_counts": {},
443
  "provider_counts": {},
444
  "status_code_counts": {},
445
  }
446
 
447
+ self._load_persistent_metrics()
448
+
449
  def _bump_metric(self, key: str, inc: int = 1) -> None:
450
  with self._metrics_lock:
451
  current = self._metrics.get(key) or 0
452
  if not isinstance(current, int):
453
  current = 0
454
  self._metrics[key] = current + inc
455
+ self._persist_metrics()
456
 
457
  def _bump_bucket(self, key: str, bucket: str, inc: int = 1) -> None:
458
  with self._metrics_lock:
 
464
  if not isinstance(current, int):
465
  current = 0
466
  mapping[bucket] = current + inc
467
+ self._persist_metrics()
468
+
469
+ def _record_completion(self, *, latency_ms: float) -> None:
470
+ with self._metrics_lock:
471
+ self._metrics["latency_sum_ms"] = (self._metrics.get("latency_sum_ms") or 0.0) + latency_ms
472
+ self._metrics["latency_count"] = (self._metrics.get("latency_count") or 0) + 1
473
+ self._persist_metrics()
474
+
475
+ def _load_persistent_metrics(self) -> None:
476
+ if not self.firestore:
477
+ return
478
+ try:
479
+ doc_ref = self.firestore.collection("system_metrics").document("inference_stats")
480
+ doc = doc_ref.get()
481
+ if doc.exists:
482
+ data = doc.to_dict() or {}
483
+ with self._metrics_lock:
484
+ for k, v in data.items():
485
+ if k in self._metrics:
486
+ if isinstance(v, (int, float)):
487
+ self._metrics[k] = v
488
+ elif isinstance(v, dict) and isinstance(self._metrics[k], dict):
489
+ self._metrics[k].update(v)
490
+ LOGGER.info("โœ… Persistent inference metrics loaded from Firestore")
491
+ except Exception as e:
492
+ LOGGER.warning(f"โš ๏ธ Failed to load persistent metrics: {e}")
493
+
494
+ def _persist_metrics(self, force: bool = False) -> None:
495
+ if not self.firestore:
496
+ return
497
+
498
+ now = time.time()
499
+ if not force and (now - self._last_persist_time < self._persist_throttle_sec):
500
+ return
501
+
502
+ try:
503
+ self._last_persist_time = now
504
+ doc_ref = self.firestore.collection("system_metrics").document("inference_stats")
505
+ with self._metrics_lock:
506
+ snapshot = dict(self._metrics)
507
+
508
+ doc_ref.set(snapshot, merge=True)
509
+ except Exception as e:
510
+ LOGGER.warning(f"โš ๏ธ Failed to persist metrics: {e}")
511
 
512
  def _record_attempt(self, *, task_type: str, provider: str, route: str, fallback_depth: int) -> None:
513
  self._bump_metric("requests_total", 1)
 
519
 
520
  def snapshot_metrics(self) -> Dict[str, Any]:
521
  with self._metrics_lock:
522
+ l_sum = self._metrics.get("latency_sum_ms") or 0.0
523
+ l_count = self._metrics.get("latency_count") or 0
524
+ avg_latency = round(l_sum / l_count, 2) if l_count > 0 else 0.0
525
+
526
  snapshot = {
527
  "uptime_sec": round(max(0.0, time.time() - self._metrics_started_at), 2),
528
  "requests_total": self._metrics.get("requests_total") or 0,
 
530
  "requests_error": self._metrics.get("requests_error") or 0,
531
  "retries_total": self._metrics.get("retries_total") or 0,
532
  "fallback_attempts": self._metrics.get("fallback_attempts") or 0,
533
+ "avg_latency_ms": avg_latency,
534
+ "active_model": self.default_model,
535
+ "primary_provider": self.provider,
536
  "route_counts": dict(self._metrics.get("route_counts") or {}),
537
  "task_counts": dict(self._metrics.get("task_counts") or {}),
538
  "provider_counts": dict(self._metrics.get("provider_counts") or {}),
 
544
  effective_task = (req.task_type or "default").strip().lower()
545
  request_tag = req.request_tag.strip() or f"{effective_task}-{int(time.time() * 1000)}"
546
  selected_model, model_selection_source = self._resolve_primary_model(req)
547
+
548
  model_chain = self._model_chain_for_task(effective_task, selected_model)
549
  last_error: Optional[Exception] = None
550
+
551
+ model_base = selected_model
552
+
 
 
 
553
  LOGGER.info(
554
+ f"๐Ÿ“ค request_tag={request_tag} task={effective_task} source={model_selection_source} "
555
+ f"selected_model={model_base} (primary)"
556
  )
557
  LOGGER.info(f" fallback_chain={model_chain[1:] if len(model_chain) > 1 else 'none'}")
558
 
 
559
  for fallback_depth, model_name in enumerate(model_chain):
560
  request_for_model = InferenceRequest(
561
  messages=req.messages,
 
568
  repetition_penalty=req.repetition_penalty,
569
  timeout_sec=req.timeout_sec,
570
  )
571
+
572
+ try:
573
+ result = self._call_deepseek(request_for_model, fallback_depth)
574
+ if fallback_depth > 0:
575
+ LOGGER.info(f"โœ… Fallback succeeded at depth={fallback_depth} model={model_name}")
576
+ return result
577
+ except Exception as exc:
578
+ last_error = exc
579
+ fallback_hint = f" (depth {fallback_depth})" if fallback_depth > 0 else ""
580
+ LOGGER.warning(
581
+ f"โš ๏ธ Attempt failed{fallback_hint}: task={request_for_model.task_type} "
582
+ f"model={model_name} error={exc.__class__.__name__}: {str(exc)[:100]}"
583
+ )
 
584
 
585
  if last_error:
586
  raise last_error
 
593
  effective_task = (req.task_type or "default").strip().lower()
594
  runtime_chat_override = self._runtime_chat_model_override()
595
 
 
 
 
 
596
  if effective_task == "chat" and runtime_chat_override:
597
  selected_model = runtime_chat_override
598
  model_selection_source = "chat_temp_override_env"
 
606
  selected_model = self.task_model_map.get(effective_task, self.default_model)
607
  model_selection_source = "task_map"
608
 
609
+ if self.enforce_lock_model:
610
+ effective_lock_model_id = self.lock_model_id
611
  if effective_task == "chat":
612
+ effective_lock_model_id = runtime_chat_override or self.chat_model_override or self.lock_model_id
613
 
614
+ selected_base = (selected_model or "").split(":", 1)[0].strip()
615
+ lock_base = (effective_lock_model_id or "").split(":", 1)[0].strip()
616
  if selected_base != lock_base:
617
  LOGGER.warning(
618
+ f"โš ๏ธ Model lock replaced requested model {selected_model} with {effective_lock_model_id}"
619
  )
620
+ selected_model = effective_lock_model_id
621
+ model_selection_source = f"{model_selection_source}:model_lock"
622
 
623
  if effective_task == "chat" and self.chat_strict_model_only:
624
  return selected_model, f"{model_selection_source}:chat_strict_model_only"
625
 
 
 
 
 
 
626
  return selected_model, model_selection_source
627
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
628
  def _model_chain_for_task(self, task_type: str, selected_model: str) -> List[str]:
629
  normalized = (task_type or "default").strip().lower()
630
  runtime_chat_override = self._runtime_chat_model_override() if normalized == "chat" else ""
631
+ chat_lock_model_id = runtime_chat_override or (self.chat_model_override if normalized == "chat" else "")
632
 
633
+ if self.enforce_lock_model:
634
  if normalized == "chat":
635
+ locked_model = (chat_lock_model_id or self.lock_model_id or "").strip()
636
  else:
637
+ locked_model = (self.lock_model_id or "").strip()
638
  return [locked_model] if locked_model else []
639
 
640
  if normalized == "chat" and self.chat_strict_model_only:
641
+ chat_model = (chat_lock_model_id or selected_model or "").strip()
642
  return [chat_model] if chat_model else []
643
 
644
  per_task_candidates = self.task_fallback_model_map.get(task_type, [])
 
658
  return deduped[:max_models]
659
  return deduped
660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
661
  def _retry_profile(self, task_type: str) -> Tuple[int, float]:
662
  normalized = (task_type or "default").strip().lower()
663
  if normalized in self.interactive_tasks:
 
674
  return self.interactive_timeout_sec
675
  return self.background_timeout_sec
676
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
677
  def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
678
  parts: List[str] = []
679
  for msg in messages:
 
686
  prefix = "SYSTEM"
687
  elif role == "assistant":
688
  prefix = "ASSISTANT"
689
+ parts.append(f"{prefix}:\n{content}")
690
  parts.append("ASSISTANT:")
691
+ return "\n\n".join(parts)
692
 
693
  def _latest_user_message(self, messages: List[Dict[str, str]]) -> str:
694
  for msg in reversed(messages):
 
698
  return content
699
  return self._messages_to_prompt(messages)
700
 
701
+ def _call_deepseek(self, req: InferenceRequest, fallback_depth: int) -> str:
702
+ """Call DeepSeek API with OpenAI-compatible chat completions."""
703
+ if not self.ds_api_key:
704
+ raise RuntimeError("DEEPSEEK_API_KEY is not set")
705
+
706
+ target_model = req.model or self.default_model
707
+ route = "deepseek"
708
+ task_type = req.task_type or "default"
709
+
710
+ LOGGER.debug(
711
+ f"๐Ÿ“ž Calling DeepSeek: task={task_type} model={target_model} "
712
+ f"route={route} depth={fallback_depth}"
 
 
 
 
 
 
 
713
  )
714
+
715
+ timeout = self._timeout_for(req, "deepseek")
716
  max_retries, backoff_sec = self._retry_profile(task_type)
 
717
 
718
+ client = get_deepseek_client()
 
 
 
719
 
720
+ # Build chat completions params
721
+ params: Dict[str, Any] = {
722
+ "model": target_model,
723
+ "messages": req.messages,
724
+ "max_tokens": req.max_new_tokens or self.default_max_new_tokens,
725
+ }
726
+
727
+ if target_model == REASONER_MODEL:
728
+ params["max_tokens"] = req.max_new_tokens or 1024
729
+ else:
730
+ params["temperature"] = req.temperature
731
+ params["top_p"] = req.top_p
732
+
733
+ # Use JSON mode for quiz generation
734
+ if task_type == "quiz_generation" and target_model != REASONER_MODEL:
735
+ params["response_format"] = {"type": "json_object"}
736
+
737
+ for attempt in range(max_retries):
738
+ self._record_attempt(
739
+ task_type=task_type,
740
+ provider="deepseek",
741
+ route=route,
742
+ fallback_depth=fallback_depth,
743
+ )
744
  start = time.perf_counter()
745
  try:
746
+ response = client.chat.completions.create(**params, timeout=timeout)
 
747
  latency_ms = (time.perf_counter() - start) * 1000
748
+
749
+ content = response.choices[0].message.content or ""
750
+ reasoning = getattr(response.choices[0].message, "reasoning_content", None)
751
+
752
+ text = content.strip()
753
+ if reasoning:
754
+ text = f"{reasoning}\n{text}"
755
+
756
  log_model_call(
757
  LOGGER,
758
+ provider="deepseek",
759
+ model=target_model,
760
+ endpoint=self.ds_base_url,
761
  latency_ms=latency_ms,
762
  input_tokens=None,
763
  output_tokens=None,
764
+ status="ok",
 
 
765
  task_type=task_type,
766
+ request_tag=req.request_tag,
767
  retry_attempt=attempt + 1,
768
  fallback_depth=fallback_depth,
769
  route=route,
770
  )
771
+ self._record_attempt(
772
+ task_type=task_type,
773
+ provider="deepseek",
774
+ route=route,
775
+ fallback_depth=fallback_depth,
776
+ )
777
+ self._record_completion(latency_ms=latency_ms)
778
+ self._bump_metric("requests_ok", 1)
779
+ return text
780
 
781
+ except RateLimitError:
782
+ latency_ms = (time.perf_counter() - start) * 1000
783
+ if attempt < max_retries - 1:
784
+ log_model_call(
785
+ LOGGER,
786
+ provider="deepseek",
787
+ model=target_model,
788
+ endpoint=self.ds_base_url,
789
+ latency_ms=latency_ms,
790
+ input_tokens=None,
791
+ output_tokens=None,
792
+ status="error",
793
+ error_class="RateLimitError",
794
+ error_message="rate limited",
795
+ task_type=task_type,
796
+ request_tag=req.request_tag,
797
+ retry_attempt=attempt + 1,
798
+ fallback_depth=fallback_depth,
799
+ route=route,
800
+ )
801
+ self._bump_metric("retries_total", 1)
802
+ time.sleep(backoff_sec * (attempt + 1) * random.uniform(0.9, 1.2))
803
+ continue
804
+ self._bump_metric("requests_error", 1)
805
+ raise RuntimeError("DeepSeek API rate limit reached. Please try again shortly.")
806
+
807
+ except APITimeoutError:
808
+ latency_ms = (time.perf_counter() - start) * 1000
809
+ if attempt < max_retries - 1:
810
+ log_model_call(
811
+ LOGGER,
812
+ provider="deepseek",
813
+ model=target_model,
814
+ endpoint=self.ds_base_url,
815
+ latency_ms=latency_ms,
816
+ input_tokens=None,
817
+ output_tokens=None,
818
+ status="error",
819
+ error_class="APITimeoutError",
820
+ error_message="timeout",
821
+ task_type=task_type,
822
+ request_tag=req.request_tag,
823
+ retry_attempt=attempt + 1,
824
+ fallback_depth=fallback_depth,
825
+ route=route,
826
+ )
827
+ self._bump_metric("retries_total", 1)
828
+ time.sleep(backoff_sec * (attempt + 1) * random.uniform(0.9, 1.2))
829
+ continue
830
+ self._bump_metric("requests_error", 1)
831
+ raise RuntimeError("DeepSeek API timed out. Please retry.")
832
+
833
+ except APIError as e:
834
+ latency_ms = (time.perf_counter() - start) * 1000
835
+ if attempt < max_retries - 1:
836
+ log_model_call(
837
+ LOGGER,
838
+ provider="deepseek",
839
+ model=target_model,
840
+ endpoint=self.ds_base_url,
841
+ latency_ms=latency_ms,
842
+ input_tokens=None,
843
+ output_tokens=None,
844
+ status="error",
845
+ error_class="APIError",
846
+ error_message=str(e)[:200],
847
+ task_type=task_type,
848
+ request_tag=req.request_tag,
849
+ retry_attempt=attempt + 1,
850
+ fallback_depth=fallback_depth,
851
+ route=route,
852
+ )
853
+ self._bump_metric("retries_total", 1)
854
+ time.sleep(backoff_sec * (attempt + 1) * random.uniform(0.9, 1.2))
855
+ continue
856
+ self._bump_metric("requests_error", 1)
857
+ raise RuntimeError(f"DeepSeek API error: {str(e)}")
858
+
859
+ except Exception as exc:
860
+ latency_ms = (time.perf_counter() - start) * 1000
861
+ self._bump_metric("requests_error", 1)
862
  log_model_call(
863
  LOGGER,
864
+ provider="deepseek",
865
+ model=target_model,
866
+ endpoint=self.ds_base_url,
867
  latency_ms=latency_ms,
868
  input_tokens=None,
869
  output_tokens=None,
870
  status="error",
871
+ error_class=exc.__class__.__name__,
872
+ error_message=str(exc)[:200],
873
  task_type=task_type,
874
+ request_tag=req.request_tag,
875
  retry_attempt=attempt + 1,
876
  fallback_depth=fallback_depth,
877
  route=route,
878
  )
879
+ raise
 
 
 
 
880
 
881
+ raise RuntimeError(f"DeepSeek call failed after {max_retries} attempts")
 
 
 
 
 
 
882
 
883
+ def _call_local_space(self, req: InferenceRequest, *, provider: str, route: str, fallback_depth: int) -> str:
884
  target_model = req.model or self.default_model
885
+ url = f"{self.local_space_url.rstrip('/')}{self.local_generate_path}"
886
+
887
+ prompt = self._messages_to_prompt(req.messages)
888
+ payload: Dict[str, object] = {
889
+ "data": [
890
+ prompt,
891
+ [],
892
+ req.temperature,
893
+ req.top_p,
894
+ req.max_new_tokens,
895
+ ]
896
+ }
897
+ headers = {"Content-Type": "application/json"}
898
+
899
  timeout = self._timeout_for(req, provider)
900
+
901
+ self._record_attempt(
902
+ task_type=req.task_type,
903
+ provider=provider,
904
+ route=route,
905
+ fallback_depth=fallback_depth,
906
+ )
907
  start = time.perf_counter()
908
+
909
  try:
910
+ resp = requests.post(url, headers=headers, json=payload, timeout=timeout)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
911
  except Exception as exc:
912
  latency_ms = (time.perf_counter() - start) * 1000
 
913
  log_model_call(
914
  LOGGER,
915
+ provider=provider,
916
+ model=target_model,
917
+ endpoint=url,
918
  latency_ms=latency_ms,
919
  input_tokens=None,
920
  output_tokens=None,
 
927
  fallback_depth=fallback_depth,
928
  route=route,
929
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
930
  self._bump_metric("requests_error", 1)
931
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
932
 
933
+ latency_ms = (time.perf_counter() - start) * 1000
 
 
 
 
 
 
 
 
 
 
 
934
  self._bump_bucket("status_code_counts", str(resp.status_code), 1)
935
 
936
  if resp.status_code != 200:
 
969
  status="ok",
970
  task_type=req.task_type,
971
  request_tag=req.request_tag,
972
+ retry_attempt=1,
973
  fallback_depth=fallback_depth,
974
  route=route,
975
  )
 
1010
 
1011
  def _clean_response_text(self, text: str) -> str:
1012
  """Strip JSON braces, template artifacts, and whitespace from response text."""
 
1013
  text = text.strip()
1014
+
 
1015
  if text.startswith("{") and text.endswith("}"):
1016
  try:
 
1017
  parsed = json.loads(text)
 
1018
  if isinstance(parsed, dict):
1019
  if "content" in parsed:
1020
  text = str(parsed["content"]).strip()
1021
  elif "text" in parsed:
1022
  text = str(parsed["text"]).strip()
1023
  except json.JSONDecodeError:
 
1024
  text = text.strip("{}")
1025
+
 
1026
  if text.startswith("```json") or text.startswith("```"):
1027
  text = re.sub(r"^```(?:json)?", "", text).strip()
1028
  if text.endswith("```"):
1029
  text = text[:-3].strip()
1030
+
1031
  return text.strip()
1032
 
1033
 
1034
+ def create_default_client(firestore_client: Optional[Any] = None) -> InferenceClient:
1035
+ return InferenceClient(firestore_client=firestore_client)
1036
+
1037
+
1038
+ def is_sequential_model(model_id: str = "") -> bool:
1039
+ mid = (model_id or os.getenv("INFERENCE_MODEL_ID") or "").strip()
1040
+ if not mid:
1041
+ return False
1042
+ if mid == REASONER_MODEL:
1043
+ return True
1044
+ if _RUNTIME_OVERRIDES:
1045
+ lock = _RUNTIME_OVERRIDES.get("INFERENCE_LOCK_MODEL_ID", "")
1046
+ if lock == REASONER_MODEL:
1047
+ return True
1048
+ return False
startup_validation.py CHANGED
@@ -30,28 +30,33 @@ def validate_imports() -> None:
30
  import uvicorn # noqa
31
  import pydantic # noqa
32
  logger.info(" โœ“ FastAPI, Uvicorn, Pydantic OK")
33
-
34
  # Backend services (use ABSOLUTE imports like deployed code)
35
- from services.inference_client import InferenceClient, create_default_client # noqa
 
 
 
 
 
36
  logger.info(" โœ“ InferenceClient imports OK")
37
-
38
  from automation_engine import automation_engine # noqa
39
  logger.info(" โœ“ automation_engine imports OK")
40
-
41
  from analytics import compute_competency_analysis # noqa
42
  logger.info(" โœ“ analytics imports OK")
43
-
44
  # Firebase
45
  try:
46
  import firebase_admin # noqa
47
  logger.info(" โœ“ firebase_admin imports OK")
48
  except ImportError:
49
  logger.warning(" โš  firebase_admin not available (OK if Firebase not needed)")
50
-
51
  # ML & inference
52
- from huggingface_hub import InferenceClient as HFInferenceClient # noqa
53
- logger.info(" โœ“ HuggingFace Hub imports OK")
54
-
55
  logger.info("โœ… All critical imports validated")
56
  except ImportError as e:
57
  raise StartupError(
@@ -72,47 +77,79 @@ def validate_imports() -> None:
72
  def validate_environment() -> None:
73
  """Verify required environment variables are set."""
74
  logger.info("๐Ÿ” Validating environment variables...")
75
-
76
- # CRITICAL: HF_TOKEN for inference
77
- hf_token = os.environ.get("HF_TOKEN")
78
- api_key = os.environ.get("HUGGING_FACE_API_TOKEN")
79
- legacy_api_key = os.environ.get("HUGGINGFACE_API_TOKEN")
80
- if not hf_token and not api_key and not legacy_api_key:
81
  logger.warning(
82
- "โš  WARNING: HF_TOKEN is not set as an environment variable.\n"
83
- " On HF Spaces, this should be set as a SPACE SECRET.\n"
84
  " AI inference will fail without this token.\n"
85
- " Use: python set-hf-secrets.py to set the secret."
86
  )
87
  else:
88
- logger.info(" โœ“ HF_TOKEN/HUGGING_FACE_API_TOKEN/HUGGINGFACE_API_TOKEN is set")
89
-
90
  # Check inference provider config
91
- inference_provider = os.getenv("INFERENCE_PROVIDER", "hf_inference")
92
  logger.info(f" โœ“ INFERENCE_PROVIDER: {inference_provider}")
93
-
94
  # Check model IDs
95
  chat_model = os.getenv("INFERENCE_CHAT_MODEL_ID") or os.getenv("INFERENCE_MODEL_ID") or "deepseek-chat"
96
  logger.info(f" โœ“ Chat model configured: {chat_model}")
97
 
98
  chat_strict = os.getenv("INFERENCE_CHAT_STRICT_MODEL_ONLY", "true").strip().lower() in {"1", "true", "yes", "on"}
99
  chat_hard_trigger = os.getenv("INFERENCE_CHAT_HARD_TRIGGER_ENABLED", "false").strip().lower() in {"1", "true", "yes", "on"}
100
- enforce_qwen_only = os.getenv("INFERENCE_ENFORCE_QWEN_ONLY", "false").strip().lower() in {"1", "true", "yes", "on"}
101
- qwen_lock_model = os.getenv("INFERENCE_QWEN_LOCK_MODEL", "deepseek-chat").strip() or "deepseek-chat"
102
- logger.info(f" โœ“ INFERENCE_CHAT_STRICT_MODEL_ONLY: {chat_strict}")
103
- logger.info(f" โœ“ INFERENCE_CHAT_HARD_TRIGGER_ENABLED: {chat_hard_trigger}")
104
- logger.info(f" โœ“ INFERENCE_ENFORCE_QWEN_ONLY: {enforce_qwen_only}")
105
- logger.info(f" โœ“ INFERENCE_QWEN_LOCK_MODEL: {qwen_lock_model}")
 
 
 
 
106
  if not chat_strict:
107
  logger.warning(" โš  Chat strict model lock is disabled; chat may fallback to alternate models")
108
  if chat_strict and chat_hard_trigger:
109
  logger.warning(
110
  " โš  Chat hard trigger is enabled while strict chat lock is on; hard escalation will be bypassed"
111
  )
112
-
 
 
113
  logger.info("โœ… Environment variables OK")
114
 
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  def validate_config_files() -> None:
117
  """Verify config files exist and are readable."""
118
  logger.info("๐Ÿ” Validating configuration files...")
@@ -154,7 +191,9 @@ def validate_config_files() -> None:
154
  )
155
 
156
  logger.info(f" โœ“ Using model config: {readable_model_config}")
157
-
 
 
158
  logger.info("โœ… Configuration files OK")
159
 
160
 
@@ -202,26 +241,26 @@ def validate_file_structure() -> None:
202
  logger.info(
203
  f" โ„น Optional build file not present at runtime: {joined}"
204
  )
205
-
206
  logger.info("โœ… File structure OK")
207
 
208
 
209
  def validate_inference_client_config() -> None:
210
  """Validate InferenceClient can load its config."""
211
  logger.info("๐Ÿ” Validating InferenceClient configuration...")
212
-
213
  try:
214
  # Try to create the client (this will load config from YAML)
215
  from services.inference_client import create_default_client
216
  client = create_default_client()
217
-
218
  # Verify critical attributes
219
  if not hasattr(client, 'task_model_map'):
220
  raise StartupError("โŒ InferenceClient missing task_model_map attribute")
221
-
222
  if not hasattr(client, 'task_provider_map'):
223
  raise StartupError("โŒ InferenceClient missing task_provider_map attribute")
224
-
225
  # Check that required tasks are mapped
226
  required_tasks = ['chat', 'verify_solution', 'lesson_generation', 'quiz_generation']
227
  for task in required_tasks:
@@ -245,9 +284,9 @@ def validate_inference_client_config() -> None:
245
  "โŒ Chat strict model lock is enabled but effective chat model chain is not singular.\n"
246
  " Check INFERENCE_CHAT_STRICT_MODEL_ONLY and routing.task_fallback_model_map.chat\n"
247
  )
248
-
249
  logger.info("โœ… InferenceClient configuration OK")
250
-
251
  except StartupError:
252
  raise
253
  except Exception as e:
@@ -258,15 +297,49 @@ def validate_inference_client_config() -> None:
258
  ) from e
259
 
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  def run_all_validations() -> None:
262
  """Run comprehensive startup validation.
263
-
264
  If any check fails, exits with clear error message visible in logs.
265
  """
266
  logger.info("=" * 70)
267
  logger.info("๐Ÿš€ STARTUP VALIDATION - Checking all critical dependencies")
268
  logger.info("=" * 70)
269
-
270
  strict_mode = os.getenv("STARTUP_VALIDATION_STRICT", "false").strip().lower() in {"1", "true", "yes", "on"}
271
 
272
  try:
@@ -275,11 +348,11 @@ def run_all_validations() -> None:
275
  validate_environment()
276
  validate_config_files()
277
  validate_inference_client_config()
278
-
279
  logger.info("=" * 70)
280
  logger.info("โœ… ALL STARTUP VALIDATIONS PASSED")
281
  logger.info("=" * 70)
282
-
283
  except StartupError as e:
284
  logger.error("=" * 70)
285
  logger.error(str(e))
@@ -298,4 +371,4 @@ def run_all_validations() -> None:
298
  logger.warning(
299
  "โš ๏ธ Continuing startup after unexpected validation error because "
300
  "STARTUP_VALIDATION_STRICT is disabled."
301
- )
 
30
  import uvicorn # noqa
31
  import pydantic # noqa
32
  logger.info(" โœ“ FastAPI, Uvicorn, Pydantic OK")
33
+
34
  # Backend services (use ABSOLUTE imports like deployed code)
35
+ from services.inference_client import (
36
+ InferenceClient, create_default_client, is_sequential_model,
37
+ get_current_runtime_config, get_model_for_task, model_supports_thinking,
38
+ set_runtime_model_profile, set_runtime_model_override, reset_runtime_overrides,
39
+ _MODEL_PROFILES,
40
+ ) # noqa
41
  logger.info(" โœ“ InferenceClient imports OK")
42
+
43
  from automation_engine import automation_engine # noqa
44
  logger.info(" โœ“ automation_engine imports OK")
45
+
46
  from analytics import compute_competency_analysis # noqa
47
  logger.info(" โœ“ analytics imports OK")
48
+
49
  # Firebase
50
  try:
51
  import firebase_admin # noqa
52
  logger.info(" โœ“ firebase_admin imports OK")
53
  except ImportError:
54
  logger.warning(" โš  firebase_admin not available (OK if Firebase not needed)")
55
+
56
  # ML & inference
57
+ from services.ai_client import get_deepseek_client, CHAT_MODEL, REASONER_MODEL # noqa
58
+ logger.info(" โœ“ DeepSeek AI client imports OK")
59
+
60
  logger.info("โœ… All critical imports validated")
61
  except ImportError as e:
62
  raise StartupError(
 
77
  def validate_environment() -> None:
78
  """Verify required environment variables are set."""
79
  logger.info("๐Ÿ” Validating environment variables...")
80
+
81
+ # CRITICAL: DEEPSEEK_API_KEY for inference
82
+ ds_api_key = os.environ.get("DEEPSEEK_API_KEY")
83
+ if not ds_api_key:
 
 
84
  logger.warning(
85
+ "โš  WARNING: DEEPSEEK_API_KEY is not set as an environment variable.\n"
 
86
  " AI inference will fail without this token.\n"
87
+ " Use: Set DEEPSEEK_API_KEY in your .env or space secrets."
88
  )
89
  else:
90
+ logger.info(" โœ“ DEEPSEEK_API_KEY is set")
91
+
92
  # Check inference provider config
93
+ inference_provider = os.getenv("INFERENCE_PROVIDER", "deepseek")
94
  logger.info(f" โœ“ INFERENCE_PROVIDER: {inference_provider}")
95
+
96
  # Check model IDs
97
  chat_model = os.getenv("INFERENCE_CHAT_MODEL_ID") or os.getenv("INFERENCE_MODEL_ID") or "deepseek-chat"
98
  logger.info(f" โœ“ Chat model configured: {chat_model}")
99
 
100
  chat_strict = os.getenv("INFERENCE_CHAT_STRICT_MODEL_ONLY", "true").strip().lower() in {"1", "true", "yes", "on"}
101
  chat_hard_trigger = os.getenv("INFERENCE_CHAT_HARD_TRIGGER_ENABLED", "false").strip().lower() in {"1", "true", "yes", "on"}
102
+ enforce_lock_model = os.getenv("INFERENCE_ENFORCE_LOCK_MODEL", "true").strip().lower() in {"1", "true", "yes", "on"}
103
+ lock_model_id = os.getenv("INFERENCE_LOCK_MODEL_ID", "deepseek-chat").strip() or "deepseek-chat"
104
+ logger.info(f" โœ“ INFERENCE_ENFORCE_LOCK_MODEL: {enforce_lock_model}")
105
+ logger.info(f" โœ“ INFERENCE_LOCK_MODEL_ID: {lock_model_id}")
106
+ model_profile = os.getenv("MODEL_PROFILE", "").strip().lower()
107
+ quiz_model = os.getenv("HF_QUIZ_MODEL_ID", "").strip()
108
+ rag_model = os.getenv("HF_RAG_MODEL_ID", "").strip()
109
+ logger.info(f" โœ“ MODEL_PROFILE: {model_profile or 'not set (using individual env vars)'}")
110
+ logger.info(f" โœ“ HF_QUIZ_MODEL_ID: {quiz_model or 'not set (using defaults)'}")
111
+ logger.info(f" โœ“ HF_RAG_MODEL_ID: {rag_model or 'not set (using defaults)'}")
112
  if not chat_strict:
113
  logger.warning(" โš  Chat strict model lock is disabled; chat may fallback to alternate models")
114
  if chat_strict and chat_hard_trigger:
115
  logger.warning(
116
  " โš  Chat hard trigger is enabled while strict chat lock is on; hard escalation will be bypassed"
117
  )
118
+
119
+ _validate_embedding_model()
120
+
121
  logger.info("โœ… Environment variables OK")
122
 
123
 
124
+ EXPECTED_EMBEDDING_MODEL = "BAAI/bge-small-en-v1.5"
125
+
126
+ def _validate_embedding_model() -> None:
127
+ embedding_model = os.getenv("EMBEDDING_MODEL", "").strip()
128
+ if not embedding_model:
129
+ logger.warning(
130
+ "WARNING: EMBEDDING_MODEL env var is not set. "
131
+ f"Expected: {EXPECTED_EMBEDDING_MODEL}. "
132
+ "RAG retrieval will fail without an embedding model."
133
+ )
134
+ elif embedding_model != EXPECTED_EMBEDDING_MODEL:
135
+ logger.warning(
136
+ f"WARNING: EMBEDDING_MODEL is set to '{embedding_model}' โ€” "
137
+ f"expected '{EXPECTED_EMBEDDING_MODEL}'. "
138
+ "Confirm this is intentional before deploying."
139
+ )
140
+ from services.ai_client import CHAT_MODEL, REASONER_MODEL # noqa
141
+ generation_model_ids = [
142
+ CHAT_MODEL, REASONER_MODEL,
143
+ ]
144
+ if embedding_model in generation_model_ids:
145
+ logger.warning(
146
+ f"CRITICAL: EMBEDDING_MODEL is set to a generation model ('{embedding_model}'). "
147
+ "This will break RAG retrieval. Set it to 'BAAI/bge-small-en-v1.5'."
148
+ )
149
+ else:
150
+ logger.info(f" EMBEDDING_MODEL: {embedding_model or 'not set'}")
151
+
152
+
153
  def validate_config_files() -> None:
154
  """Verify config files exist and are readable."""
155
  logger.info("๐Ÿ” Validating configuration files...")
 
191
  )
192
 
193
  logger.info(f" โœ“ Using model config: {readable_model_config}")
194
+
195
+ _validate_model_config_fields(readable_model_config)
196
+
197
  logger.info("โœ… Configuration files OK")
198
 
199
 
 
241
  logger.info(
242
  f" โ„น Optional build file not present at runtime: {joined}"
243
  )
244
+
245
  logger.info("โœ… File structure OK")
246
 
247
 
248
  def validate_inference_client_config() -> None:
249
  """Validate InferenceClient can load its config."""
250
  logger.info("๐Ÿ” Validating InferenceClient configuration...")
251
+
252
  try:
253
  # Try to create the client (this will load config from YAML)
254
  from services.inference_client import create_default_client
255
  client = create_default_client()
256
+
257
  # Verify critical attributes
258
  if not hasattr(client, 'task_model_map'):
259
  raise StartupError("โŒ InferenceClient missing task_model_map attribute")
260
+
261
  if not hasattr(client, 'task_provider_map'):
262
  raise StartupError("โŒ InferenceClient missing task_provider_map attribute")
263
+
264
  # Check that required tasks are mapped
265
  required_tasks = ['chat', 'verify_solution', 'lesson_generation', 'quiz_generation']
266
  for task in required_tasks:
 
284
  "โŒ Chat strict model lock is enabled but effective chat model chain is not singular.\n"
285
  " Check INFERENCE_CHAT_STRICT_MODEL_ONLY and routing.task_fallback_model_map.chat\n"
286
  )
287
+
288
  logger.info("โœ… InferenceClient configuration OK")
289
+
290
  except StartupError:
291
  raise
292
  except Exception as e:
 
297
  ) from e
298
 
299
 
300
+ def _validate_model_config_fields(config_path: str) -> None:
301
+ try:
302
+ import yaml
303
+ with open(config_path, "r", encoding="utf-8") as f:
304
+ config = yaml.safe_load(f) or {}
305
+ except Exception as e:
306
+ raise StartupError(f"โŒ Cannot parse {config_path} as YAML: {e}") from e
307
+
308
+ models = config.get("models", {})
309
+ if not isinstance(models, dict):
310
+ raise StartupError(f"โŒ {config_path}: 'models' section missing or invalid")
311
+
312
+ if "rag_primary" not in models:
313
+ raise StartupError(f"โŒ {config_path}: missing 'models.rag_primary' field")
314
+ rag_primary = models["rag_primary"]
315
+ if isinstance(rag_primary, dict):
316
+ logger.info(f" โœ“ rag_primary model: {rag_primary.get('id', 'UNSET')}")
317
+ else:
318
+ logger.warning(f" โš  rag_primary is not a dict, may cause issues")
319
+
320
+ capabilities = models.get("model_capabilities")
321
+ if not isinstance(capabilities, dict):
322
+ raise StartupError(f"โŒ {config_path}: missing 'models.model_capabilities' section")
323
+ logger.info(f" โœ“ model_capabilities: sequential_only={capabilities.get('sequential_only')}, supports_thinking={capabilities.get('supports_thinking')}")
324
+
325
+ tasks = config.get("routing", {}).get("task_model_map", {})
326
+ rag_tasks = {"rag_lesson", "rag_problem", "rag_analysis_context"}
327
+ missing_rag = rag_tasks - set(str(t).strip().lower() for t in tasks.keys())
328
+ if missing_rag:
329
+ raise StartupError(f"โŒ {config_path}: missing RAG task mappings: {missing_rag}")
330
+
331
+ logger.info(f" โœ“ All RAG task mappings present")
332
+
333
+
334
  def run_all_validations() -> None:
335
  """Run comprehensive startup validation.
336
+
337
  If any check fails, exits with clear error message visible in logs.
338
  """
339
  logger.info("=" * 70)
340
  logger.info("๐Ÿš€ STARTUP VALIDATION - Checking all critical dependencies")
341
  logger.info("=" * 70)
342
+
343
  strict_mode = os.getenv("STARTUP_VALIDATION_STRICT", "false").strip().lower() in {"1", "true", "yes", "on"}
344
 
345
  try:
 
348
  validate_environment()
349
  validate_config_files()
350
  validate_inference_client_config()
351
+
352
  logger.info("=" * 70)
353
  logger.info("โœ… ALL STARTUP VALIDATIONS PASSED")
354
  logger.info("=" * 70)
355
+
356
  except StartupError as e:
357
  logger.error("=" * 70)
358
  logger.error(str(e))
 
371
  logger.warning(
372
  "โš ๏ธ Continuing startup after unexpected validation error because "
373
  "STARTUP_VALIDATION_STRICT is disabled."
374
+ )