Oleksii Obolonskyi commited on
Commit
052a978
·
1 Parent(s): 8faa6a7

Harden HF inference routing

Browse files
Files changed (2) hide show
  1. README.md +5 -7
  2. app.py +89 -14
README.md CHANGED
@@ -57,11 +57,7 @@ export HF_TOKEN=hf_your_token_here
57
  export RAG_HF_MODEL=meta-llama/Llama-3.2-1B-Instruct
58
  ```
59
 
60
- Optional override if you use a dedicated Inference Endpoint:
61
-
62
- ```bash
63
- export RAG_HF_API_URL=https://api-inference.huggingface.co/models/meta-llama/Llama-3.2-1B-Instruct
64
- ```
65
 
66
  ### 3) Prepare sources
67
 
@@ -106,7 +102,8 @@ export RAG_ARTICLE_MANIFEST_PATH=data/normalized/manifest_articles.json
106
  export RAG_EMBED_MODEL=sentence-transformers/all-MiniLM-L6-v2
107
  export HF_TOKEN=hf_your_token_here
108
  export RAG_HF_MODEL=meta-llama/Llama-3.2-1B-Instruct
109
- export RAG_HF_API_URL=https://api-inference.huggingface.co/models/meta-llama/Llama-3.2-1B-Instruct
 
110
  export RAG_OUT_DIR=data/normalized
111
  export RAG_ARTICLE_SOURCES=sources_articles.json
112
  ```
@@ -115,7 +112,8 @@ export RAG_ARTICLE_SOURCES=sources_articles.json
115
 
116
  1. Create a new Space (Streamlit SDK) and push this repo.
117
  2. In Space Settings → Secrets, set `HF_TOKEN` (required) and optionally `GITHUB_TOKEN`.
118
- 3. In Space Settings → Variables, set `RAG_HF_MODEL` or `RAG_HF_API_URL` if you want to override defaults.
 
119
 
120
  ## Common maintenance tasks
121
 
 
57
  export RAG_HF_MODEL=meta-llama/Llama-3.2-1B-Instruct
58
  ```
59
 
60
+ Optional: set `RAG_HF_FALLBACK_MODEL` to retry if the primary model is gated or unavailable.
 
 
 
 
61
 
62
  ### 3) Prepare sources
63
 
 
102
  export RAG_EMBED_MODEL=sentence-transformers/all-MiniLM-L6-v2
103
  export HF_TOKEN=hf_your_token_here
104
  export RAG_HF_MODEL=meta-llama/Llama-3.2-1B-Instruct
105
+ export RAG_LLM_BACKEND=hf
106
+ export RAG_HF_FALLBACK_MODEL=HuggingFaceH4/zephyr-7b-beta
107
  export RAG_OUT_DIR=data/normalized
108
  export RAG_ARTICLE_SOURCES=sources_articles.json
109
  ```
 
112
 
113
  1. Create a new Space (Streamlit SDK) and push this repo.
114
  2. In Space Settings → Secrets, set `HF_TOKEN` (required) and optionally `GITHUB_TOKEN`.
115
+ 3. In Space Settings → Variables, set `RAG_HF_MODEL` (required) and `RAG_LLM_BACKEND=hf`.
116
+ 4. Optional: `RAG_HF_FALLBACK_MODEL` to retry if the primary model is gated or unavailable.
117
 
118
  ## Common maintenance tasks
119
 
app.py CHANGED
@@ -15,7 +15,6 @@ import numpy as np
15
  import faiss
16
  import requests
17
  from huggingface_hub import InferenceClient
18
- from huggingface_hub import InferenceClient
19
  from sentence_transformers import SentenceTransformer
20
 
21
  load_dotenv(Path(__file__).resolve().parent / ".env", override=True)
@@ -39,6 +38,7 @@ HF_TOKEN = (
39
  or ""
40
  ).strip()
41
  HF_MODEL = os.environ.get("RAG_HF_MODEL", "meta-llama/Llama-3.2-1B-Instruct")
 
42
 
43
  OLLAMA_BASE_URL = os.environ.get("RAG_OLLAMA_URL", "http://localhost:11434").rstrip("/")
