tharunchndrn commited on
Commit
4acbfb3
·
verified ·
1 Parent(s): 84c9ad6

Update backend_app/rag_hf.py

Browse files
Files changed (1) hide show
  1. backend_app/rag_hf.py +92 -57
backend_app/rag_hf.py CHANGED
@@ -13,21 +13,28 @@ from .config import (
13
  EMBED_MODEL_NAME,
14
  MIN_TOP_SCORE,
15
  WEB_MAX_RESULTS,
 
 
16
  )
17
  from .fetcher import fetch_page_text
18
  from .web_search import web_search
19
 
20
- HF_TOKEN = os.getenv("HF_TOKEN", "")
21
- HF_MODEL = os.getenv("HF_MODEL", "HuggingFaceH4/zephyr-7b-beta") # you can change later
22
 
23
  class RAGEngineHF:
24
  def __init__(self):
25
  self.embedder = SentenceTransformer(EMBED_MODEL_NAME)
 
 
26
  self.index = faiss.read_index(FAISS_INDEX_PATH)
27
  with open(DOCSTORE_PATH, "rb") as f:
28
  self.docs: List[Dict] = pickle.load(f)
29
 
30
- self.client = InferenceClient(model=HF_MODEL, token=HF_TOKEN)
 
 
 
 
 
31
  self.TOP_K = 5
32
  self.MAX_CONTEXT_CHARS_PER_DOC = 1800
33
 
@@ -36,47 +43,69 @@ class RAGEngineHF:
36
  q_emb = np.array(q_emb, dtype="float32")
37
  scores, ids = self.index.search(q_emb, k)
38
 
39
- out = []
40
  for rank, doc_id in enumerate(ids[0]):
41
  if doc_id == -1:
42
  continue
43
  d = self.docs[int(doc_id)]
44
- out.append({
45
- "rank": rank + 1,
46
- "score": float(scores[0][rank]),
47
- "text": d["text"],
48
- "meta": d["meta"],
49
- })
 
 
50
  return out
51
 
52
  def _needs_web_fallback(self, contexts: List[Dict]) -> bool:
53
  return (not contexts) or (contexts[0]["score"] < MIN_TOP_SCORE)
54
 
55
  def fetch_web_context(self, query: str) -> Tuple[List[Dict], List[Dict]]:
 
 
 
 
56
  queries = [f"site:foodsystemsdashboard.org {query}", query]
57
- links, seen = [], set()
 
58
 
59
  for q in queries:
60
  for r in web_search(q, max_results=WEB_MAX_RESULTS):
61
- if r["url"] not in seen:
 
62
  links.append(r)
63
- seen.add(r["url"])
64
  if len(links) >= WEB_MAX_RESULTS:
65
  break
66
 
67
- contexts, sources = [], []
 
 
68
  for r in links[:WEB_MAX_RESULTS]:
69
  try:
70
  page = fetch_page_text(r["url"], use_cache=True)
71
- contexts.append({
72
- "rank": len(contexts) + 1,
73
- "score": 0.0,
74
- "text": page["text"],
75
- "meta": {"url": page["url"], "title": page["title"], "chunk": 0},
76
- })
77
- sources.append({"title": page["title"], "url": page["url"]})
78
- except:
 
 
 
 
 
 
 
 
 
 
 
79
  continue
 
80
  return contexts, sources
81
 
82
  def answer(self, query: str, preferred_lang: Optional[str] = None) -> Dict:
@@ -85,6 +114,7 @@ class RAGEngineHF:
85
  contexts = local
86
  sources = self._unique_sources(local)
87
 
 
88
  if self._needs_web_fallback(local):
89
  web_ctx, web_src = self.fetch_web_context(query)
90
  if web_ctx:
@@ -93,8 +123,10 @@ class RAGEngineHF:
93
  sources = web_src
94
 
95
  context_block = "\n\n".join(
96
- [f"[{i+1}] {c['meta']['title']}\n{c['text'][:self.MAX_CONTEXT_CHARS_PER_DOC]}"
97
- for i, c in enumerate(contexts)]
 
 
98
  )
99
 
100
  lang_line = f"Respond in {preferred_lang}.\n" if preferred_lang else ""
