fikri0o0 commited on
Commit
2dc9a8a
Β·
verified Β·
1 Parent(s): d6223d1

Add query rewriting + corrective RAG + 3-stage RAGAS ablation

Browse files
Files changed (1) hide show
  1. rag_chain.py +111 -25
rag_chain.py CHANGED
@@ -17,6 +17,19 @@ from config import (
17
  CHUNK_SIZE, CHUNK_OVERLAP, DEVICE, PROVIDER_KEYS,
18
  USE_HYBRID_SEARCH, MAX_HISTORY_TURNS,
19
  USE_RERANKER, RERANKER_MODEL, RETRIEVAL_FETCH_K, RRF_K,
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  )
21
 
22
  SYSTEM_PROMPT = (
@@ -197,10 +210,41 @@ def _rerank(
197
  return [docs[i] for i in order], [float(probs[i]) for i in order]
198
 
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  def retrieve_docs(
201
  input_text: str, philosopher: str = "All"
202
  ) -> tuple[list[Document], list[float]]:
203
- """Two-stage retrieval: hybrid (RRF) candidate pool β†’ cross-encoder rerank.
204
 
205
  Returns (docs, scores). With reranking on, scores are cross-encoder
206
  relevance ∈ [0, 1]; in the fallback path, semantic cosine relevance,
@@ -208,30 +252,38 @@ def retrieve_docs(
208
  """
209
  vectorstore = _get_vectorstore()
210
  fetch_k = RETRIEVAL_FETCH_K if USE_RERANKER else RETRIEVAL_K
211
- search_kwargs: dict = {"k": fetch_k}
212
- if philosopher != "All":
213
- search_kwargs["filter"] = {"philosopher": philosopher}
214
-
215
- with warnings.catch_warnings():
216
- warnings.filterwarnings("ignore", message="Relevance scores must be between")
217
- semantic_pairs = vectorstore.similarity_search_with_relevance_scores(
218
- input_text, **search_kwargs
219
- )
220
- semantic_docs = [d for d, _ in semantic_pairs]
221
- sem_score = {d.page_content: s for d, s in semantic_pairs}
222
 
223
- bm25_docs: list[Document] = []
224
- if USE_HYBRID_SEARCH and philosopher == "All":
 
225
  try:
226
- bm25_docs = _get_bm25_retriever().invoke(input_text)
227
  except Exception:
228
- bm25_docs = []
229
-
230
- # Stage 1 β€” fuse the two ranked lists into one candidate pool.
231
- fused = _reciprocal_rank_fusion([semantic_docs, bm25_docs])
232
- pool = [d for d, _ in fused][:fetch_k] or semantic_docs[:fetch_k]
233
-
234
- # Stage 2 β€” cross-encoder rerank.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  if USE_RERANKER and pool:
236
  try:
237
  return _rerank(input_text, pool, RETRIEVAL_K)
@@ -249,6 +301,37 @@ def retrieve_docs(
249
  return docs, scores
250
 
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  # ---------------------------------------------------------------------------
253
  # LLM calls β€” non-streaming
254
  # ---------------------------------------------------------------------------
@@ -400,12 +483,15 @@ def stream_llm(
400
  def query(
401
  input_text: str, philosopher: str = "All", llm_label: str = DEFAULT_LLM
402
  ) -> dict:
403
- """Non-streaming query. Returns answer + context + scores."""
404
  provider, model_id = LLM_OPTIONS.get(llm_label, LLM_OPTIONS[DEFAULT_LLM])
405
- docs, scores = retrieve_docs(input_text, philosopher)
 
 
 
406
  context_str = "\n\n".join(d.page_content for d in docs)
407
  answer = _call_llm(provider, model_id, context_str, input_text)
408
- return {"answer": answer, "context": docs, "scores": scores}
409
 
410
 
411
  # ---------------------------------------------------------------------------
 
17
  CHUNK_SIZE, CHUNK_OVERLAP, DEVICE, PROVIDER_KEYS,
18
  USE_HYBRID_SEARCH, MAX_HISTORY_TURNS,
19
  USE_RERANKER, RERANKER_MODEL, RETRIEVAL_FETCH_K, RRF_K,
20
+ USE_QUERY_REWRITE, QUERY_REWRITE_MODEL, N_QUERY_VARIANTS,
21
+ USE_CORRECTIVE_RAG, CRAG_ABSTAIN_THRESHOLD,
22
+ )
23
+
24
+ # Google's OpenAI-compatible endpoint (httpx). Used for query rewriting so it
25
+ # never touches the grpc google.genai client (which segfaults beside torch).
26
+ GOOGLE_OPENAI_BASE = "https://generativelanguage.googleapis.com/v1beta/openai/"
27
+
28
+ ABSTAIN_MESSAGE = (
29
+ "I don't have enough grounded context in the knowledge base to answer that "
30
+ "confidently. My sources are 12 Western philosophy texts (Nietzsche, Plato, "
31
+ "Kant, Hume, Schopenhauer, Mill, Marcus Aurelius, Epictetus, Russell) β€” try "
32
+ "rephrasing, or ask about themes from those works."
33
  )
34
 
35
  SYSTEM_PROMPT = (
 
210
  return [docs[i] for i in order], [float(probs[i]) for i in order]
211
 
212
 
213
+ @lru_cache(maxsize=256)
214
+ def _rewrite_query(question: str) -> tuple[str, ...]:
215
+ """Multi-query expansion: original question + LLM-generated paraphrases.
216
+
217
+ Cached so repeated/identical questions don't re-call the LLM. Uses the
218
+ OpenAI-compatible endpoint (httpx) to stay off the grpc google.genai client.
219
+ """
220
+ n = max(1, N_QUERY_VARIANTS - 1)
221
+ from openai import OpenAI
222
+ client = OpenAI(api_key=GOOGLE_API_KEY, base_url=GOOGLE_OPENAI_BASE)
223
+ prompt = (
224
+ "You rewrite search queries for a Western-philosophy retrieval system. "
225
+ f"Generate {n} alternative phrasings of the question that would help "
226
+ "retrieve relevant passages β€” vary wording, add synonyms and related "
227
+ "concepts, name the likely philosopher/work. One per line, no numbering, "
228
+ "no preamble.\n\nQuestion: " + question
229
+ )
230
+ resp = client.chat.completions.create(
231
+ model=QUERY_REWRITE_MODEL,
232
+ messages=[{"role": "user", "content": prompt}],
233
+ temperature=0.5,
234
+ max_tokens=200,
235
+ )
236
+ variants = [
237
+ ln.strip(" -β€’\t").strip()
238
+ for ln in (resp.choices[0].message.content or "").splitlines()
239
+ if ln.strip()
240
+ ]
241
+ return tuple([question] + [v for v in variants if v][:n])
242
+
243
+
244
  def retrieve_docs(
245
  input_text: str, philosopher: str = "All"
246
  ) -> tuple[list[Document], list[float]]:
247
+ """Multi-query β†’ hybrid (RRF) candidate pool β†’ cross-encoder rerank.
248
 
249
  Returns (docs, scores). With reranking on, scores are cross-encoder
250
  relevance ∈ [0, 1]; in the fallback path, semantic cosine relevance,
 
252
  """
253
  vectorstore = _get_vectorstore()
254
  fetch_k = RETRIEVAL_FETCH_K if USE_RERANKER else RETRIEVAL_K
 
 
 
 
 
 
 
 
 
 
 
255
 
256
+ # Query rewriting (multi-query). Only when not filtering to one philosopher.
257
+ queries = [input_text]
258
+ if USE_QUERY_REWRITE and philosopher == "All":
259
  try:
260
+ queries = list(_rewrite_query(input_text))
261
  except Exception:
262
+ queries = [input_text]
263
+
264
+ ranked_lists: list[list[Document]] = []
265
+ sem_score: dict[str, float] = {}
266
+ for q in queries:
267
+ search_kwargs: dict = {"k": fetch_k}
268
+ if philosopher != "All":
269
+ search_kwargs["filter"] = {"philosopher": philosopher}
270
+ with warnings.catch_warnings():
271
+ warnings.filterwarnings("ignore", message="Relevance scores must be between")
272
+ pairs = vectorstore.similarity_search_with_relevance_scores(q, **search_kwargs)
273
+ ranked_lists.append([d for d, _ in pairs])
274
+ for d, s in pairs:
275
+ sem_score.setdefault(d.page_content, s)
276
+ if USE_HYBRID_SEARCH and philosopher == "All":
277
+ try:
278
+ ranked_lists.append(_get_bm25_retriever().invoke(q))
279
+ except Exception:
280
+ pass
281
+
282
+ # Stage 1 β€” fuse all ranked lists (across query variants) into one pool.
283
+ fused = _reciprocal_rank_fusion(ranked_lists)
284
+ pool = [d for d, _ in fused][:fetch_k] or (ranked_lists[0][:fetch_k] if ranked_lists else [])
285
+
286
+ # Stage 2 β€” cross-encoder rerank against the ORIGINAL question.
287
  if USE_RERANKER and pool:
288
  try:
289
  return _rerank(input_text, pool, RETRIEVAL_K)
 
301
  return docs, scores
302
 
303
 
304
+ def retrieve_corrective(
305
+ input_text: str, philosopher: str = "All"
306
+ ) -> tuple[list[Document], list[float], str]:
307
+ """retrieve_docs + a confidence label from the reranker's top score.
308
+
309
+ Returns (docs, scores, confidence) where confidence is "ok" or "low".
310
+ "low" means the best retrieved chunk is below CRAG_ABSTAIN_THRESHOLD β€” the
311
+ caller should abstain rather than answer from weak context.
312
+ """
313
+ docs, scores = retrieve_docs(input_text, philosopher)
314
+ confidence = "ok"
315
+ if USE_CORRECTIVE_RAG:
316
+ # Abstain gate on semantic cosine (cleanly separates off-corpus queries;
317
+ # the reranker sigmoid hovers ~0.5 for both relevant and irrelevant).
318
+ search_kwargs: dict = {"k": 3}
319
+ if philosopher != "All":
320
+ search_kwargs["filter"] = {"philosopher": philosopher}
321
+ try:
322
+ with warnings.catch_warnings():
323
+ warnings.filterwarnings("ignore", message="Relevance scores must be between")
324
+ pairs = _get_vectorstore().similarity_search_with_relevance_scores(
325
+ input_text, **search_kwargs
326
+ )
327
+ top_cos = max((s for _, s in pairs), default=0.0)
328
+ if top_cos < CRAG_ABSTAIN_THRESHOLD:
329
+ confidence = "low"
330
+ except Exception:
331
+ pass
332
+ return docs, scores, confidence
333
+
334
+
335
  # ---------------------------------------------------------------------------
336
  # LLM calls β€” non-streaming
337
  # ---------------------------------------------------------------------------
 
483
  def query(
484
  input_text: str, philosopher: str = "All", llm_label: str = DEFAULT_LLM
485
  ) -> dict:
486
+ """Non-streaming query. Returns answer + context + scores (+ abstained)."""
487
  provider, model_id = LLM_OPTIONS.get(llm_label, LLM_OPTIONS[DEFAULT_LLM])
488
+ docs, scores, confidence = retrieve_corrective(input_text, philosopher)
489
+ if confidence == "low":
490
+ return {"answer": ABSTAIN_MESSAGE, "context": docs,
491
+ "scores": scores, "abstained": True}
492
  context_str = "\n\n".join(d.page_content for d in docs)
493
  answer = _call_llm(provider, model_id, context_str, input_text)
494
+ return {"answer": answer, "context": docs, "scores": scores, "abstained": False}
495
 
496
 
497
  # ---------------------------------------------------------------------------