Oleksii Obolonskyi commited on
Commit
8faa6a7
·
1 Parent(s): 52f5ee4

Use InferenceClient for HF chat

Browse files
Files changed (2) hide show
  1. app.py +19 -27
  2. requirements.txt +1 -0
app.py CHANGED
@@ -14,6 +14,8 @@ import streamlit as st
14
  import numpy as np
15
  import faiss
16
  import requests
 
 
17
  from sentence_transformers import SentenceTransformer
18
 
19
  load_dotenv(Path(__file__).resolve().parent / ".env", override=True)
@@ -37,9 +39,6 @@ HF_TOKEN = (
37
  or ""
38
  ).strip()
39
  HF_MODEL = os.environ.get("RAG_HF_MODEL", "meta-llama/Llama-3.2-1B-Instruct")
40
- HF_API_URL = os.environ.get("RAG_HF_API_URL", "").strip()
41
- if not HF_API_URL:
42
- HF_API_URL = f"https://api-inference.huggingface.co/models/{HF_MODEL}"
43
 
44
  OLLAMA_BASE_URL = os.environ.get("RAG_OLLAMA_URL", "http://localhost:11434").rstrip("/")
45
  OLLAMA_MODEL = os.environ.get("RAG_OLLAMA_MODEL", "llama3.2:1b")
@@ -49,14 +48,14 @@ REPO_NAME = "RAG_project"
49
 
50
  GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN", "").strip()
51
  # Retrieval mix: book-first + article nuance.
52
- BOOK_K = 12
53
  ARTICLE_K = 4
54
  PER_DOC_CAP = 3
55
  OVERLAP_FILTER = True
56
 
57
  # Enhanced answer mix: heavier retrieval for deeper answers.
58
- ENHANCED_BOOK_K = 12
59
- ENHANCED_ARTICLE_K = 5
60
 
61
  AVOID_PHRASES = [
62
  "The article discusses",
@@ -523,31 +522,24 @@ def build_hf_prompt(user_prompt: str, model_id: str) -> str:
523
  )
524
  return f"System: {system_msg}\nUser: {user_prompt}\nAssistant:"
525
 
 
 
 
 
526
  def hf_chat(prompt: str, timeout: Tuple[int, int] = (10, 600)) -> Tuple[str, Optional[str]]:
527
  if not HF_TOKEN:
528
  return "", "Missing HF_TOKEN (or HUGGINGFACEHUB_API_TOKEN)"
529
- headers = {"Authorization": f"Bearer {HF_TOKEN}"}
530
- payload = {
531
- "inputs": build_hf_prompt(prompt, HF_MODEL),
532
- "parameters": {
533
- "temperature": 0.2,
534
- "max_new_tokens": 512,
535
- "return_full_text": False,
536
- },
537
- "options": {"wait_for_model": True},
538
- }
539
  try:
540
- r = requests.post(HF_API_URL, headers=headers, json=payload, timeout=timeout)
541
- r.raise_for_status()
542
- data = r.json()
543
- msg = ""
544
- if isinstance(data, list) and data:
545
- msg = data[0].get("generated_text") or ""
546
- elif isinstance(data, dict):
547
- if "error" in data:
548
- return "", str(data.get("error") or "Hugging Face error")
549
- msg = data.get("generated_text") or ""
550
- return (msg or "").strip(), None
551
  except Exception as e:
552
  return "", str(e)
553
 
 
14
  import numpy as np
15
  import faiss
16
  import requests
17
+ from huggingface_hub import InferenceClient
18
+ from huggingface_hub import InferenceClient
19
  from sentence_transformers import SentenceTransformer
20
 
21
  load_dotenv(Path(__file__).resolve().parent / ".env", override=True)
 
39
  or ""
40
  ).strip()
41
  HF_MODEL = os.environ.get("RAG_HF_MODEL", "meta-llama/Llama-3.2-1B-Instruct")
 
 
 
42
 
43
  OLLAMA_BASE_URL = os.environ.get("RAG_OLLAMA_URL", "http://localhost:11434").rstrip("/")
44
  OLLAMA_MODEL = os.environ.get("RAG_OLLAMA_MODEL", "llama3.2:1b")
 
48
 
49
  GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN", "").strip()
50
  # Retrieval mix: book-first + article nuance.
51
+ BOOK_K = 8
52
  ARTICLE_K = 4
53
  PER_DOC_CAP = 3
54
  OVERLAP_FILTER = True
55
 
56
  # Enhanced answer mix: heavier retrieval for deeper answers.
57
+ ENHANCED_BOOK_K = 14
58
+ ENHANCED_ARTICLE_K = 7
59
 
60
  AVOID_PHRASES = [
61
  "The article discusses",
 
522
  )
523
  return f"System: {system_msg}\nUser: {user_prompt}\nAssistant:"
524
 
525
+ @st.cache_resource(show_spinner=False)
526
+ def get_hf_client() -> InferenceClient:
527
+ return InferenceClient(model=HF_MODEL, token=HF_TOKEN)
528
+
529
  def hf_chat(prompt: str, timeout: Tuple[int, int] = (10, 600)) -> Tuple[str, Optional[str]]:
530
  if not HF_TOKEN:
531
  return "", "Missing HF_TOKEN (or HUGGINGFACEHUB_API_TOKEN)"
 
 
 
 
 
 
 
 
 
 
532
  try:
533
+ client = get_hf_client()
534
+ inp = build_hf_prompt(prompt, HF_MODEL)
535
+ out = client.text_generation(
536
+ inp,
537
+ max_new_tokens=512,
538
+ temperature=0.2,
539
+ do_sample=True,
540
+ return_full_text=False,
541
+ )
542
+ return (out or "").strip(), None
 
543
  except Exception as e:
544
  return "", str(e)
545
 
requirements.txt CHANGED
@@ -5,6 +5,7 @@ requests>=2.31.0
5
  python-dateutil>=2.9.0
6
  tqdm>=4.66.0
7
  python-dotenv>=1.0.0
 
8
 
9
  # -------------------------
10
  # PDF processing (required by normalize_all.py)
 
5
  python-dateutil>=2.9.0
6
  tqdm>=4.66.0
7
  python-dotenv>=1.0.0
8
+ huggingface-hub>=0.30.0
9
 
10
  # -------------------------
11
  # PDF processing (required by normalize_all.py)