@@ -102,56 +134,59 @@ class RAGEngineHF:
102
  prompt = f"""
103
  You are the SysLink Food System assistant.
104
 
105
- You MUST answer using ONLY the information provided in the retrieved context.
 
106
 
107
  Write in simple, clear language.
108
  Keep responses MEDIUM length (8–14 lines).
109
- Avoid being too brief or too detailed.
110
-
111
- Your focus areas are:
112
- - Food systems and agri-food value chains
113
- - Farmers, markets, logistics, and distribution
114
- - Sustainability and food security
115
- - Policy, programs, and institutional support
116
-
117
- Rules:
118
- - Be factual, neutral, and helpful
119
- - Avoid technical jargon unless it appears in the context
120
- - Do not include opinions or speculation
121
- - Do not summarize irrelevant information
122
- - Do not mention the word “context” or “retrieved documents” in the final answer
123
 
124
  {lang_line}
125
  QUESTION: {query}
126
 
127
- CONTEXT:
128
  {context_block}
129
 
130
  ANSWER:
131
- """.strip()
132
- messages = [
133
- {"role": "system", "content": "You are the SysLink Food System assistant."},
134
- {"role": "user", "content": prompt},
135
- ]
136
-
137
- resp = self.client.chat_completion(
138
- messages=messages,
139
- max_tokens=250,
140
- temperature=0.2,
141
- )
142
-
143
- out = resp.choices[0].message.content.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  if not out:
146
- out = "I couldn’t find enough reliable information in the provided sources. Please rephrase or share more details."
147
 
148
  return {"answer": out, "sources": sources, "used": used}
149
 
150
  def _unique_sources(self, contexts: List[Dict]) -> List[Dict]:
151
  seen, out = set(), []
152
  for c in contexts:
153
- u = c["meta"]["url"]
154
- if u not in seen:
155
- out.append({"title": c["meta"]["title"], "url": u})
 
156
  seen.add(u)
157
  return out
 
13
  EMBED_MODEL_NAME,
14
  MIN_TOP_SCORE,
15
  WEB_MAX_RESULTS,
16
+ HF_TOKEN,
17
+ HF_MODEL,
18
  )
19
  from .fetcher import fetch_page_text
20
  from .web_search import web_search
21
 
 
 
22
 
23
  class RAGEngineHF:
24
  def __init__(self):
25
  self.embedder = SentenceTransformer(EMBED_MODEL_NAME)
26
+
27
+ # Load FAISS index + docs
28
  self.index = faiss.read_index(FAISS_INDEX_PATH)
29
  with open(DOCSTORE_PATH, "rb") as f:
30
  self.docs: List[Dict] = pickle.load(f)
31
 
32
+ # Prefer config values; give safe default model if empty
33
+ model_name = (HF_MODEL or "google/gemma-2-2b-it").strip()
34
+ token = (HF_TOKEN or "").strip()
35
+
36
+ self.client = InferenceClient(model=model_name, token=token)
37
+
38
  self.TOP_K = 5
39
  self.MAX_CONTEXT_CHARS_PER_DOC = 1800
40
 
 
43
  q_emb = np.array(q_emb, dtype="float32")
44
  scores, ids = self.index.search(q_emb, k)
45
 
46
+ out: List[Dict] = []
47
  for rank, doc_id in enumerate(ids[0]):
48
  if doc_id == -1:
49
  continue
50
  d = self.docs[int(doc_id)]
51
+ out.append(
52
+ {
53
+ "rank": rank + 1,
54
+ "score": float(scores[0][rank]),
55
+ "text": d.get("text", ""),
56
+ "meta": d.get("meta", {}),
57
+ }
58
+ )
59
  return out
60
 
61
  def _needs_web_fallback(self, contexts: List[Dict]) -> bool:
62
  return (not contexts) or (contexts[0]["score"] < MIN_TOP_SCORE)
63
 
64
  def fetch_web_context(self, query: str) -> Tuple[List[Dict], List[Dict]]:
65
+ """
66
+ Optional fallback: uses web_search() -> fetch_page_text().
67
+ web_search() should return [] when rate-limited, so this won't crash.
68
+ """
69
  queries = [f"site:foodsystemsdashboard.org {query}", query]
