MrA7A1 commited on
Commit
53a5ee8
·
verified ·
1 Parent(s): 88b7e97

KAPO self-heal fix: conversational HF fallback

Browse files
Files changed (1) hide show
  1. brain_server/api/main.py +48 -12
brain_server/api/main.py CHANGED
@@ -1664,6 +1664,48 @@ def _project_specific_fast_reply(user_input: str) -> str:
1664
  return ""
1665
 
1666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1667
  def _generate_response(user_input: str, history: list[dict[str, str]], context_block: str) -> str:
1668
  language = _detect_language(user_input)
1669
  exact_reply = _extract_exact_reply_instruction_safe(user_input)
@@ -1678,22 +1720,16 @@ def _generate_response(user_input: str, history: list[dict[str, str]], context_b
1678
  try:
1679
  from huggingface_hub import InferenceClient
1680
  prompt = _build_chat_prompt(user_input, history, context_block)
 
1681
  max_tokens = 80 if language == "ar" else 96
1682
  model_repo = str(os.getenv("MODEL_REPO", DEFAULT_MODEL_REPO) or DEFAULT_MODEL_REPO).strip()
1683
  client = InferenceClient(model=model_repo, api_key=(str(os.getenv("HF_TOKEN", "") or "").strip() or None))
1684
  try:
1685
- chat_result = client.chat_completion(
1686
- messages=[{"role": "user", "content": prompt}],
1687
- max_tokens=max_tokens,
1688
- )
1689
- choices = getattr(chat_result, "choices", None) or []
1690
- if choices:
1691
- message = getattr(choices[0], "message", None)
1692
- generated_text = str(getattr(message, "content", "") or "").strip()
1693
- if generated_text:
1694
- return generated_text
1695
- except Exception:
1696
- pass
1697
  generated = client.text_generation(
1698
  prompt,
1699
  max_new_tokens=max_tokens,
 
1664
  return ""
1665
 
1666
 
1667
+ def _hf_chat_messages(user_input: str, history: list[dict[str, str]], context_block: str) -> list[dict[str, str]]:
1668
+ messages: list[dict[str, str]] = []
1669
+ context_text = str(context_block or "").strip()
1670
+ if context_text:
1671
+ messages.append({"role": "system", "content": context_text})
1672
+ for item in history or []:
1673
+ role = str((item or {}).get("role") or "").strip().lower()
1674
+ content = str((item or {}).get("content") or "").strip()
1675
+ if role in {"system", "user", "assistant"} and content:
1676
+ messages.append({"role": role, "content": content})
1677
+ messages.append({"role": "user", "content": str(user_input or "").strip()})
1678
+ return messages
1679
+
1680
+
1681
+ def _hf_chat_completion_text(client: Any, messages: list[dict[str, str]], max_tokens: int) -> str:
1682
+ result = client.chat_completion(messages=messages, max_tokens=max_tokens)
1683
+ choices = getattr(result, "choices", None)
1684
+ if choices is None and isinstance(result, dict):
1685
+ choices = result.get("choices")
1686
+ choices = choices or []
1687
+ if not choices:
1688
+ return ""
1689
+ first = choices[0]
1690
+ message = getattr(first, "message", None)
1691
+ if message is None and isinstance(first, dict):
1692
+ message = first.get("message")
1693
+ content = getattr(message, "content", None)
1694
+ if content is None and isinstance(message, dict):
1695
+ content = message.get("content")
1696
+ if isinstance(content, list):
1697
+ parts: list[str] = []
1698
+ for item in content:
1699
+ if isinstance(item, dict):
1700
+ text = item.get("text")
1701
+ if text:
1702
+ parts.append(str(text))
1703
+ elif item:
1704
+ parts.append(str(item))
1705
+ content = "\n".join(part for part in parts if part.strip())
1706
+ return str(content or "").strip()
1707
+
1708
+
1709
  def _generate_response(user_input: str, history: list[dict[str, str]], context_block: str) -> str:
1710
  language = _detect_language(user_input)
1711
  exact_reply = _extract_exact_reply_instruction_safe(user_input)
 
1720
  try:
1721
  from huggingface_hub import InferenceClient
1722
  prompt = _build_chat_prompt(user_input, history, context_block)
1723
+ messages = _hf_chat_messages(user_input, history, context_block)
1724
  max_tokens = 80 if language == "ar" else 96
1725
  model_repo = str(os.getenv("MODEL_REPO", DEFAULT_MODEL_REPO) or DEFAULT_MODEL_REPO).strip()
1726
  client = InferenceClient(model=model_repo, api_key=(str(os.getenv("HF_TOKEN", "") or "").strip() or None))
1727
  try:
1728
+ generated_text = _hf_chat_completion_text(client, messages, max_tokens)
1729
+ if generated_text:
1730
+ return generated_text
1731
+ except Exception as exc:
1732
+ logger.info("HF chat-completion path failed; falling back to text-generation (%s)", exc)
 
 
 
 
 
 
 
1733
  generated = client.text_generation(
1734
  prompt,
1735
  max_new_tokens=max_tokens,