MukulRay commited on
Commit
2a847ac
Β·
1 Parent(s): e4ec4e1

feat: metadata routing, Tavily fallback, retrieval_method in response

Browse files
Files changed (3) hide show
  1. main.py +3 -1
  2. rag.py +149 -30
  3. requirements.txt +1 -0
main.py CHANGED
@@ -53,6 +53,7 @@ class GenerateResponse(BaseModel):
53
  sources: list[str]
54
  latency_ms: float
55
  blocked: bool = False
 
56
 
57
 
58
  @app.get("/")
@@ -83,7 +84,7 @@ def generate(req: GenerateRequest):
83
  )
84
 
85
  start = time.time()
86
- answer, sources = rag_chain.query(req.query, top_k=req.top_k)
87
  latency_ms = (time.time() - start) * 1000
88
 
89
  is_clean, answer = validate_output(answer)
@@ -100,6 +101,7 @@ def generate(req: GenerateRequest):
100
  sources=sources,
101
  latency_ms=round(latency_ms, 1),
102
  blocked=False,
 
103
  )
104
 
105
 
 
53
  sources: list[str]
54
  latency_ms: float
55
  blocked: bool = False
56
+ retrieval_method: str = "rag" # "rag" | "web_fallback" | "guardrail_blocked"
57
 
58
 
59
  @app.get("/")
 
84
  )
85
 
86
  start = time.time()
87
+ answer, sources, retrieval_method = rag_chain.query(req.query, top_k=req.top_k)
88
  latency_ms = (time.time() - start) * 1000
89
 
90
  is_clean, answer = validate_output(answer)
 
101
  sources=sources,
102
  latency_ms=round(latency_ms, 1),
103
  blocked=False,
104
+ retrieval_method=retrieval_method,
105
  )
106
 
107
 
rag.py CHANGED
@@ -23,9 +23,10 @@ PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
23
  PINECONE_INDEX = os.getenv("PINECONE_INDEX", "llmops-rag")
24
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
25
  GROQ_MODEL = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
 
26
 
27
  # Minimum Pinecone score to trust corpus β€” below this triggers web fallback
28
- CORPUS_CONFIDENCE_THRESHOLD = 0.35
29
 
30
  PROMPT_TEMPLATE = """You are Akasha β€” the living memory of Teyvat, an omniscient Genshin Impact assistant with the depth of a master theorycafter and the storytelling of a lore scholar.
31
 
@@ -113,6 +114,115 @@ def _extract_subject(query: str) -> str:
113
  return query.strip().title()
114
 
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  def _build_groq_llm():
117
  from langchain_groq import ChatGroq
118
 
@@ -213,6 +323,7 @@ class RAGChain:
213
 
214
  def _corpus_has_coverage(self, question: str) -> tuple[bool, list]:
215
  """Check if Pinecone has meaningful coverage for this query."""
 
216
  try:
217
  docs_with_scores = self.vectorstore.similarity_search_with_score(
218
  question, k=3
@@ -228,43 +339,51 @@ class RAGChain:
228
  logger.warning(f"Coverage check failed: {e}")
229
  return True, [] # fail open β€” try corpus anyway
230
 
231
- def query(self, question: str, top_k: int = 8) -> tuple[str, list[str]]:
 
 
 
 
232
  if not self.ready:
233
  raise RuntimeError("RAG chain is not loaded.")
234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  self.chain.retriever.search_kwargs["k"] = top_k
 
 
 
 
236
 
237
- # Check corpus coverage
238
- has_coverage, _ = self._corpus_has_coverage(question)
239
-
240
- if not has_coverage:
241
- logger.info("Low corpus coverage β€” attempting web fallback")
242
- subject = _extract_subject(question)
243
- web_content = _fetch_wiki_page(subject)
244
-
245
- if web_content:
246
- # Answer from web data using the LLM directly
247
- web_prompt = PromptTemplate(
248
- template=WEB_PROMPT_TEMPLATE,
249
- input_variables=["context", "question"],
250
- )
251
- from langchain_core.output_parsers import StrOutputParser
252
- web_chain = web_prompt | self.llm | StrOutputParser()
253
- try:
254
- answer = web_chain.invoke({
255
- "context": web_content,
256
- "question": question,
257
- })
258
- answer = answer.strip().replace("</s>", "").strip()
259
- return answer, ["web: wiki.gg/game8.co (live)"]
260
- except Exception as e:
261
- logger.warning(f"Web chain failed: {e}")
262
-
263
- # Default: corpus RAG
264
  result = self.chain.invoke({"query": question})
265
  answer = result["result"].strip().replace("</s>", "").strip()
266
  sources = [
267
  doc.metadata.get("source", "unknown")
268
  for doc in result.get("source_documents", [])
269
  ]
270
- return answer, list(dict.fromkeys(sources))
 
23
  PINECONE_INDEX = os.getenv("PINECONE_INDEX", "llmops-rag")
24
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
25
  GROQ_MODEL = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
26
+ TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
27
 
28
  # Minimum Pinecone score to trust corpus β€” below this triggers web fallback
29
+ CORPUS_CONFIDENCE_THRESHOLD = 0.60
30
 
31
  PROMPT_TEMPLATE = """You are Akasha β€” the living memory of Teyvat, an omniscient Genshin Impact assistant with the depth of a master theorycafter and the storytelling of a lore scholar.
