MrA7A1 commited on
Commit
88b7e97
·
verified ·
1 Parent(s): 9a8a7a4

KAPO rollout fix: HF chat completion fallback

Browse files
Files changed (1) hide show
  1. brain_server/api/main.py +22 -1
brain_server/api/main.py CHANGED
@@ -283,7 +283,11 @@ def _load_embed_model() -> None:
283
  if EMBED_MODEL is not None:
284
  return
285
 
286
- from sentence_transformers import SentenceTransformer
 
 
 
 
287
 
288
  model_name = os.getenv("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
289
  EMBED_MODEL = SentenceTransformer(model_name)
@@ -1677,6 +1681,19 @@ def _generate_response(user_input: str, history: list[dict[str, str]], context_b
1677
  max_tokens = 80 if language == "ar" else 96
1678
  model_repo = str(os.getenv("MODEL_REPO", DEFAULT_MODEL_REPO) or DEFAULT_MODEL_REPO).strip()
1679
  client = InferenceClient(model=model_repo, api_key=(str(os.getenv("HF_TOKEN", "") or "").strip() or None))
 
 
 
 
 
 
 
 
 
 
 
 
 
1680
  generated = client.text_generation(
1681
  prompt,
1682
  max_new_tokens=max_tokens,
@@ -1717,6 +1734,8 @@ def _store_chat_trace(request_id: str, payload: dict[str, Any]) -> None:
1717
  executor_url = os.getenv("EXECUTOR_URL", "").strip().rstrip("/")
1718
  if not _executor_roundtrip_allowed("BRAIN_REMOTE_TRACE_STORE_ENABLED", default=True):
1719
  return
 
 
1720
  try:
1721
  requests.post(
1722
  f"{executor_url}/memory/store",
@@ -1729,6 +1748,8 @@ def _store_chat_trace(request_id: str, payload: dict[str, Any]) -> None:
1729
  )
1730
  except requests.exceptions.ReadTimeout:
1731
  logger.info("Chat trace store timed out; continuing")
 
 
1732
  except Exception:
1733
  logger.warning("Failed to store chat trace on executor", exc_info=True)
1734
 
 
283
  if EMBED_MODEL is not None:
284
  return
285
 
286
+ try:
287
+ from sentence_transformers import SentenceTransformer
288
+ except ModuleNotFoundError:
289
+ logger.info("Skipping embedding model load because sentence_transformers is unavailable")
290
+ return
291
 
292
  model_name = os.getenv("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
293
  EMBED_MODEL = SentenceTransformer(model_name)
 
1681
  max_tokens = 80 if language == "ar" else 96
1682
  model_repo = str(os.getenv("MODEL_REPO", DEFAULT_MODEL_REPO) or DEFAULT_MODEL_REPO).strip()
1683
  client = InferenceClient(model=model_repo, api_key=(str(os.getenv("HF_TOKEN", "") or "").strip() or None))
1684
+ try:
1685
+ chat_result = client.chat_completion(
1686
+ messages=[{"role": "user", "content": prompt}],
1687
+ max_tokens=max_tokens,
1688
+ )
1689
+ choices = getattr(chat_result, "choices", None) or []
1690
+ if choices:
1691
+ message = getattr(choices[0], "message", None)
1692
+ generated_text = str(getattr(message, "content", "") or "").strip()
1693
+ if generated_text:
1694
+ return generated_text
1695
+ except Exception:
1696
+ pass
1697
  generated = client.text_generation(
1698
  prompt,
1699
  max_new_tokens=max_tokens,
 
1734
  executor_url = os.getenv("EXECUTOR_URL", "").strip().rstrip("/")
1735
  if not _executor_roundtrip_allowed("BRAIN_REMOTE_TRACE_STORE_ENABLED", default=True):
1736
  return
1737
+ if not executor_url:
1738
+ return
1739
  try:
1740
  requests.post(
1741
  f"{executor_url}/memory/store",
 
1748
  )
1749
  except requests.exceptions.ReadTimeout:
1750
  logger.info("Chat trace store timed out; continuing")
1751
+ except requests.exceptions.ConnectionError:
1752
+ logger.info("Chat trace store skipped because executor is unreachable")
1753
  except Exception:
1754
  logger.warning("Failed to store chat trace on executor", exc_info=True)
1755