Inframat-x commited on
Commit
79f6c83
·
verified ·
1 Parent(s): d15804c

Create rag_core.py

Browse files
Files changed (1) hide show
  1. rag_core.py +699 -0
rag_core.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag_core.py — RAG core + logging + grid evaluation (no UI)
2
+
3
+ import os
4
+ import re
5
+ import json
6
+ import time
7
+ import uuid
8
+ import traceback
9
+ from pathlib import Path
10
+ from typing import List, Dict, Any, Optional, Tuple
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+
15
+ # ---------------------- Optional deps ---------------------- #
16
+
17
+ USE_DENSE = True
18
+ try:
19
+ from sentence_transformers import SentenceTransformer
20
+ except Exception:
21
+ USE_DENSE = False
22
+
23
+ try:
24
+ from rank_bm25 import BM25Okapi
25
+ except Exception:
26
+ BM25Okapi = None
27
+ print("rank_bm25 not installed; BM25 disabled (TF-IDF still works).")
28
+
29
+ # Optional OpenAI (for LLM synthesis; not needed for retrieval eval)
30
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
31
+ OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-5")
32
+ try:
33
+ from openai import OpenAI
34
+ except Exception:
35
+ OpenAI = None
36
+
37
+ LLM_AVAILABLE = (
38
+ OPENAI_API_KEY is not None
39
+ and OPENAI_API_KEY.strip() != ""
40
+ and OpenAI is not None
41
+ )
42
+
43
+ # -------------------------- Paths & artifacts --------------------------- #
44
+
45
+ ARTIFACT_DIR = Path("rag_artifacts")
46
+ ARTIFACT_DIR.mkdir(exist_ok=True)
47
+ LOCAL_PDF_DIR = Path("papers")
48
+ LOCAL_PDF_DIR.mkdir(exist_ok=True)
49
+
50
+ TFIDF_VECT_PATH = ARTIFACT_DIR / "tfidf_vectorizer.joblib"
51
+ TFIDF_MAT_PATH = ARTIFACT_DIR / "tfidf_matrix.joblib"
52
+ BM25_TOK_PATH = ARTIFACT_DIR / "bm25_tokens.joblib"
53
+ EMB_NPY_PATH = ARTIFACT_DIR / "chunk_embeddings.npy"
54
+ RAG_META_PATH = ARTIFACT_DIR / "chunks.parquet"
55
+
56
+ LOG_PATH = ARTIFACT_DIR / "rag_logs.jsonl"
57
+
58
+ USE_ONLINE_SOURCES = os.getenv("USE_ONLINE_SOURCES", "false").lower() == "true"
59
+
60
+ # default hybrid weights
61
+ W_TFIDF_DEFAULT = 0.50 if not USE_DENSE else 0.30
62
+ W_BM25_DEFAULT = 0.50 if not USE_DENSE else 0.30
63
+ W_EMB_DEFAULT = 0.00 if not USE_DENSE else 0.40
64
+
65
+ # -------------------------- basic text helpers -------------------------- #
66
+
67
+ _SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+|\n+")
68
+ TOKEN_RE = re.compile(r"[A-Za-z0-9_#+\-/\.%]+")
69
+
70
+ def sent_split(text: str) -> List[str]:
71
+ sents = [s.strip() for s in _SENT_SPLIT_RE.split(text) if s.strip()]
72
+ return [s for s in sents if len(s.split()) >= 5]
73
+
74
+ def tokenize(text: str) -> List[str]:
75
+ return [t.lower() for t in TOKEN_RE.findall(text)]
76
+
77
+ # -------------------------- PDF text extraction ------------------------ #
78
+
79
+ def _extract_pdf_text(pdf_path: Path) -> str:
80
+ try:
81
+ import fitz # PyMuPDF
82
+ doc = fitz.open(pdf_path)
83
+ out = []
84
+ for i, page in enumerate(doc):
85
+ out.append(f"[[PAGE={i+1}]]\n{page.get_text('text') or ''}")
86
+ return "\n\n".join(out)
87
+ except Exception:
88
+ try:
89
+ from pypdf import PdfReader
90
+ reader = PdfReader(str(pdf_path))
91
+ out = []
92
+ for i, p in enumerate(reader.pages):
93
+ txt = p.extract_text() or ""
94
+ out.append(f"[[PAGE={i+1}]]\n{txt}")
95
+ return "\n\n".join(out)
96
+ except Exception as e:
97
+ print(f"PDF read error ({pdf_path}): {e}")
98
+ return ""
99
+
100
+ def chunk_by_sentence_windows(text: str, win_size: int = 8, overlap: int = 2) -> List[str]:
101
+ sents = sent_split(text)
102
+ chunks, step = [], max(1, win_size - overlap)
103
+ for i in range(0, len(sents), step):
104
+ window = sents[i:i+win_size]
105
+ if not window:
106
+ break
107
+ chunks.append(" ".join(window))
108
+ return chunks
109
+
110
+ # -------------------------- dense encoder -------------------------- #
111
+
112
+ def _safe_init_st_model(name: str):
113
+ global USE_DENSE
114
+ if not USE_DENSE:
115
+ return None
116
+ try:
117
+ return SentenceTransformer(name)
118
+ except Exception as e:
119
+ print("Dense embeddings unavailable:", e)
120
+ USE_DENSE = False
121
+ return None
122
+
123
+ # --------------------- build / load hybrid index --------------------- #
124
+
125
+ def build_or_load_hybrid(pdf_dir: Path):
126
+ from sklearn.feature_extraction.text import TfidfVectorizer
127
+ import joblib
128
+
129
+ have_cache = (
130
+ TFIDF_VECT_PATH.exists()
131
+ and TFIDF_MAT_PATH.exists()
132
+ and RAG_META_PATH.exists()
133
+ and (BM25_TOK_PATH.exists() or BM25Okapi is None)
134
+ and (EMB_NPY_PATH.exists() or not USE_DENSE)
135
+ )
136
+
137
+ if have_cache:
138
+ vectorizer = joblib.load(TFIDF_VECT_PATH)
139
+ X_tfidf = joblib.load(TFIDF_MAT_PATH)
140
+ meta = pd.read_parquet(RAG_META_PATH)
141
+ bm25_toks = joblib.load(BM25_TOK_PATH) if BM25Okapi is not None else None
142
+ emb = np.load(EMB_NPY_PATH) if (USE_DENSE and EMB_NPY_PATH.exists()) else None
143
+ return vectorizer, X_tfidf, meta, bm25_toks, emb
144
+
145
+ rows, all_tokens = [], []
146
+ pdf_paths = list(pdf_dir.glob("**/*.pdf"))
147
+ print(f"Indexing PDFs in {pdf_dir} — found {len(pdf_paths)} file(s).")
148
+ for pdf in pdf_paths:
149
+ raw = _extract_pdf_text(pdf)
150
+ if not raw.strip():
151
+ continue
152
+ for i, ch in enumerate(chunk_by_sentence_windows(raw, win_size=8, overlap=2)):
153
+ rows.append({"doc_path": str(pdf), "chunk_id": i, "text": ch})
154
+ all_tokens.append(tokenize(ch))
155
+
156
+ if not rows:
157
+ meta = pd.DataFrame(columns=["doc_path", "chunk_id", "text"])
158
+ return None, None, meta, None, None
159
+
160
+ meta = pd.DataFrame(rows)
161
+
162
+ vectorizer = TfidfVectorizer(
163
+ ngram_range=(1, 2),
164
+ min_df=1,
165
+ max_df=0.95,
166
+ sublinear_tf=True,
167
+ smooth_idf=True,
168
+ lowercase=True,
169
+ token_pattern=r"(?u)\b\w[\w\-\./%+#]*\b",
170
+ )
171
+ X_tfidf = vectorizer.fit_transform(meta["text"].tolist())
172
+
173
+ emb = None
174
+ if USE_DENSE:
175
+ try:
176
+ st_model = _safe_init_st_model(
177
+ os.getenv("EMB_MODEL_NAME", "sentence-transformers/all-MiniLM-L6-v2")
178
+ )
179
+ if st_model is not None:
180
+ from sklearn.preprocessing import normalize as sk_normalize
181
+ em = st_model.encode(
182
+ meta["text"].tolist(),
183
+ batch_size=64,
184
+ show_progress_bar=False,
185
+ convert_to_numpy=True,
186
+ )
187
+ emb = sk_normalize(em)
188
+ np.save(EMB_NPY_PATH, emb)
189
+ except Exception as e:
190
+ print("Dense embedding failed:", e)
191
+ emb = None
192
+
193
+ import joblib
194
+ joblib.dump(vectorizer, TFIDF_VECT_PATH)
195
+ joblib.dump(X_tfidf, TFIDF_MAT_PATH)
196
+ if BM25Okapi is not None:
197
+ joblib.dump(all_tokens, BM25_TOK_PATH)
198
+ meta.to_parquet(RAG_META_PATH, index=False)
199
+
200
+ return vectorizer, X_tfidf, meta, all_tokens, emb
201
+
202
+ tfidf_vectorizer, tfidf_matrix, rag_meta, bm25_tokens, emb_matrix = build_or_load_hybrid(
203
+ LOCAL_PDF_DIR
204
+ )
205
+ bm25 = BM25Okapi(bm25_tokens) if (BM25Okapi is not None and bm25_tokens is not None) else None
206
+ st_query_model = _safe_init_st_model(
207
+ os.getenv("EMB_MODEL_NAME", "sentence-transformers/all-MiniLM-L6-v2")
208
+ )
209
+
210
+ # -------------------------- hybrid retrieval -------------------------- #
211
+
212
+ def _extract_page(text_chunk: str) -> str:
213
+ m = list(re.finditer(r"\[\[PAGE=(\d+)\]\]", text_chunk or ""))
214
+ return m[-1].group(1) if m else "?"
215
+
216
+ def hybrid_search(
217
+ query: str,
218
+ k: int = 8,
219
+ w_tfidf: float = W_TFIDF_DEFAULT,
220
+ w_bm25: float = W_BM25_DEFAULT,
221
+ w_emb: float = W_EMB_DEFAULT,
222
+ ) -> pd.DataFrame:
223
+ if rag_meta is None or rag_meta.empty:
224
+ return pd.DataFrame()
225
+
226
+ n_chunks = len(rag_meta)
227
+
228
+ # dense scores
229
+ if USE_DENSE and st_query_model is not None and emb_matrix is not None and w_emb > 0:
230
+ try:
231
+ from sklearn.preprocessing import normalize as sk_normalize
232
+ q_emb = st_query_model.encode([query], convert_to_numpy=True)
233
+ q_emb = sk_normalize(q_emb)[0]
234
+ dense_scores = emb_matrix @ q_emb
235
+ except Exception as e:
236
+ print("Dense query encoding failed:", e)
237
+ dense_scores = np.zeros(n_chunks)
238
+ w_emb = 0.0
239
+ else:
240
+ dense_scores = np.zeros(n_chunks)
241
+ w_emb = 0.0
242
+
243
+ # tf-idf
244
+ if tfidf_vectorizer is not None and tfidf_matrix is not None:
245
+ q_vec = tfidf_vectorizer.transform([query])
246
+ tfidf_scores = (tfidf_matrix @ q_vec.T).toarray().ravel()
247
+ else:
248
+ tfidf_scores = np.zeros(n_chunks)
249
+ w_tfidf = 0.0
250
+
251
+ # bm25
252
+ if bm25 is not None:
253
+ q_tokens = [t.lower() for t in TOKEN_RE.findall(query)]
254
+ bm25_scores = np.array(bm25.get_scores(q_tokens), dtype=float)
255
+ else:
256
+ bm25_scores = np.zeros(n_chunks)
257
+ w_bm25 = 0.0
258
+
259
+ def _norm(x):
260
+ x = np.asarray(x, dtype=float)
261
+ if np.allclose(x.max(), x.min()):
262
+ return np.zeros_like(x)
263
+ return (x - x.min()) / (x.max() - x.min())
264
+
265
+ s_dense = _norm(dense_scores)
266
+ s_tfidf = _norm(tfidf_scores)
267
+ s_bm25 = _norm(bm25_scores)
268
+
269
+ total_w = (w_tfidf + w_bm25 + w_emb) or 1.0
270
+ w_tfidf, w_bm25, w_emb = (
271
+ w_tfidf / total_w,
272
+ w_bm25 / total_w,
273
+ w_emb / total_w,
274
+ )
275
+
276
+ combo = w_emb * s_dense + w_tfidf * s_tfidf + w_bm25 * s_bm25
277
+ idx = np.argsort(-combo)[:k]
278
+
279
+ hits = rag_meta.iloc[idx].copy()
280
+ hits["score_dense"] = s_dense[idx]
281
+ hits["score_tfidf"] = s_tfidf[idx]
282
+ hits["score_bm25"] = s_bm25[idx]
283
+ hits["score"] = combo[idx]
284
+ return hits.reset_index(drop=True)
285
+
286
+ # --------------------- MMR sentence selection --------------------- #
287
+
288
+ def split_sentences(text: str) -> List[str]:
289
+ sents = sent_split(text)
290
+ return [s for s in sents if 6 <= len(s.split()) <= 60]
291
+
292
+ def mmr_select_sentences(
293
+ question: str,
294
+ hits: pd.DataFrame,
295
+ top_n: int = 4,
296
+ pool_per_chunk: int = 6,
297
+ lambda_div: float = 0.7,
298
+ ) -> List[Dict[str, Any]]:
299
+ pool = []
300
+ for _, row in hits.iterrows():
301
+ doc = Path(row["doc_path"]).name
302
+ page = _extract_page(row["text"])
303
+ sents = split_sentences(row["text"])
304
+ if not sents:
305
+ continue
306
+ for s in sents[:max(1, int(pool_per_chunk))]:
307
+ pool.append({"sent": s, "doc": doc, "page": page})
308
+ if not pool:
309
+ return []
310
+
311
+ sent_texts = [p["sent"] for p in pool]
312
+ use_dense = USE_DENSE and st_query_model is not None
313
+
314
+ try:
315
+ if use_dense:
316
+ from sklearn.preprocessing import normalize as sk_normalize
317
+ enc = st_query_model.encode([question] + sent_texts, convert_to_numpy=True)
318
+ q_vec = sk_normalize(enc[:1])[0]
319
+ S = sk_normalize(enc[1:])
320
+ rel = S @ q_vec
321
+ def sim_fn(i, j): return float(S[i] @ S[j])
322
+ else:
323
+ from sklearn.feature_extraction.text import TfidfVectorizer
324
+ vect = TfidfVectorizer().fit(sent_texts + [question])
325
+ Q = vect.transform([question])
326
+ S = vect.transform(sent_texts)
327
+ rel = (S @ Q.T).toarray().ravel()
328
+ def sim_fn(i, j):
329
+ num = (S[i] @ S[j].T)
330
+ return float(num.toarray()[0, 0]) if hasattr(num, "toarray") else float(num)
331
+ except Exception:
332
+ rel = np.ones(len(sent_texts))
333
+ def sim_fn(i, j): return 0.0
334
+
335
+ lambda_div = float(np.clip(lambda_div, 0.0, 1.0))
336
+
337
+ remain = list(range(len(pool)))
338
+ first = int(np.argmax(rel))
339
+ selected_idx = [first]
340
+ selected = [pool[first]]
341
+ remain.remove(first)
342
+
343
+ max_pick = min(int(top_n), len(pool))
344
+ while len(selected) < max_pick and remain:
345
+ cand_scores: List[Tuple[float, int]] = []
346
+ for i in remain:
347
+ div_i = max(sim_fn(i, j) for j in selected_idx) if selected_idx else 0.0
348
+ score = lambda_div * float(rel[i]) - (1.0 - lambda_div) * div_i
349
+ cand_scores.append((score, i))
350
+ cand_scores.sort(reverse=True)
351
+ _, best_i = cand_scores[0]
352
+ selected_idx.append(best_i)
353
+ selected.append(pool[best_i])
354
+ remain.remove(best_i)
355
+
356
+ return selected
357
+
358
+ def compose_extractive(selected: List[Dict[str, Any]]) -> str:
359
+ if not selected:
360
+ return ""
361
+ return " ".join(f"{s['sent']} ({s['doc']}, p.{s['page']})" for s in selected)
362
+
363
+ # --------------------------- logging helpers --------------------------- #
364
+
365
+ OPENAI_IN_COST_PER_1K = float(os.getenv("OPENAI_COST_IN_PER_1K", "0"))
366
+ OPENAI_OUT_COST_PER_1K = float(os.getenv("OPENAI_COST_OUT_PER_1K", "0"))
367
+
368
+ def _safe_write_jsonl(path: Path, record: dict):
369
+ try:
370
+ with open(path, "a", encoding="utf-8") as f:
371
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
372
+ except Exception as e:
373
+ print("[Log] write failed:", e)
374
+
375
+ def _calc_cost_usd(prompt_toks, completion_toks):
376
+ if prompt_toks is None or completion_toks is None:
377
+ return None
378
+ return (prompt_toks / 1000.0) * OPENAI_IN_COST_PER_1K + (
379
+ completion_toks / 1000.0
380
+ ) * OPENAI_OUT_COST_PER_1K
381
+
382
+ # ------------------------ optional LLM synthesis ------------------------ #
383
+
384
+ def synthesize_with_llm(
385
+ question: str,
386
+ sentence_lines: List[str],
387
+ model: Optional[str] = None,
388
+ temperature: float = 0.2,
389
+ ):
390
+ if not LLM_AVAILABLE:
391
+ return None, None
392
+ client = OpenAI(api_key=OPENAI_API_KEY)
393
+ model = model or OPENAI_MODEL
394
+
395
+ SYSTEM_PROMPT = (
396
+ "You are a scientific assistant for self-sensing cementitious materials.\n"
397
+ "Answer STRICTLY using the provided sentences.\n"
398
+ "Do not invent facts. Keep it concise (3–6 sentences).\n"
399
+ "Retain inline citations like (Doc.pdf, p.X) exactly as given."
400
+ )
401
+ user_prompt = (
402
+ f"Question: {question}\n\n"
403
+ "Use ONLY these sentences to answer; keep their inline citations:\n"
404
+ + "\n".join(f"- {s}" for s in sentence_lines)
405
+ )
406
+
407
+ try:
408
+ resp = client.responses.create(
409
+ model=model,
410
+ input=[
411
+ {"role": "system", "content": SYSTEM_PROMPT},
412
+ {"role": "user", "content": user_prompt},
413
+ ],
414
+ temperature=temperature,
415
+ )
416
+ out_text = getattr(resp, "output_text", None) or str(resp)
417
+ usage = None
418
+ try:
419
+ u = getattr(resp, "usage", None)
420
+ if u:
421
+ pt = getattr(u, "prompt_tokens", None) if hasattr(u, "prompt_tokens") else u.get("prompt_tokens", None)
422
+ ct = getattr(u, "completion_tokens", None) if hasattr(u, "completion_tokens") else u.get("completion_tokens", None)
423
+ usage = {"prompt_tokens": pt, "completion_tokens": ct}
424
+ except Exception:
425
+ usage = None
426
+ return out_text, usage
427
+ except Exception:
428
+ return None, None
429
+
430
+ # ------------------- main RAG reply (with config_id) ------------------- #
431
+
432
+ def rag_reply(
433
+ question: str,
434
+ k: int = 8,
435
+ n_sentences: int = 4,
436
+ include_passages: bool = False,
437
+ use_llm: bool = False,
438
+ model: Optional[str] = None,
439
+ temperature: float = 0.2,
440
+ strict_quotes_only: bool = False,
441
+ w_tfidf: float = W_TFIDF_DEFAULT,
442
+ w_bm25: float = W_BM25_DEFAULT,
443
+ w_emb: float = W_EMB_DEFAULT,
444
+ config_id: Optional[str] = None,
445
+ ) -> str:
446
+ run_id = str(uuid.uuid4())
447
+ t0_total = time.time()
448
+ t0_retr = time.time()
449
+
450
+ hits = hybrid_search(
451
+ question,
452
+ k=int(k),
453
+ w_tfidf=float(w_tfidf),
454
+ w_bm25=float(w_bm25),
455
+ w_emb=float(w_emb),
456
+ )
457
+ t1_retr = time.time()
458
+ latency_ms_retriever = int((t1_retr - t0_retr) * 1000)
459
+
460
+ if hits is None or hits.empty:
461
+ final = "No indexed PDFs found."
462
+ record = {
463
+ "run_id": run_id,
464
+ "ts": int(time.time() * 1000),
465
+ "inputs": {
466
+ "question": question,
467
+ "top_k": int(k),
468
+ "n_sentences": int(n_sentences),
469
+ "w_tfidf": float(w_tfidf),
470
+ "w_bm25": float(w_bm25),
471
+ "w_emb": float(w_emb),
472
+ "use_llm": bool(use_llm),
473
+ "model": model,
474
+ "temperature": float(temperature),
475
+ "config_id": config_id,
476
+ },
477
+ "retrieval": {"hits": [], "latency_ms_retriever": latency_ms_retriever},
478
+ "output": {"final_answer": final, "used_sentences": []},
479
+ "latency_ms_total": int((time.time() - t0_total) * 1000),
480
+ "openai": None,
481
+ }
482
+ _safe_write_jsonl(LOG_PATH, record)
483
+ return final
484
+
485
+ selected = mmr_select_sentences(
486
+ question, hits, top_n=int(n_sentences), pool_per_chunk=6, lambda_div=0.7
487
+ )
488
+ header_cites = "; ".join(
489
+ f"{Path(r['doc_path']).name} (p.{_extract_page(r['text'])})"
490
+ for _, r in hits.head(6).iterrows()
491
+ )
492
+
493
+ srcs = {Path(r["doc_path"]).name for _, r in hits.iterrows()}
494
+ coverage_note = (
495
+ ""
496
+ if len(srcs) >= 3
497
+ else f"\n\n> Note: Only {len(srcs)} unique source(s). Add more PDFs or increase Top-K."
498
+ )
499
+
500
+ retr_list = []
501
+ for _, r in hits.iterrows():
502
+ retr_list.append(
503
+ {
504
+ "doc": Path(r["doc_path"]).name,
505
+ "page": _extract_page(r["text"]),
506
+ "score_tfidf": float(r.get("score_tfidf", 0.0)),
507
+ "score_bm25": float(r.get("score_bm25", 0.0)),
508
+ "score_dense": float(r.get("score_dense", 0.0)),
509
+ "combo_score": float(r.get("score", 0.0)),
510
+ }
511
+ )
512
+
513
+ # retrieval-only / strict quotations (useful for grid eval)
514
+ if strict_quotes_only:
515
+ if not selected:
516
+ final = (
517
+ f"**Quoted Passages:**\n\n---\n"
518
+ + "\n\n".join(hits["text"].tolist()[:2])
519
+ + f"\n\n**Citations:** {header_cites}{coverage_note}"
520
+ )
521
+ else:
522
+ final = "**Quoted Passages:**\n- " + "\n- ".join(
523
+ f"{s['sent']} ({s['doc']}, p.{s['page']})" for s in selected
524
+ )
525
+ final += f"\n\n**Citations:** {header_cites}{coverage_note}"
526
+ if include_passages:
527
+ final += "\n\n---\n" + "\n\n".join(hits["text"].tolist()[:2])
528
+
529
+ record = {
530
+ "run_id": run_id,
531
+ "ts": int(time.time() * 1000),
532
+ "inputs": {
533
+ "question": question,
534
+ "top_k": int(k),
535
+ "n_sentences": int(n_sentences),
536
+ "w_tfidf": float(w_tfidf),
537
+ "w_bm25": float(w_bm25),
538
+ "w_emb": float(w_emb),
539
+ "use_llm": False,
540
+ "model": None,
541
+ "temperature": float(temperature),
542
+ "config_id": config_id,
543
+ },
544
+ "retrieval": {"hits": retr_list, "latency_ms_retriever": latency_ms_retriever},
545
+ "output": {
546
+ "final_answer": final,
547
+ "used_sentences": [
548
+ {"sent": s["sent"], "doc": s["doc"], "page": s["page"]}
549
+ for s in selected
550
+ ],
551
+ },
552
+ "latency_ms_total": int((time.time() - t0_total) * 1000),
553
+ "openai": None,
554
+ }
555
+ _safe_write_jsonl(LOG_PATH, record)
556
+ return final
557
+
558
+ # extractive / LLM synthesis
559
+ extractive = compose_extractive(selected)
560
+ llm_usage = None
561
+ llm_latency_ms = None
562
+
563
+ if use_llm and selected:
564
+ lines = [f"{s['sent']} ({s['doc']}, p.{s['page']})" for s in selected]
565
+ t0_llm = time.time()
566
+ llm_text, llm_usage = synthesize_with_llm(
567
+ question, lines, model=model, temperature=temperature
568
+ )
569
+ t1_llm = time.time()
570
+ llm_latency_ms = int((t1_llm - t0_llm) * 1000)
571
+
572
+ if llm_text:
573
+ final = (
574
+ f"**Answer (LLM synthesis):** {llm_text}\n\n"
575
+ f"**Citations:** {header_cites}{coverage_note}"
576
+ )
577
+ if include_passages:
578
+ final += "\n\n---\n" + "\n\n".join(hits["text"].tolist()[:2])
579
+ else:
580
+ if not extractive:
581
+ final = (
582
+ f"**Answer:** Here are relevant passages.\n\n"
583
+ f"**Citations:** {header_cites}{coverage_note}\n\n---\n"
584
+ + "\n\n".join(hits["text"].tolist()[:2])
585
+ )
586
+ else:
587
+ final = (
588
+ f"**Answer:** {extractive}\n\n"
589
+ f"**Citations:** {header_cites}{coverage_note}"
590
+ )
591
+ if include_passages:
592
+ final += "\n\n---\n" + "\n\n".join(hits["text"].tolist()[:2])
593
+ else:
594
+ if not extractive:
595
+ final = (
596
+ f"**Answer:** Here are relevant passages.\n\n"
597
+ f"**Citations:** {header_cites}{coverage_note}\n\n---\n"
598
+ + "\n\n".join(hits["text"].tolist()[:2])
599
+ )
600
+ else:
601
+ final = (
602
+ f"**Answer:** {extractive}\n\n"
603
+ f"**Citations:** {header_cites}{coverage_note}"
604
+ )
605
+ if include_passages:
606
+ final += "\n\n---\n" + "\n\n".join(hits["text"].tolist()[:2])
607
+
608
+ prompt_toks = llm_usage.get("prompt_tokens") if llm_usage else None
609
+ completion_toks = llm_usage.get("completion_tokens") if llm_usage else None
610
+ cost_usd = _calc_cost_usd(prompt_toks, completion_toks)
611
+
612
+ total_ms = int((time.time() - t0_total) * 1000)
613
+ record = {
614
+ "run_id": run_id,
615
+ "ts": int(time.time() * 1000),
616
+ "inputs": {
617
+ "question": question,
618
+ "top_k": int(k),
619
+ "n_sentences": int(n_sentences),
620
+ "w_tfidf": float(w_tfidf),
621
+ "w_bm25": float(w_bm25),
622
+ "w_emb": float(w_emb),
623
+ "use_llm": bool(use_llm),
624
+ "model": model,
625
+ "temperature": float(temperature),
626
+ "config_id": config_id,
627
+ },
628
+ "retrieval": {"hits": retr_list, "latency_ms_retriever": latency_ms_retriever},
629
+ "output": {
630
+ "final_answer": final,
631
+ "used_sentences": [
632
+ {"sent": s["sent"], "doc": s["doc"], "page": s["page"]}
633
+ for s in selected
634
+ ],
635
+ },
636
+ "latency_ms_total": total_ms,
637
+ "latency_ms_llm": llm_latency_ms,
638
+ "openai": {
639
+ "prompt_tokens": prompt_toks,
640
+ "completion_tokens": completion_toks,
641
+ "cost_usd": cost_usd,
642
+ }
643
+ if use_llm
644
+ else None,
645
+ }
646
+ _safe_write_jsonl(LOG_PATH, record)
647
+ return final
648
+
649
+ # --------------- automated grid evaluation over weights --------------- #
650
+
651
+ def run_weight_grid_eval(
652
+ gold_csv: str,
653
+ weight_grid: List[Dict[str, float]],
654
+ k: int = 8,
655
+ n_sentences: int = 4,
656
+ ) -> None:
657
+ """
658
+ Automatically evaluate many (w_tfidf, w_bm25, w_emb) combinations
659
+ on the full gold question set.
660
+
661
+ - Reads questions from gold_csv (column 'question')
662
+ - For each configuration in weight_grid, calls rag_reply(...)
663
+ with use_llm=False and strict_quotes_only=True
664
+ - All runs are logged into rag_logs.jsonl with a 'config_id'
665
+ and the exact weights.
666
+ """
667
+ gold_df = pd.read_csv(gold_csv)
668
+ if "question" not in gold_df.columns:
669
+ raise ValueError("gold_csv must contain a 'question' column.")
670
+ questions = gold_df["question"].astype(str).tolist()
671
+
672
+ for cfg in weight_grid:
673
+ wt = float(cfg.get("w_tfidf", 0.0))
674
+ wb = float(cfg.get("w_bm25", 0.0))
675
+ we = float(cfg.get("w_emb", 0.0))
676
+ cid = cfg.get("id") or f"tfidf{wt}_bm25{wb}_emb{we}"
677
+
678
+ print(
679
+ f"\n[GridEval] Running config {cid} "
680
+ f"(w_tfidf={wt}, w_bm25={wb}, w_emb={we}, k={k})"
681
+ )
682
+
683
+ for q in questions:
684
+ _ = rag_reply(
685
+ question=q,
686
+ k=int(k),
687
+ n_sentences=int(n_sentences),
688
+ include_passages=False,
689
+ use_llm=False,
690
+ model=None,
691
+ temperature=0.0,
692
+ strict_quotes_only=True,
693
+ w_tfidf=wt,
694
+ w_bm25=wb,
695
+ w_emb=we,
696
+ config_id=cid,
697
+ )
698
+
699
+ print("✅ RAG core + grid evaluation helpers loaded.")