32
 
 
114
  return query.strip().title()
115
 
116
 
117
+ def route_query(question: str) -> dict:
118
+ """
119
+ Detect query intent and return a Pinecone metadata filter dict.
120
+ Applied per-query, not at startup.
121
+ """
122
+ q = question.lower()
123
+
124
+ # Build/optimization intent
125
+ build_keywords = ["build", "weapon", "artifact", "bis", "best in slot",
126
+ "team", "rotation", "er threshold", "em", "crit",
127
+ "f2p", "free to play", "comps", "comp"]
128
+ # Lore intent
129
+ lore_keywords = ["lore", "story", "who is", "personality", "history",
130
+ "background", "quest", "backstory", "relationship"]
131
+ # Stats/numbers intent
132
+ stats_keywords = ["stats", "talent", "constellation", "scaling",
133
+ "multiplier", "numbers", "c0", "c1", "c2", "c3",
134
+ "c4", "c5", "c6", "a1", "a4"]
135
+ # Mechanics intent
136
+ mechanics_keywords = ["reaction", "mechanic", "how does", "damage formula",
137
+ "icd", "internal cooldown", "vaporize", "melt",
138
+ "swirl", "freeze", "superconduct", "hyperbloom",
139
+ "burgeon", "quicken", "aggravate", "spread"]
140
+
141
+ # Known Genshin character names for character-specific filter
142
+ # This list covers the major characters β€” not exhaustive
143
+ known_characters = [
144
+ "hu tao", "zhongli", "venti", "kazuha", "raiden", "raiden shogun",
145
+ "bennett", "xingqiu", "yelan", "xiangling", "fischl", "beidou",
146
+ "sucrose", "albedo", "ganyu", "ayaka", "ayato", "itto", "gorou",
147
+ "kokomi", "sara", "yoimiya", "thoma", "shenhe", "yunjin",
148
+ "nahida", "cyno", "tighnari", "collei", "dori", "layla", "faruzan",
149
+ "wanderer", "scaramouche", "alhaitham", "dehya", "mika", "baizhu",
150
+ "kaveh", "nilou", "candace",
151
+ "neuvillette", "furina", "wriothesley", "navia", "charlotte",
152
+ "freminet", "lyney", "lynette", "arlecchino", "clorinde",
153
+ "sigewinne", "emilie", "chevreuse",
154
+ "mualani", "kinich", "kachina", "xilonen", "chasca", "ororon",
155
+ "mavuika", "citlali",
156
+ "lumine", "aether", "paimon",
157
+ "keqing", "diluc", "jean", "qiqi", "mona", "klee", "childe",
158
+ "tartaglia", "eula", "amber", "barbara", "noelle", "razor",
159
+ "lisa", "traveler", "xinyan", "ningguang", "chongyun", "diona",
160
+ "rosaria", "yanfei", "hutao", "sayu", "shogun",
161
+ "yae miko", "yae", "heizou", "shinobu", "tighnari",
162
+ "wanderer", "alhaitham", "baizhu",
163
+ ]
164
+
165
+ filter_dict = {}
166
+
167
+ # Detect character name in query
168
+ detected_character = None
169
+ for char in known_characters:
170
+ if char in q:
171
+ # Normalize to title case for metadata match
172
+ detected_character = char.title()
173
+ break
174
+
175
+ # Determine tier/content_type filter based on intent keywords
176
+ if any(kw in q for kw in build_keywords + mechanics_keywords):
177
+ filter_dict = {"tier": {"$in": ["tcl", "structured"]}}
178
+ elif any(kw in q for kw in lore_keywords):
179
+ filter_dict = {"tier": "wiki"}
180
+ elif any(kw in q for kw in stats_keywords):
181
+ filter_dict = {"content_type": {"$in": ["stats", "ability"]}}
182
+ else:
183
+ filter_dict = {} # ambiguous β€” search all tiers
184
+
185
+ # Add character filter on top if detected
186
+ if detected_character:
187
+ if filter_dict:
188
+ filter_dict["character"] = detected_character
189
+ else:
190
+ filter_dict = {"character": detected_character}
191
+
192
+ logger.info(f"Query routed β€” filter: {filter_dict}")
193
+ return filter_dict
194
+
195
+
196
+ def _tavily_search(question: str) -> tuple[str, str]:
197
+ """
198
+ Call Tavily search API as web fallback.
199
+ Returns (answer_text, source_url).
200
+ Falls back to empty strings if API key not set or call fails.
201
+ """
202
+ if not TAVILY_API_KEY:
203
+ logger.warning("TAVILY_API_KEY not set β€” web fallback unavailable")
204
+ return "", ""
205
+ try:
206
+ from tavily import TavilyClient
207
+ client = TavilyClient(api_key=TAVILY_API_KEY)
208
+ # Scope search to Genshin sources for quality
209
+ response = client.search(
210
+ query=f"Genshin Impact {question}",
211
+ search_depth="basic",
212
+ max_results=3,
213
+ include_answer=True,
214
+ )
215
+ answer = response.get("answer", "")
216
+ # Get top source URL
217
+ results = response.get("results", [])
218
+ source_url = results[0]["url"] if results else "web search"
219
+ logger.info(f"Tavily returned answer length: {len(answer)} chars")
220
+ return answer, source_url
221
+ except Exception as e:
222
+ logger.warning(f"Tavily search failed: {e}")
223
+ return "", ""
224
+
225
+
226
  def _build_groq_llm():
