Oleksii Obolonskyi commited on
Commit
8aac1c0
·
1 Parent(s): 45772d2

Use HF Router OpenAI client

Browse files
Files changed (2) hide show
  1. app.py +33 -100
  2. requirements.txt +1 -1
app.py CHANGED
@@ -14,7 +14,7 @@ import streamlit as st
14
  import numpy as np
15
  import faiss
16
  import requests
17
- from huggingface_hub import InferenceClient
18
  from sentence_transformers import SentenceTransformer
19
 
20
  load_dotenv(Path(__file__).resolve().parent / ".env", override=True)
@@ -83,15 +83,7 @@ OVERLAP_FILTER = CONFIG.overlap_filter
83
  RETRIEVE_TOPK_MULT = CONFIG.retrieve_topk_mult
84
 
85
  HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
86
- HF_PROVIDER = os.getenv("RAG_HF_PROVIDER", "hf-inference").strip() or "hf-inference"
87
- HF_MODEL_PRIMARY = os.getenv("RAG_HF_MODEL", os.getenv("RAG_HF_MODEL_PRIMARY", "HuggingFaceTB/SmolLM3-3B")).strip()
88
- HF_MODEL_FALLBACKS_RAW = os.getenv("RAG_HF_MODEL_FALLBACKS", "").strip()
89
- HF_MODEL_FALLBACKS = (
90
- [m.strip() for m in HF_MODEL_FALLBACKS_RAW.split(",") if m.strip()]
91
- if HF_MODEL_FALLBACKS_RAW
92
- else ["HuggingFaceTB/SmolLM3-3B", "HuggingFaceTB/SmolLM2-1.7B", "HuggingFaceTB/SmolLM2-360M"]
93
- )
94
- HF_MODEL_CANDIDATES = [HF_MODEL_PRIMARY] + [m for m in HF_MODEL_FALLBACKS if m != HF_MODEL_PRIMARY]
95
 
96
  OLLAMA_BASE_URL = os.environ.get("RAG_OLLAMA_URL", "http://localhost:11434").rstrip("/")
97
  OLLAMA_MODEL = os.environ.get("RAG_OLLAMA_MODEL", "llama3.2:1b")
@@ -623,82 +615,30 @@ def is_running_on_spaces() -> bool:
623
  return True
624
  return (os.environ.get("SYSTEM") or "").strip().lower() == "spaces"
625
 
626
- def get_hf_client(model_id: str) -> InferenceClient:
627
- return InferenceClient(model=model_id, provider=HF_PROVIDER, token=HF_TOKEN)
628
-
629
- def select_active_hf_model() -> str:
630
- if st.session_state.get("hf_active_model"):
631
- return st.session_state["hf_active_model"]
632
- last_err = ""
633
- for model_id in HF_MODEL_CANDIDATES:
634
- try:
635
- client = get_hf_client(model_id)
636
- client.text_generation(
637
- "ping",
638
- max_new_tokens=2,
639
- temperature=0.0,
640
- do_sample=False,
641
- return_full_text=False,
642
- )
643
- st.session_state["hf_active_model"] = model_id
644
- st.session_state.pop("hf_startup_error", None)
645
- return model_id
646
- except Exception as exc:
647
- last_err = str(exc)
648
- st.session_state["hf_active_model"] = HF_MODEL_PRIMARY
649
- if last_err:
650
- st.session_state["hf_startup_error"] = last_err
651
- return HF_MODEL_PRIMARY
652
-
653
- class LLMClient:
654
- def __init__(self, backend: str) -> None:
655
- self.backend = backend
656
-
657
- def generate(self, prompt: str) -> Tuple[str, Optional[str]]:
658
- if self.backend == "ollama":
659
- return ollama_chat(prompt)
660
- return self._hf_generate(prompt)
661
-
662
- def _hf_generate(self, prompt: str) -> Tuple[str, Optional[str]]:
663
- model_id = select_active_hf_model()
664
- client = get_hf_client(model_id)
665
- messages = [
666
- {"role": "system", "content": system_message()},
667
- {"role": "user", "content": prompt},
668
- ]
669
- try:
670
- chat_api = getattr(getattr(client, "chat", None), "completions", None)
671
- create_fn = getattr(chat_api, "create", None)
672
- if create_fn:
673
- resp = create_fn(
674
- model=model_id,
675
- messages=messages,
676
- max_tokens=MAX_GENERATION_TOKENS,
677
- temperature=0.2,
678
- )
679
- text = (resp.choices[0].message.content or "").strip()
680
- return text, None
681
- except Exception as exc:
682
- chat_err = str(exc)
683
- else:
684
- chat_err = ""
685
-
686
- try:
687
- out = client.text_generation(
688
- prompt,
689
- max_new_tokens=MAX_GENERATION_TOKENS,
690
- temperature=0.2,
691
- do_sample=True,
692
- return_full_text=False,
693
- )
694
- return (out or "").strip(), None
695
- except Exception as exc:
696
- err_msg = str(exc) or chat_err
697
- hint = f"HF model: {model_id}; provider: {HF_PROVIDER}."
698
- err_low = err_msg.lower()
699
- if any(k in err_low for k in ["401", "403", "gated", "license", "not authorized", "forbidden"]):
700
- hint += " This model is gated. Ensure HF_TOKEN has accepted the license."
701
- return "", f"{err_msg} ({hint})"
702
 
703
  def ollama_chat(prompt: str, timeout: Tuple[int, int] = (10, 600)) -> Tuple[str, Optional[str]]:
704
  url = f"{OLLAMA_BASE_URL}/api/chat"