70
+ links: List[Dict] = []
71
+ seen = set()
72
 
73
  for q in queries:
74
  for r in web_search(q, max_results=WEB_MAX_RESULTS):
75
+ url = r.get("url")
76
+ if url and url not in seen:
77
  links.append(r)
78
+ seen.add(url)
79
  if len(links) >= WEB_MAX_RESULTS:
80
  break
81
 
82
+ contexts: List[Dict] = []
83
+ sources: List[Dict] = []
84
+
85
  for r in links[:WEB_MAX_RESULTS]:
86
  try:
87
  page = fetch_page_text(r["url"], use_cache=True)
88
+ contexts.append(
89
+ {
90
+ "rank": len(contexts) + 1,
91
+ "score": 0.0,
92
+ "text": page.get("text", ""),
93
+ "meta": {
94
+ "url": page.get("url", r["url"]),
95
+ "title": page.get("title", r.get("title", "Source")),
96
+ "chunk": 0,
97
+ },
98
+ }
99
+ )
100
+ sources.append(
101
+ {
102
+ "title": page.get("title", r.get("title", "Source")),
103
+ "url": page.get("url", r["url"]),
104
+ }
105
+ )
106
+ except Exception:
107
  continue
108
+
109
  return contexts, sources
110
 
111
  def answer(self, query: str, preferred_lang: Optional[str] = None) -> Dict:
 
114
  contexts = local
115
  sources = self._unique_sources(local)
116
 
117
+ # Web fallback only if local seems weak
118
  if self._needs_web_fallback(local):
119
  web_ctx, web_src = self.fetch_web_context(query)
120
  if web_ctx:
 
123
  sources = web_src
124
 
125
  context_block = "\n\n".join(
126
+ [
127
+ f"[{i+1}] {c.get('meta', {}).get('title', 'Source')}\n{c.get('text', '')[:self.MAX_CONTEXT_CHARS_PER_DOC]}"
128
+ for i, c in enumerate(contexts)
129
+ ]
130
  )
131
 
132
  lang_line = f"Respond in {preferred_lang}.\n" if preferred_lang else ""
 
134
  prompt = f"""
135
  You are the SysLink Food System assistant.
136
 
137
+ You MUST answer using ONLY the information provided below.
138
+ Do NOT invent facts.
139
 
140
  Write in simple, clear language.
141
  Keep responses MEDIUM length (8–14 lines).
142
+ If information is missing, say what is missing.
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  {lang_line}
145
  QUESTION: {query}
146
 
147
+ INFORMATION:
148
  {context_block}
149
 
150
  ANSWER:
151
+ """.strip()
152
+
153
+ # If token missing, we can still try public inference,
154
+ # but failures are common; return a helpful message.
155
+ token = (HF_TOKEN or "").strip()
156
+ if not token:
157
+ return {
158
+ "answer": "I’m running without an HF_TOKEN right now, so the AI response may fail. Please add HF_TOKEN in Space Settings → Secrets, then retry.",
159
+ "sources": sources,
160
+ "used": used,
161
+ }
162
+
163
+ # Try chat completion (works for conversational providers)
164
+ try:
165
+ messages = [
166
+ {"role": "system", "content": "You are the SysLink Food System assistant."},
167
+ {"role": "user", "content": prompt},
168
+ ]
169
+ resp = self.client.chat_completion(
170
+ messages=messages,
171
+ max_tokens=250,
172
+ temperature=0.2,
173
+ )
174
+ out = (resp.choices[0].message.content or "").strip()
175
+ except Exception as e:
176
+ # Fallback: return a visible error message (so you can debug)
177
+ out = f"Model error: {str(e)}"
178
 
179
  if not out:
180
+ out = "I couldn’t generate an answer right now. Please try again."
181
 
182
  return {"answer": out, "sources": sources, "used": used}
183
 
184
  def _unique_sources(self, contexts: List[Dict]) -> List[Dict]:
185
  seen, out = set(), []
186
  for c in contexts:
187
+ meta = c.get("meta", {})
188
+ u = meta.get("url")
189
+ if u and u not in seen:
190
+ out.append({"title": meta.get("title", "Source"), "url": u})
191
  seen.add(u)
192
  return out