227
  from langchain_groq import ChatGroq
228
 
 
323
 
324
  def _corpus_has_coverage(self, question: str) -> tuple[bool, list]:
325
  """Check if Pinecone has meaningful coverage for this query."""
326
+ # NOTE: not called in query() β€” kept for reference
327
  try:
328
  docs_with_scores = self.vectorstore.similarity_search_with_score(
329
  question, k=3
 
339
  logger.warning(f"Coverage check failed: {e}")
340
  return True, [] # fail open β€” try corpus anyway
341
 
342
+ def query(self, question: str, top_k: int = 8) -> tuple[str, list[str], str]:
343
+ """
344
+ Returns (answer, sources, retrieval_method)
345
+ retrieval_method: "rag" | "web_fallback" | "guardrail_blocked"
346
+ """
347
  if not self.ready:
348
  raise RuntimeError("RAG chain is not loaded.")
349
 
350
+ # Step 1: Route query β€” get metadata filter
351
+ filter_dict = route_query(question)
352
+
353
+ # Step 2: Retrieve with scores to check confidence
354
+ try:
355
+ docs_with_scores = self.vectorstore.similarity_search_with_score(
356
+ question, k=top_k, filter=filter_dict if filter_dict else None
357
+ )
358
+ except Exception as e:
359
+ logger.warning(f"Filtered retrieval failed: {e} β€” retrying without filter")
360
+ docs_with_scores = self.vectorstore.similarity_search_with_score(
361
+ question, k=top_k
362
+ )
363
+
364
+ # Step 3: Check confidence
365
+ max_score = docs_with_scores[0][1] if docs_with_scores else 0.0
366
+ logger.info(f"Top Pinecone score: {max_score:.3f} (threshold: {CORPUS_CONFIDENCE_THRESHOLD})")
367
+
368
+ if max_score < CORPUS_CONFIDENCE_THRESHOLD:
369
+ logger.info(f"Low confidence ({max_score:.2f}) β€” falling back to web search")
370
+ tavily_answer, tavily_source = _tavily_search(question)
371
+ if tavily_answer:
372
+ return tavily_answer, [f"web: {tavily_source}"], "web_fallback"
373
+ else:
374
+ logger.warning("Tavily fallback also failed β€” proceeding with RAG anyway")
375
+
376
+ # Step 4: Apply filter to the chain retriever and run RAG
377
  self.chain.retriever.search_kwargs["k"] = top_k
378
+ if filter_dict:
379
+ self.chain.retriever.search_kwargs["filter"] = filter_dict
380
+ else:
381
+ self.chain.retriever.search_kwargs.pop("filter", None)
382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  result = self.chain.invoke({"query": question})
384
  answer = result["result"].strip().replace("</s>", "").strip()
385
  sources = [
386
  doc.metadata.get("source", "unknown")
387
  for doc in result.get("source_documents", [])
388
  ]
389
+ return answer, list(dict.fromkeys(sources)), "rag"
requirements.txt CHANGED
@@ -24,3 +24,4 @@ sentence-transformers==4.1.0
24
  # ── Utilities ──────────────────────────────────────────────────────────────────
25
  python-dotenv==1.0.1
26
  requests>=2.31.0
 
 
24
  # ── Utilities ──────────────────────────────────────────────────────────────────
25
  python-dotenv==1.0.1
26
  requests>=2.31.0
27
+ tavily-python