@@ -728,15 +668,15 @@ def llm_chat(prompt: str, timeout: Tuple[int, int] = (10, 600)) -> Tuple[str, Op
728
  """
729
  backend = (os.environ.get("RAG_LLM_BACKEND", "") or "").strip().lower()
730
 
731
- if backend == "hf":
732
- return LLMClient("hf").generate(prompt)
733
  if backend == "ollama":
734
- return LLMClient("ollama").generate(prompt)
735
  if is_running_on_spaces():
736
- return LLMClient("hf").generate(prompt)
737
  if (HF_TOKEN or "").strip():
738
- return LLMClient("hf").generate(prompt)
739
- return LLMClient("ollama").generate(prompt)
740
 
741
  def github_create_issue(title: str, body: str, labels: Optional[List[str]] = None) -> Tuple[Optional[int], Optional[str]]:
742
  if not GITHUB_TOKEN:
@@ -804,14 +744,7 @@ with st.sidebar:
804
  st.session_state["open_ticket_ui"] = True
805
  st.write("")
806
  st.subheader("LLM")
807
- backend = os.getenv("RAG_LLM_BACKEND", "auto").strip().lower()
808
- use_hf = backend == "hf" or (
809
- backend == "auto" and (is_running_on_spaces() or (HF_TOKEN or "").strip())
810
- )
811
- active_model = select_active_hf_model() if use_hf else HF_MODEL_PRIMARY
812
- st.markdown(f"- Active model: `{active_model}`")
813
- if use_hf and st.session_state.get("hf_startup_error"):
814
- st.warning("HF model not available; check token/provider/model list.")
815
  st.write("")
816
  st.subheader("Embedding model (retrieval)")
817
  st.code(EMBED_MODEL)
 
14
  import numpy as np
15
  import faiss
16
  import requests
17
+ from openai import OpenAI
18
  from sentence_transformers import SentenceTransformer
19
 
20
  load_dotenv(Path(__file__).resolve().parent / ".env", override=True)
 
83
  RETRIEVE_TOPK_MULT = CONFIG.retrieve_topk_mult
84
 
85
  HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
86
+ HF_MODEL = os.getenv("RAG_HF_MODEL", "Qwen/Qwen2.5-7B-Instruct-1M:featherless-ai").strip()
 
 
 
 
 
 
 
 
87
 
88
  OLLAMA_BASE_URL = os.environ.get("RAG_OLLAMA_URL", "http://localhost:11434").rstrip("/")
89
  OLLAMA_MODEL = os.environ.get("RAG_OLLAMA_MODEL", "llama3.2:1b")
 
615
  return True
616
  return (os.environ.get("SYSTEM") or "").strip().lower() == "spaces"
617
 
618
+ @st.cache_resource(show_spinner=False)
619
+ def get_hf_router_client() -> OpenAI:
620
+ return OpenAI(
621
+ base_url="https://router.huggingface.co/v1",
622
+ api_key=HF_TOKEN,
623
+ )
624
+
625
+ def hf_chat(prompt: str) -> Tuple[str, Optional[str]]:
626
+ if not HF_TOKEN:
627
+ return "", "Missing HF_TOKEN (or HUGGINGFACEHUB_API_TOKEN)"
628
+ try:
629
+ client = get_hf_router_client()
630
+ completion = client.chat.completions.create(
631
+ model=HF_MODEL,
632
+ messages=[
633
+ {"role": "system", "content": "You are a helpful assistant."},
634
+ {"role": "user", "content": prompt},
635
+ ],
636
+ max_tokens=MAX_GENERATION_TOKENS,
637
+ temperature=0.2,
638
+ )
639
+ return (completion.choices[0].message.content or "").strip(), None
640
+ except Exception as e:
641
+ return "", str(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
642
 
643
  def ollama_chat(prompt: str, timeout: Tuple[int, int] = (10, 600)) -> Tuple[str, Optional[str]]:
644
  url = f"{OLLAMA_BASE_URL}/api/chat"
 
668
  """
669
  backend = (os.environ.get("RAG_LLM_BACKEND", "") or "").strip().lower()
670
 
671
+ if backend == "hf-router":
672
+ return hf_chat(prompt)
673
  if backend == "ollama":
674
+ return ollama_chat(prompt)
675
  if is_running_on_spaces():
676
+ return hf_chat(prompt)
677
  if (HF_TOKEN or "").strip():
678
+ return hf_chat(prompt)
679
+ return ollama_chat(prompt)
680
 
681
  def github_create_issue(title: str, body: str, labels: Optional[List[str]] = None) -> Tuple[Optional[int], Optional[str]]:
682
  if not GITHUB_TOKEN:
 
744
  st.session_state["open_ticket_ui"] = True
745
  st.write("")
746
  st.subheader("LLM")
747
+ st.markdown(f"- Active model: `{HF_MODEL}`")
 
 
 
 
 
 
 
748
  st.write("")
749
  st.subheader("Embedding model (retrieval)")
750
  st.code(EMBED_MODEL)
requirements.txt CHANGED
@@ -3,7 +3,7 @@
3
  # -------------------------
4
  requests>=2.31.0
5
  python-dotenv>=1.0.0
6
- huggingface_hub>=0.24.0
7
  numpy>=1.24.0
8
  faiss-cpu>=1.8.0
9
  sentence-transformers>=2.6.0
 
3
  # -------------------------
4
  requests>=2.31.0
5
  python-dotenv>=1.0.0
6
+ openai>=1.0.0
7
  numpy>=1.24.0
8
  faiss-cpu>=1.8.0
9
  sentence-transformers>=2.6.0