44
  OLLAMA_MODEL = os.environ.get("RAG_OLLAMA_MODEL", "llama3.2:1b")
@@ -508,12 +508,15 @@ def answer_question(
508
  return "Model error: Empty response from model", citations, False
509
  return sanitize_answer(answer), citations, True
510
 
511
- def build_hf_prompt(user_prompt: str, model_id: str) -> str:
512
- system_msg = (
513
  f"You are an assistant for {COMPANY_NAME}. Contact: {COMPANY_EMAIL}, "
514
  f"{COMPANY_PHONE}. {COMPANY_ABOUT}. Answer only from the provided context. "
515
  "Keep answers concise. Cite sources using the provided citation tags exactly."
516
  )
 
 
 
517
  if "llama-3" in model_id.lower():
518
  return (
519
  "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
@@ -523,15 +526,44 @@ def build_hf_prompt(user_prompt: str, model_id: str) -> str:
523
  return f"System: {system_msg}\nUser: {user_prompt}\nAssistant:"
524
 
525
  @st.cache_resource(show_spinner=False)
526
- def get_hf_client() -> InferenceClient:
527
- return InferenceClient(model=HF_MODEL, token=HF_TOKEN)
528
 
529
  def hf_chat(prompt: str, timeout: Tuple[int, int] = (10, 600)) -> Tuple[str, Optional[str]]:
530
  if not HF_TOKEN:
531
  return "", "Missing HF_TOKEN (or HUGGINGFACEHUB_API_TOKEN)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  try:
533
- client = get_hf_client()
534
- inp = build_hf_prompt(prompt, HF_MODEL)
535
  out = client.text_generation(
536
  inp,
537
  max_new_tokens=512,
@@ -541,14 +573,41 @@ def hf_chat(prompt: str, timeout: Tuple[int, int] = (10, 600)) -> Tuple[str, Opt
541
  )
542
  return (out or "").strip(), None
543
  except Exception as e:
544
- return "", str(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
 
546
  def ollama_chat(prompt: str, timeout: Tuple[int, int] = (10, 600)) -> Tuple[str, Optional[str]]:
547
  url = f"{OLLAMA_BASE_URL}/api/chat"
548
  payload = {
549
  "model": OLLAMA_MODEL,
550
  "messages": [
551
- {"role": "system", "content": f"You are an assistant for {COMPANY_NAME}. Contact: {COMPANY_EMAIL}, {COMPANY_PHONE}. {COMPANY_ABOUT}. Answer only from the provided context. Keep answers concise. Cite sources using the provided citation tags exactly."},
552
  {"role": "user", "content": prompt},
553
  ],
554
  "stream": False,
@@ -575,10 +634,17 @@ def llm_chat(prompt: str, timeout: Tuple[int, int] = (10, 600)) -> Tuple[str, Op
575
  return hf_chat(prompt, timeout=timeout)
576
  if backend == "ollama":
577
  return ollama_chat(prompt, timeout=timeout)
 
 
578
  if (HF_TOKEN or "").strip():
579
  return hf_chat(prompt, timeout=timeout)
580
  return ollama_chat(prompt, timeout=timeout)
581
 
 
 
 
 
 
582
  def github_create_issue(title: str, body: str, labels: Optional[List[str]] = None) -> Tuple[Optional[int], Optional[str]]:
583
  if not GITHUB_TOKEN:
584
  return None, "Missing GITHUB_TOKEN"
@@ -646,9 +712,16 @@ with st.sidebar:
646
  st.write("")
647
  st.subheader("LLM")
648
  st.markdown(f"- Model: `{HF_MODEL}`")
649
- st.markdown(f"- URL: `{HF_API_URL}`")
650
- if not HF_TOKEN:
651
- st.warning("HF_TOKEN is not set. LLM requests will fail until you add it.")
 
 
 
 
 
 
 
652
  st.write("")
653
  st.subheader("Embedding model (retrieval)")
654
  st.code(EMBED_MODEL)
@@ -1019,8 +1092,10 @@ if st.session_state.get("active_action"):
1019
  if ok:
1020
  push_message("assistant", answer, citations=citations, not_found=False)
1021
  else:
1022
- push_message("assistant", answer, citations=[], not_found=True)
1023
- st.session_state["ticket_prefill"] = {"question": q_norm, "citations": citations}
 
 
1024
  st.session_state["last_question"] = q_norm
1025
  st.session_state["last_citations"] = citations
1026
  st.session_state["last_answer"] = answer
 
15
  import faiss
16
  import requests
17
  from huggingface_hub import InferenceClient
 
18
  from sentence_transformers import SentenceTransformer
19
 
20
  load_dotenv(Path(__file__).resolve().parent / ".env", override=True)
 
38
  or ""
39
  ).strip()
40
  HF_MODEL = os.environ.get("RAG_HF_MODEL", "meta-llama/Llama-3.2-1B-Instruct")
41
+ HF_FALLBACK_MODEL = os.environ.get("RAG_HF_FALLBACK_MODEL", "").strip()
42
 
43
  OLLAMA_BASE_URL = os.environ.get("RAG_OLLAMA_URL", "http://localhost:11434").rstrip("/")
44
  OLLAMA_MODEL = os.environ.get("RAG_OLLAMA_MODEL", "llama3.2:1b")
 
508
  return "Model error: Empty response from model", citations, False
509
  return sanitize_answer(answer), citations, True
510
 
511
+ def system_message() -> str:
512
+ return (
513
  f"You are an assistant for {COMPANY_NAME}. Contact: {COMPANY_EMAIL}, "
514
  f"{COMPANY_PHONE}. {COMPANY_ABOUT}. Answer only from the provided context. "
515
  "Keep answers concise. Cite sources using the provided citation tags exactly."
516
  )
517
+
518
+ def build_hf_prompt(user_prompt: str, model_id: str) -> str:
519
+ system_msg = system_message()
520
  if "llama-3" in model_id.lower():
521
  return (
522
  "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
 
526
  return f"System: {system_msg}\nUser: {user_prompt}\nAssistant:"
527
 
528
  @st.cache_resource(show_spinner=False)
529
+ def get_hf_client(model_id: str) -> InferenceClient:
530
+ return InferenceClient(model=model_id, token=HF_TOKEN)
531
 
532
  def hf_chat(prompt: str, timeout: Tuple[int, int] = (10, 600)) -> Tuple[str, Optional[str]]:
533
  if not HF_TOKEN:
534
  return "", "Missing HF_TOKEN (or HUGGINGFACEHUB_API_TOKEN)"
535
+ return hf_chat_with_model(prompt, HF_MODEL, timeout=timeout)
536
+
537
+ def hf_chat_with_model(
538
+ prompt: str,
539
+ model_id: str,
540
+ timeout: Tuple[int, int] = (10, 600),
541
+ ) -> Tuple[str, Optional[str]]:
542
+ client = get_hf_client(model_id)
543
+ messages = [
544
+ {"role": "system", "content": system_message()},
545
+ {"role": "user", "content": prompt},
546
+ ]
547
+ try:
548
+ chat_api = getattr(getattr(client, "chat", None), "completions", None)
549
+ create_fn = getattr(chat_api, "create", None)
550
+ if create_fn:
551
+ resp = create_fn(
552
+ model=model_id,
553
+ messages=messages,
554
+ max_tokens=512,
555
+ temperature=0.2,
556
+ )
557
+ text = (resp.choices[0].message.content or "").strip()
558
+ return text, None
559
+ except Exception as e:
560
+ fallback = hf_fallback_model_error(str(e), model_id)
561
+ if fallback:
562
+ return hf_chat_with_model(prompt, fallback, timeout=timeout)
563
+ return "", hf_format_error(str(e), model_id)
564
+
565
  try:
566
+ inp = build_hf_prompt(prompt, model_id)
 
567
  out = client.text_generation(
568
  inp,
569
  max_new_tokens=512,
 
573
  )
574
  return (out or "").strip(), None
575
  except Exception as e:
576
+ fallback = hf_fallback_model_error(str(e), model_id)
577
+ if fallback:
578
+ return hf_chat_with_model(prompt, fallback, timeout=timeout)
579
+ return "", hf_format_error(str(e), model_id)
580
+
581
+ def hf_fallback_model_error(err: str, model_id: str) -> Optional[str]:
582
+ if not HF_FALLBACK_MODEL:
583
+ return None
584
+ if model_id == HF_FALLBACK_MODEL:
585
+ return None
586
+ err_low = (err or "").lower()
587
+ if any(k in err_low for k in ["401", "403", "gated", "license", "not authorized", "forbidden"]):
588
+ return HF_FALLBACK_MODEL
589
+ if any(k in err_low for k in ["404", "not found", "provider", "unavailable", "service unavailable"]):
590
+ return HF_FALLBACK_MODEL
591
+ return None
592
+
593
+ def hf_format_error(err: str, model_id: str) -> str:
594
+ err_low = (err or "").lower()
595
+ if any(k in err_low for k in ["401", "403", "gated", "license", "not authorized", "forbidden"]):
596
+ return (
597
+ f"{err} This model is gated. Ensure HF_TOKEN has accepted the model license and has access."
598
+ )
599
+ if any(k in err_low for k in ["404", "not found"]):
600
+ return f"{err} Model not found. Verify RAG_HF_MODEL or RAG_HF_API_URL."
601
+ if any(k in err_low for k in ["provider", "unavailable", "service unavailable"]):
602
+ return f"{err} Provider unavailable. Try again later or set RAG_HF_FALLBACK_MODEL."
603
+ return err
604
 
605
  def ollama_chat(prompt: str, timeout: Tuple[int, int] = (10, 600)) -> Tuple[str, Optional[str]]:
606
  url = f"{OLLAMA_BASE_URL}/api/chat"
607
  payload = {
608
  "model": OLLAMA_MODEL,
609
  "messages": [
610
+ {"role": "system", "content": system_message()},
611
  {"role": "user", "content": prompt},
612
  ],
613
  "stream": False,
 
634
  return hf_chat(prompt, timeout=timeout)
635
  if backend == "ollama":
636
  return ollama_chat(prompt, timeout=timeout)
637
+ if is_running_on_spaces():
638
+ return hf_chat(prompt, timeout=timeout)
639
  if (HF_TOKEN or "").strip():
640
  return hf_chat(prompt, timeout=timeout)
641
  return ollama_chat(prompt, timeout=timeout)
642
 
643
+ def is_running_on_spaces() -> bool:
644
+ if os.environ.get("HF_SPACE_ID") or os.environ.get("SPACE_ID"):
645
+ return True
646
+ return (os.environ.get("SYSTEM") or "").strip().lower() == "spaces"
647
+
648
  def github_create_issue(title: str, body: str, labels: Optional[List[str]] = None) -> Tuple[Optional[int], Optional[str]]:
649
  if not GITHUB_TOKEN:
650
  return None, "Missing GITHUB_TOKEN"
 
712
  st.write("")
713
  st.subheader("LLM")
714
  st.markdown(f"- Model: `{HF_MODEL}`")
715
+ st.markdown(f"- Backend: `{(os.environ.get('RAG_LLM_BACKEND') or 'auto')}`")
716
+ st.markdown(f"- HF token set: `{bool(HF_TOKEN)}`")
717
+ if HF_FALLBACK_MODEL:
718
+ st.markdown(f"- HF fallback: `{HF_FALLBACK_MODEL}`")
719
+ if st.button("Test model", key="hf_test", use_container_width=True, disabled=st.session_state["is_thinking"]):
720
+ test_text, test_err = llm_chat("Say OK.")
721
+ if test_err:
722
+ st.error(test_err)
723
+ else:
724
+ st.success(test_text or "OK")
725
  st.write("")
726
  st.subheader("Embedding model (retrieval)")
727
  st.code(EMBED_MODEL)
 
1092
  if ok:
1093
  push_message("assistant", answer, citations=citations, not_found=False)
1094
  else:
1095
+ is_not_found = answer.strip() == "Not found in dataset."
1096
+ push_message("assistant", answer, citations=[], not_found=is_not_found)
1097
+ if is_not_found:
1098
+ st.session_state["ticket_prefill"] = {"question": q_norm, "citations": citations}
1099
  st.session_state["last_question"] = q_norm
1100
  st.session_state["last_citations"] = citations
1101
  st.session_state["last_answer"] = answer