Inframat-x commited on
Commit
301030b
·
verified ·
1 Parent(s): 996cbb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +705 -91
app.py CHANGED
@@ -1,21 +1,684 @@
1
- # app.py (UI wrapper with download buttons for evaluation outputs)
 
2
 
3
- import os
4
  from pathlib import Path
 
 
 
 
5
  import gradio as gr
6
 
7
- from rag_core import ( # your core RAG logic (previously app, now separated)
8
- rag_reply,
9
- W_TFIDF_DEFAULT,
10
- W_BM25_DEFAULT,
11
- W_EMB_DEFAULT,
12
- LOG_PATH,
13
- ARTIFACT_DIR,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  )
15
- from rag_eval_metrics import evaluate_rag
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # ------------- RAG chat wrapper ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def rag_chat_fn(
20
  message,
21
  history,
@@ -40,24 +703,20 @@ def rag_chat_fn(
40
  w_tfidf=float(w_tfidf),
41
  w_bm25=float(w_bm25),
42
  w_emb=float(w_emb),
43
- config_id=None, # manual chat → no grid config id
44
  )
45
 
46
-
47
- # ------------- Evaluate wrapper (runs metrics) ----------------
48
  def run_eval_ui(gold_file, k):
49
- """
50
- Trigger evaluation:
51
- - Uses uploaded gold.csv (or fallback gold.csv in repo root)
52
- - Compares against rag_artifacts/rag_logs.jsonl
53
- - Writes metrics_per_question.csv, metrics_aggregate.json, metrics_by_weights.csv
54
- """
55
  if gold_file is None:
56
  gold_path = Path("gold.csv")
57
  if not gold_path.exists():
58
  return (
59
  "**No gold.csv provided or found in the working directory.**\n"
60
- "Upload a file or place gold.csv next to app.py."
 
 
 
61
  )
62
  gold_csv = str(gold_path)
63
  else:
@@ -66,7 +725,7 @@ def run_eval_ui(gold_file, k):
66
  logs_jsonl = str(LOG_PATH)
67
  out_dir = str(ARTIFACT_DIR)
68
 
69
- # This prints to the Space logs and writes CSV/JSON into rag_artifacts/
70
  evaluate_rag(
71
  gold_csv,
72
  logs_jsonl,
@@ -75,21 +734,6 @@ def run_eval_ui(gold_file, k):
75
  group_by_weights=True,
76
  )
77
 
78
- return (
79
- "✅ Evaluation finished.\n\n"
80
- f"- Per-question metrics: `"{ARTIFACT_DIR / 'metrics_per_question.csv'}"`\n"
81
- f"- Aggregate metrics: `"{ARTIFACT_DIR / 'metrics_aggregate.json'}"`\n"
82
- f"- Config surface: `"{ARTIFACT_DIR / 'metrics_by_weights.csv'}"`\n\n"
83
- "Use **Download Evaluation Results** below to fetch the files."
84
- )
85
-
86
-
87
- # ------------- Download helper for evaluation outputs -------------
88
- def download_metrics():
89
- """
90
- Return paths to evaluation artifacts so Gradio shows them as downloadable files.
91
- If a file does not exist yet, returns None for that output.
92
- """
93
  base = ARTIFACT_DIR
94
  perq = base / "metrics_per_question.csv"
95
  agg = base / "metrics_aggregate.json"
@@ -99,7 +743,15 @@ def download_metrics():
99
  agg_path = str(agg) if agg.exists() else None
100
  surf_path = str(surf) if surf.exists() else None
101
 
102
- return perq_path, agg_path, surf_path
 
 
 
 
 
 
 
 
103
 
104
 
105
  # ------------- Build Gradio UI -----------------
@@ -116,42 +768,18 @@ with gr.Blocks(title="Self-Sensing Concrete RAG") as demo:
116
  top_k = gr.Slider(3, 15, value=8, step=1, label="Top-K chunks")
117
  n_sentences = gr.Slider(2, 8, value=4, step=1, label="Answer length (sentences)")
118
  include_passages = gr.Checkbox(
119
- value=False,
120
- label="Include supporting passages",
121
  )
122
-
123
  with gr.Row():
124
- w_tfidf = gr.Slider(
125
- 0.0, 1.0,
126
- value=W_TFIDF_DEFAULT,
127
- step=0.05,
128
- label="TF-IDF weight",
129
- )
130
- w_bm25 = gr.Slider(
131
- 0.0, 1.0,
132
- value=W_BM25_DEFAULT,
133
- step=0.05,
134
- label="BM25 weight",
135
- )
136
- w_emb = gr.Slider(
137
- 0.0, 1.0,
138
- value=W_EMB_DEFAULT,
139
- step=0.05,
140
- label="Dense weight",
141
- )
142
 
143
  gr.ChatInterface(
144
  fn=rag_chat_fn,
145
- additional_inputs=[
146
- top_k,
147
- n_sentences,
148
- include_passages,
149
- w_tfidf,
150
- w_bm25,
151
- w_emb,
152
- ],
153
  title="Hybrid RAG Q&A",
154
- description="Hybrid BM25 + TF-IDF + dense retrieval with MMR sentence selection.",
155
  )
156
 
157
  # --------- Evaluation tab ---------
@@ -159,39 +787,25 @@ with gr.Blocks(title="Self-Sensing Concrete RAG") as demo:
159
  gr.Markdown(
160
  "Upload **gold.csv** and compute retrieval metrics against "
161
  "`rag_artifacts/rag_logs.jsonl`.\n\n"
162
- "Then use **Download Evaluation Results** to fetch the CSV/JSON files."
 
163
  )
164
-
165
  gold_file = gr.File(label="gold.csv", file_types=[".csv"])
166
  k_slider = gr.Slider(3, 15, value=8, step=1, label="k for Hit/Recall/nDCG")
167
 
168
- with gr.Row():
169
- btn_eval = gr.Button("Run Evaluation", variant="primary")
170
- eval_out = gr.Markdown(label="Evaluation log")
 
 
 
171
 
172
  btn_eval.click(
173
  fn=run_eval_ui,
174
  inputs=[gold_file, k_slider],
175
- outputs=eval_out,
176
- )
177
-
178
- gr.Markdown("### Download Evaluation Results")
179
-
180
- # Three file outputs (they become download links in the Space)
181
- with gr.Row():
182
- perq_file = gr.File(label="Per-question metrics (CSV)")
183
- agg_file = gr.File(label="Aggregate metrics (JSON)")
184
- surf_file = gr.File(label="Config surface (CSV)")
185
-
186
- btn_download = gr.Button("Download Evaluation Results")
187
-
188
- btn_download.click(
189
- fn=download_metrics,
190
- inputs=None,
191
- outputs=[perq_file, agg_file, surf_file],
192
  )
193
 
194
-
195
  # ------------- Launch app -----------------
196
  if __name__ == "__main__":
197
  demo.queue().launch(
 
1
+ # ===================== app.py =====================
2
+ # RAG core + logging + grid eval + Gradio UI wrapper
3
 
4
+ import os, re, json, time, uuid, traceback
5
  from pathlib import Path
6
+ from typing import List, Dict, Any, Optional
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
  import gradio as gr
11
 
12
+ from rag_eval_metrics import evaluate_rag # must exist as rag_eval_metrics.py
13
+
14
+
15
+ # ===================== RAG CORE + LOGGING + GRID EVAL =====================
16
+
17
+ # Optional dense + BM25
18
+ USE_DENSE = True
19
+ try:
20
+ from sentence_transformers import SentenceTransformer
21
+ except Exception:
22
+ USE_DENSE = False
23
+
24
+ try:
25
+ from rank_bm25 import BM25Okapi
26
+ except Exception:
27
+ BM25Okapi = None
28
+ print("rank_bm25 not installed; BM25 disabled (TF-IDF still works).")
29
+
30
+ # Optional OpenAI (for LLM synthesis, not needed for grid search)
31
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
32
+ OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-5")
33
+ try:
34
+ from openai import OpenAI
35
+ except Exception:
36
+ OpenAI = None
37
+
38
+ LLM_AVAILABLE = (
39
+ OPENAI_API_KEY is not None and OPENAI_API_KEY.strip() != "" and OpenAI is not None
40
+ )
41
+
42
+ # -------------------------- Paths & artifacts ---------------------------
43
+ ARTIFACT_DIR = Path("rag_artifacts"); ARTIFACT_DIR.mkdir(exist_ok=True)
44
+ LOCAL_PDF_DIR = Path("papers"); LOCAL_PDF_DIR.mkdir(exist_ok=True)
45
+
46
+ TFIDF_VECT_PATH = ARTIFACT_DIR / "tfidf_vectorizer.joblib"
47
+ TFIDF_MAT_PATH = ARTIFACT_DIR / "tfidf_matrix.joblib"
48
+ BM25_TOK_PATH = ARTIFACT_DIR / "bm25_tokens.joblib"
49
+ EMB_NPY_PATH = ARTIFACT_DIR / "chunk_embeddings.npy"
50
+ RAG_META_PATH = ARTIFACT_DIR / "chunks.parquet"
51
+
52
+ LOG_PATH = ARTIFACT_DIR / "rag_logs.jsonl"
53
+
54
+ USE_ONLINE_SOURCES = os.getenv("USE_ONLINE_SOURCES", "false").lower() == "true"
55
+
56
+ # default hybrid weights
57
+ W_TFIDF_DEFAULT = 0.50 if not USE_DENSE else 0.30
58
+ W_BM25_DEFAULT = 0.50 if not USE_DENSE else 0.30
59
+ W_EMB_DEFAULT = 0.00 if not USE_DENSE else 0.40
60
+
61
+ # ---------- basic text helpers ----------
62
+ _SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+|\n+")
63
+ TOKEN_RE = re.compile(r"[A-Za-z0-9_#+\-/\.%]+")
64
+
65
+ def sent_split(text: str) -> List[str]:
66
+ sents = [s.strip() for s in _SENT_SPLIT_RE.split(text) if s.strip()]
67
+ return [s for s in sents if len(s.split()) >= 5]
68
+
69
+ def tokenize(text: str) -> List[str]:
70
+ return [t.lower() for t in TOKEN_RE.findall(text)]
71
+
72
+ # ---------- PDF text extraction ----------
73
+ def _extract_pdf_text(pdf_path: Path) -> str:
74
+ try:
75
+ import fitz # PyMuPDF
76
+ doc = fitz.open(pdf_path)
77
+ out = []
78
+ for i, page in enumerate(doc):
79
+ out.append(f"[[PAGE={i+1}]]\n{page.get_text('text') or ''}")
80
+ return "\n\n".join(out)
81
+ except Exception:
82
+ try:
83
+ from pypdf import PdfReader
84
+ reader = PdfReader(str(pdf_path))
85
+ out = []
86
+ for i, p in enumerate(reader.pages):
87
+ txt = p.extract_text() or ""
88
+ out.append(f"[[PAGE={i+1}]]\n{txt}")
89
+ return "\n\n".join(out)
90
+ except Exception as e:
91
+ print(f"PDF read error ({pdf_path}): {e}")
92
+ return ""
93
+
94
+ def chunk_by_sentence_windows(text: str, win_size=8, overlap=2) -> List[str]:
95
+ sents = sent_split(text)
96
+ chunks, step = [], max(1, win_size - overlap)
97
+ for i in range(0, len(sents), step):
98
+ window = sents[i:i+win_size]
99
+ if not window:
100
+ break
101
+ chunks.append(" ".join(window))
102
+ return chunks
103
+
104
+ # ---------- dense encoder ----------
105
+ def _safe_init_st_model(name: str):
106
+ global USE_DENSE
107
+ if not USE_DENSE:
108
+ return None
109
+ try:
110
+ return SentenceTransformer(name)
111
+ except Exception as e:
112
+ print("Dense embeddings unavailable:", e)
113
+ USE_DENSE = False
114
+ return None
115
+
116
+ # ---------- build / load hybrid index ----------
117
+ def build_or_load_hybrid(pdf_dir: Path):
118
+ from sklearn.feature_extraction.text import TfidfVectorizer
119
+ import joblib
120
+
121
+ have_cache = (
122
+ TFIDF_VECT_PATH.exists()
123
+ and TFIDF_MAT_PATH.exists()
124
+ and RAG_META_PATH.exists()
125
+ and (BM25_TOK_PATH.exists() or BM25Okapi is None)
126
+ and (EMB_NPY_PATH.exists() or not USE_DENSE)
127
+ )
128
+
129
+ if have_cache:
130
+ vectorizer = joblib.load(TFIDF_VECT_PATH)
131
+ X_tfidf = joblib.load(TFIDF_MAT_PATH)
132
+ meta = pd.read_parquet(RAG_META_PATH)
133
+ bm25_toks = joblib.load(BM25_TOK_PATH) if BM25Okapi is not None else None
134
+ emb = np.load(EMB_NPY_PATH) if (USE_DENSE and EMB_NPY_PATH.exists()) else None
135
+ return vectorizer, X_tfidf, meta, bm25_toks, emb
136
+
137
+ rows, all_tokens = [], []
138
+ pdf_paths = list(pdf_dir.glob("**/*.pdf"))
139
+ print(f"Indexing PDFs in {pdf_dir} — found {len(pdf_paths)} file(s).")
140
+ for pdf in pdf_paths:
141
+ raw = _extract_pdf_text(pdf)
142
+ if not raw.strip():
143
+ continue
144
+ for i, ch in enumerate(chunk_by_sentence_windows(raw, win_size=8, overlap=2)):
145
+ rows.append({"doc_path": str(pdf), "chunk_id": i, "text": ch})
146
+ all_tokens.append(tokenize(ch))
147
+
148
+ if not rows:
149
+ meta = pd.DataFrame(columns=["doc_path", "chunk_id", "text"])
150
+ return None, None, meta, None, None
151
+
152
+ meta = pd.DataFrame(rows)
153
+
154
+ vectorizer = TfidfVectorizer(
155
+ ngram_range=(1, 2),
156
+ min_df=1,
157
+ max_df=0.95,
158
+ sublinear_tf=True,
159
+ smooth_idf=True,
160
+ lowercase=True,
161
+ token_pattern=r"(?u)\b\w[\w\-\./%+#]*\b",
162
+ )
163
+ X_tfidf = vectorizer.fit_transform(meta["text"].tolist())
164
+
165
+ emb = None
166
+ if USE_DENSE:
167
+ try:
168
+ st_model = _safe_init_st_model(
169
+ os.getenv("EMB_MODEL_NAME", "sentence-transformers/all-MiniLM-L6-v2")
170
+ )
171
+ if st_model is not None:
172
+ from sklearn.preprocessing import normalize as sk_normalize
173
+ em = st_model.encode(
174
+ meta["text"].tolist(),
175
+ batch_size=64,
176
+ show_progress_bar=False,
177
+ convert_to_numpy=True,
178
+ )
179
+ emb = sk_normalize(em)
180
+ np.save(EMB_NPY_PATH, emb)
181
+ except Exception as e:
182
+ print("Dense embedding failed:", e)
183
+ emb = None
184
+
185
+ import joblib
186
+ joblib.dump(vectorizer, TFIDF_VECT_PATH)
187
+ joblib.dump(X_tfidf, TFIDF_MAT_PATH)
188
+ if BM25Okapi is not None:
189
+ joblib.dump(all_tokens, BM25_TOK_PATH)
190
+ meta.to_parquet(RAG_META_PATH, index=False)
191
+
192
+ return vectorizer, X_tfidf, meta, all_tokens, emb
193
+
194
+ tfidf_vectorizer, tfidf_matrix, rag_meta, bm25_tokens, emb_matrix = build_or_load_hybrid(
195
+ LOCAL_PDF_DIR
196
  )
197
+ bm25 = BM25Okapi(bm25_tokens) if (BM25Okapi is not None and bm25_tokens is not None) else None
198
+ st_query_model = _safe_init_st_model(
199
+ os.getenv("EMB_MODEL_NAME", "sentence-transformers/all-MiniLM-L6-v2")
200
+ )
201
+
202
+ # ---------- hybrid retrieval ----------
203
+ def _extract_page(text_chunk: str) -> str:
204
+ m = list(re.finditer(r"\[\[PAGE=(\d+)\]\]", text_chunk or ""))
205
+ return m[-1].group(1) if m else "?"
206
+
207
+ def hybrid_search(
208
+ query: str,
209
+ k: int = 8,
210
+ w_tfidf: float = W_TFIDF_DEFAULT,
211
+ w_bm25: float = W_BM25_DEFAULT,
212
+ w_emb: float = W_EMB_DEFAULT,
213
+ ) -> pd.DataFrame:
214
+ if rag_meta is None or rag_meta.empty:
215
+ return pd.DataFrame()
216
+
217
+ n_chunks = len(rag_meta)
218
+
219
+ # dense scores
220
+ if USE_DENSE and st_query_model is not None and emb_matrix is not None and w_emb > 0:
221
+ try:
222
+ from sklearn.preprocessing import normalize as sk_normalize
223
+ q_emb = st_query_model.encode([query], convert_to_numpy=True)
224
+ q_emb = sk_normalize(q_emb)[0]
225
+ dense_scores = emb_matrix @ q_emb
226
+ except Exception as e:
227
+ print("Dense query encoding failed:", e)
228
+ dense_scores = np.zeros(n_chunks); w_emb = 0.0
229
+ else:
230
+ dense_scores = np.zeros(n_chunks); w_emb = 0.0
231
+
232
+ # tf-idf
233
+ if tfidf_vectorizer is not None and tfidf_matrix is not None:
234
+ q_vec = tfidf_vectorizer.transform([query])
235
+ tfidf_scores = (tfidf_matrix @ q_vec.T).toarray().ravel()
236
+ else:
237
+ tfidf_scores = np.zeros(n_chunks); w_tfidf = 0.0
238
+
239
+ # bm25
240
+ if bm25 is not None:
241
+ q_tokens = [t.lower() for t in TOKEN_RE.findall(query)]
242
+ bm25_scores = np.array(bm25.get_scores(q_tokens), dtype=float)
243
+ else:
244
+ bm25_scores = np.zeros(n_chunks); w_bm25 = 0.0
245
+
246
+ def _norm(x):
247
+ x = np.asarray(x, dtype=float)
248
+ if np.allclose(x.max(), x.min()):
249
+ return np.zeros_like(x)
250
+ return (x - x.min()) / (x.max() - x.min())
251
+
252
+ s_dense = _norm(dense_scores)
253
+ s_tfidf = _norm(tfidf_scores)
254
+ s_bm25 = _norm(bm25_scores)
255
+
256
+ total_w = (w_tfidf + w_bm25 + w_emb) or 1.0
257
+ w_tfidf, w_bm25, w_emb = w_tfidf / total_w, w_bm25 / total_w, w_emb / total_w
258
+
259
+ combo = w_emb * s_dense + w_tfidf * s_tfidf + w_bm25 * s_bm25
260
+ idx = np.argsort(-combo)[:k]
261
+
262
+ hits = rag_meta.iloc[idx].copy()
263
+ hits["score_dense"] = s_dense[idx]
264
+ hits["score_tfidf"] = s_tfidf[idx]
265
+ hits["score_bm25"] = s_bm25[idx]
266
+ hits["score"] = combo[idx]
267
+ return hits.reset_index(drop=True)
268
+
269
+ # ---------- MMR sentence selection ----------
270
+ def split_sentences(text: str) -> List[str]:
271
+ sents = sent_split(text)
272
+ return [s for s in sents if 6 <= len(s.split()) <= 60]
273
+
274
+ def mmr_select_sentences(
275
+ question: str,
276
+ hits: pd.DataFrame,
277
+ top_n: int = 4,
278
+ pool_per_chunk: int = 6,
279
+ lambda_div: float = 0.7,
280
+ ) -> List[Dict[str, Any]]:
281
+ pool = []
282
+ for _, row in hits.iterrows():
283
+ doc = Path(row["doc_path"]).name
284
+ page = _extract_page(row["text"])
285
+ sents = split_sentences(row["text"])
286
+ if not sents:
287
+ continue
288
+ for s in sents[:max(1, int(pool_per_chunk))]:
289
+ pool.append({"sent": s, "doc": doc, "page": page})
290
+ if not pool:
291
+ return []
292
+
293
+ sent_texts = [p["sent"] for p in pool]
294
+ use_dense = USE_DENSE and st_query_model is not None
295
+
296
+ try:
297
+ if use_dense:
298
+ from sklearn.preprocessing import normalize as sk_normalize
299
+ enc = st_query_model.encode([question] + sent_texts, convert_to_numpy=True)
300
+ q_vec = sk_normalize(enc[:1])[0]
301
+ S = sk_normalize(enc[1:])
302
+ rel = S @ q_vec
303
+ def sim_fn(i, j): return float(S[i] @ S[j])
304
+ else:
305
+ from sklearn.feature_extraction.text import TfidfVectorizer
306
+ vect = TfidfVectorizer().fit(sent_texts + [question])
307
+ Q = vect.transform([question]); S = vect.transform(sent_texts)
308
+ rel = (S @ Q.T).toarray().ravel()
309
+ def sim_fn(i, j):
310
+ num = (S[i] @ S[j].T)
311
+ return float(num.toarray()[0, 0]) if hasattr(num, "toarray") else float(num)
312
+ except Exception:
313
+ rel = np.ones(len(sent_texts))
314
+ def sim_fn(i, j): return 0.0
315
+
316
+ lambda_div = float(np.clip(lambda_div, 0.0, 1.0))
317
+
318
+ remain = list(range(len(pool)))
319
+ first = int(np.argmax(rel))
320
+ selected_idx = [first]
321
+ selected = [pool[first]]
322
+ remain.remove(first)
323
+
324
+ max_pick = min(int(top_n), len(pool))
325
+ while len(selected) < max_pick and remain:
326
+ cand_scores = []
327
+ for i in remain:
328
+ div_i = max(sim_fn(i, j) for j in selected_idx) if selected_idx else 0.0
329
+ score = lambda_div * float(rel[i]) - (1.0 - lambda_div) * div_i
330
+ cand_scores.append((score, i))
331
+ cand_scores.sort(reverse=True)
332
+ _, best_i = cand_scores[0]
333
+ selected_idx.append(best_i)
334
+ selected.append(pool[best_i])
335
+ remain.remove(best_i)
336
+
337
+ return selected
338
+
339
+ def compose_extractive(selected: List[Dict[str, Any]]) -> str:
340
+ if not selected:
341
+ return ""
342
+ return " ".join(f"{s['sent']} ({s['doc']}, p.{s['page']})" for s in selected)
343
+
344
+ # ---------- logging helpers ----------
345
+ OPENAI_IN_COST_PER_1K = float(os.getenv("OPENAI_COST_IN_PER_1K", "0"))
346
+ OPENAI_OUT_COST_PER_1K = float(os.getenv("OPENAI_COST_OUT_PER_1K", "0"))
347
+
348
+ def _safe_write_jsonl(path: Path, record: dict):
349
+ try:
350
+ with open(path, "a", encoding="utf-8") as f:
351
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
352
+ except Exception as e:
353
+ print("[Log] write failed:", e)
354
+
355
+ def _calc_cost_usd(prompt_toks, completion_toks):
356
+ if prompt_toks is None or completion_toks is None:
357
+ return None
358
+ return (prompt_toks / 1000.0) * OPENAI_IN_COST_PER_1K + (
359
+ completion_toks / 1000.0
360
+ ) * OPENAI_OUT_COST_PER_1K
361
+
362
+ # ---------- optional LLM synthesis ----------
363
+ def synthesize_with_llm(
364
+ question: str,
365
+ sentence_lines: List[str],
366
+ model: Optional[str] = None,
367
+ temperature: float = 0.2,
368
+ ):
369
+ if not LLM_AVAILABLE:
370
+ return None, None
371
+ client = OpenAI(api_key=OPENAI_API_KEY)
372
+ model = model or OPENAI_MODEL
373
+
374
+ SYSTEM_PROMPT = (
375
+ "You are a scientific assistant for self-sensing cementitious materials.\n"
376
+ "Answer STRICTLY using the provided sentences.\n"
377
+ "Do not invent facts. Keep it concise (3–6 sentences).\n"
378
+ "Retain inline citations like (Doc.pdf, p.X) exactly as given."
379
+ )
380
+ user_prompt = (
381
+ f"Question: {question}\n\n"
382
+ "Use ONLY these sentences to answer; keep their inline citations:\n"
383
+ + "\n".join(f"- {s}" for s in sentence_lines)
384
+ )
385
+
386
+ try:
387
+ resp = client.responses.create(
388
+ model=model,
389
+ input=[
390
+ {"role": "system", "content": SYSTEM_PROMPT},
391
+ {"role": "user", "content": user_prompt},
392
+ ],
393
+ temperature=temperature,
394
+ )
395
+ out_text = getattr(resp, "output_text", None) or str(resp)
396
+ usage = None
397
+ try:
398
+ u = getattr(resp, "usage", None)
399
+ if u:
400
+ pt = getattr(u, "prompt_tokens", None) if hasattr(u, "prompt_tokens") else u.get("prompt_tokens", None)
401
+ ct = getattr(u, "completion_tokens", None) if hasattr(u, "completion_tokens") else u.get("completion_tokens", None)
402
+ usage = {"prompt_tokens": pt, "completion_tokens": ct}
403
+ except Exception:
404
+ usage = None
405
+ return out_text, usage
406
+ except Exception:
407
+ return None, None
408
 
409
+ # ---------- main RAG reply (with config_id) ----------
410
+ def rag_reply(
411
+ question: str,
412
+ k: int = 8,
413
+ n_sentences: int = 4,
414
+ include_passages: bool = False,
415
+ use_llm: bool = False,
416
+ model: Optional[str] = None,
417
+ temperature: float = 0.2,
418
+ strict_quotes_only: bool = False,
419
+ w_tfidf: float = W_TFIDF_DEFAULT,
420
+ w_bm25: float = W_BM25_DEFAULT,
421
+ w_emb: float = W_EMB_DEFAULT,
422
+ config_id: Optional[str] = None,
423
+ ) -> str:
424
+ run_id = str(uuid.uuid4())
425
+ t0_total = time.time()
426
+ t0_retr = time.time()
427
 
428
+ hits = hybrid_search(
429
+ question,
430
+ k=int(k),
431
+ w_tfidf=float(w_tfidf),
432
+ w_bm25=float(w_bm25),
433
+ w_emb=float(w_emb),
434
+ )
435
+ t1_retr = time.time()
436
+ latency_ms_retriever = int((t1_retr - t0_retr) * 1000)
437
+
438
+ if hits is None or hits.empty:
439
+ final = "No indexed PDFs found."
440
+ record = {
441
+ "run_id": run_id,
442
+ "ts": int(time.time() * 1000),
443
+ "inputs": {
444
+ "question": question,
445
+ "top_k": int(k),
446
+ "n_sentences": int(n_sentences),
447
+ "w_tfidf": float(w_tfidf),
448
+ "w_bm25": float(w_bm25),
449
+ "w_emb": float(w_emb),
450
+ "use_llm": bool(use_llm),
451
+ "model": model,
452
+ "temperature": float(temperature),
453
+ "config_id": config_id,
454
+ },
455
+ "retrieval": {"hits": [], "latency_ms_retriever": latency_ms_retriever},
456
+ "output": {"final_answer": final, "used_sentences": []},
457
+ "latency_ms_total": int((time.time() - t0_total) * 1000),
458
+ "openai": None,
459
+ }
460
+ _safe_write_jsonl(LOG_PATH, record)
461
+ return final
462
+
463
+ selected = mmr_select_sentences(
464
+ question, hits, top_n=int(n_sentences), pool_per_chunk=6, lambda_div=0.7
465
+ )
466
+ header_cites = "; ".join(
467
+ f"{Path(r['doc_path']).name} (p.{_extract_page(r['text'])})"
468
+ for _, r in hits.head(6).iterrows()
469
+ )
470
+
471
+ srcs = {Path(r["doc_path"]).name for _, r in hits.iterrows()}
472
+ coverage_note = (
473
+ ""
474
+ if len(srcs) >= 3
475
+ else f"\n\n> Note: Only {len(srcs)} unique source(s). Add more PDFs or increase Top-K."
476
+ )
477
+
478
+ retr_list = []
479
+ for _, r in hits.iterrows():
480
+ retr_list.append(
481
+ {
482
+ "doc": Path(r["doc_path"]).name,
483
+ "page": _extract_page(r["text"]),
484
+ "score_tfidf": float(r.get("score_tfidf", 0.0)),
485
+ "score_bm25": float(r.get("score_bm25", 0.0)),
486
+ "score_dense": float(r.get("score_dense", 0.0)),
487
+ "combo_score": float(r.get("score", 0.0)),
488
+ }
489
+ )
490
+
491
+ # retrieval-only / strict quotations (used in grid search)
492
+ if strict_quotes_only:
493
+ if not selected:
494
+ final = (
495
+ f"**Quoted Passages:**\n\n---\n"
496
+ + "\n\n".join(hits["text"].tolist()[:2])
497
+ + f"\n\n**Citations:** {header_cites}{coverage_note}"
498
+ )
499
+ else:
500
+ final = "**Quoted Passages:**\n- " + "\n- ".join(
501
+ f"{s['sent']} ({s['doc']}, p.{s['page']})" for s in selected
502
+ )
503
+ final += f"\n\n**Citations:** {header_cites}{coverage_note}"
504
+ if include_passages:
505
+ final += "\n\n---\n" + "\n\n".join(hits["text"].tolist()[:2])
506
+
507
+ record = {
508
+ "run_id": run_id,
509
+ "ts": int(time.time() * 1000),
510
+ "inputs": {
511
+ "question": question,
512
+ "top_k": int(k),
513
+ "n_sentences": int(n_sentences),
514
+ "w_tfidf": float(w_tfidf),
515
+ "w_bm25": float(w_bm25),
516
+ "w_emb": float(w_emb),
517
+ "use_llm": False,
518
+ "model": None,
519
+ "temperature": float(temperature),
520
+ "config_id": config_id,
521
+ },
522
+ "retrieval": {"hits": retr_list, "latency_ms_retriever": latency_ms_retriever},
523
+ "output": {
524
+ "final_answer": final,
525
+ "used_sentences": [
526
+ {"sent": s["sent"], "doc": s["doc"], "page": s["page"]}
527
+ for s in selected
528
+ ],
529
+ },
530
+ "latency_ms_total": int((time.time() - t0_total) * 1000),
531
+ "openai": None,
532
+ }
533
+ _safe_write_jsonl(LOG_PATH, record)
534
+ return final
535
+
536
+ # extractive / LLM synthesis
537
+ extractive = compose_extractive(selected)
538
+ llm_usage = None
539
+ llm_latency_ms = None
540
+
541
+ if use_llm and selected:
542
+ lines = [f"{s['sent']} ({s['doc']}, p.{s['page']})" for s in selected]
543
+ t0_llm = time.time()
544
+ llm_text, llm_usage = synthesize_with_llm(
545
+ question, lines, model=model, temperature=temperature
546
+ )
547
+ t1_llm = time.time()
548
+ llm_latency_ms = int((t1_llm - t0_llm) * 1000)
549
+
550
+ if llm_text:
551
+ final = (
552
+ f"**Answer (LLM synthesis):** {llm_text}\n\n"
553
+ f"**Citations:** {header_cites}{coverage_note}"
554
+ )
555
+ if include_passages:
556
+ final += "\n\n---\n" + "\n\n".join(hits["text"].tolist()[:2])
557
+ else:
558
+ if not extractive:
559
+ final = (
560
+ f"**Answer:** Here are relevant passages.\n\n"
561
+ f"**Citations:** {header_cites}{coverage_note}\n\n---\n"
562
+ + "\n\n".join(hits["text"].tolist()[:2])
563
+ )
564
+ else:
565
+ final = (
566
+ f"**Answer:** {extractive}\n\n"
567
+ f"**Citations:** {header_cites}{coverage_note}"
568
+ )
569
+ if include_passages:
570
+ final += "\n\n---\n" + "\n\n".join(hits["text"].tolist()[:2])
571
+ else:
572
+ if not extractive:
573
+ final = (
574
+ f"**Answer:** Here are relevant passages.\n\n"
575
+ f"**Citations:** {header_cites}{coverage_note}\n\n---\n"
576
+ + "\n\n".join(hits["text"].tolist()[:2])
577
+ )
578
+ else:
579
+ final = (
580
+ f"**Answer:** {extractive}\n\n"
581
+ f"**Citations:** {header_cites}{coverage_note}"
582
+ )
583
+ if include_passages:
584
+ final += "\n\n---\n" + "\n\n".join(hits["text"].tolist()[:2])
585
+
586
+ prompt_toks = llm_usage.get("prompt_tokens") if llm_usage else None
587
+ completion_toks = llm_usage.get("completion_tokens") if llm_usage else None
588
+ cost_usd = _calc_cost_usd(prompt_toks, completion_toks)
589
+
590
+ total_ms = int((time.time() - t0_total) * 1000)
591
+ record = {
592
+ "run_id": run_id,
593
+ "ts": int(time.time() * 1000),
594
+ "inputs": {
595
+ "question": question,
596
+ "top_k": int(k),
597
+ "n_sentences": int(n_sentences),
598
+ "w_tfidf": float(w_tfidf),
599
+ "w_bm25": float(w_bm25),
600
+ "w_emb": float(w_emb),
601
+ "use_llm": bool(use_llm),
602
+ "model": model,
603
+ "temperature": float(temperature),
604
+ "config_id": config_id,
605
+ },
606
+ "retrieval": {"hits": retr_list, "latency_ms_retriever": latency_ms_retriever},
607
+ "output": {
608
+ "final_answer": final,
609
+ "used_sentences": [
610
+ {"sent": s["sent"], "doc": s["doc"], "page": s["page"]}
611
+ for s in selected
612
+ ],
613
+ },
614
+ "latency_ms_total": total_ms,
615
+ "latency_ms_llm": llm_latency_ms,
616
+ "openai": {
617
+ "prompt_tokens": prompt_toks,
618
+ "completion_tokens": completion_toks,
619
+ "cost_usd": cost_usd,
620
+ }
621
+ if use_llm
622
+ else None,
623
+ }
624
+ _safe_write_jsonl(LOG_PATH, record)
625
+ return final
626
+
627
+ # ---------- automated grid evaluation over weights ----------
628
+ def run_weight_grid_eval(
629
+ gold_csv: str,
630
+ weight_grid: List[Dict[str, float]],
631
+ k: int = 8,
632
+ n_sentences: int = 4,
633
+ ) -> None:
634
+ """
635
+ Automatically evaluate many (w_tfidf, w_bm25, w_emb) combinations
636
+ on the full gold question set.
637
+
638
+ - Reads questions from gold_csv (column 'question')
639
+ - For each configuration in weight_grid, calls rag_reply(...)
640
+ with use_llm=False and strict_quotes_only=True
641
+ - All runs are logged into rag_logs.jsonl with a 'config_id'
642
+ and the exact weights.
643
+ """
644
+ gold_df = pd.read_csv(gold_csv)
645
+ if "question" not in gold_df.columns:
646
+ raise ValueError("gold_csv must contain a 'question' column.")
647
+ questions = gold_df["question"].astype(str).tolist()
648
+
649
+ for cfg in weight_grid:
650
+ wt = float(cfg.get("w_tfidf", 0.0))
651
+ wb = float(cfg.get("w_bm25", 0.0))
652
+ we = float(cfg.get("w_emb", 0.0))
653
+ cid = cfg.get("id") or f"tfidf{wt}_bm25{wb}_emb{we}"
654
+
655
+ print(
656
+ f"\n[GridEval] Running config {cid} "
657
+ f"(w_tfidf={wt}, w_bm25={wb}, w_emb={we}, k={k})"
658
+ )
659
+
660
+ for q in questions:
661
+ _ = rag_reply(
662
+ question=q,
663
+ k=int(k),
664
+ n_sentences=int(n_sentences),
665
+ include_passages=False,
666
+ use_llm=False,
667
+ model=None,
668
+ temperature=0.0,
669
+ strict_quotes_only=True,
670
+ w_tfidf=wt,
671
+ w_bm25=wb,
672
+ w_emb=we,
673
+ config_id=cid,
674
+ )
675
+
676
+ print("✅ RAG core + grid evaluation helpers loaded.")
677
+
678
+
679
+ # ===================== GRADIO UI WRAPPER =====================
680
+
681
+ # --- Chat wrapper ---
682
  def rag_chat_fn(
683
  message,
684
  history,
 
703
  w_tfidf=float(w_tfidf),
704
  w_bm25=float(w_bm25),
705
  w_emb=float(w_emb),
706
+ config_id=None,
707
  )
708
 
709
+ # --- Evaluation wrapper: returns log + file paths ---
 
710
  def run_eval_ui(gold_file, k):
 
 
 
 
 
 
711
  if gold_file is None:
712
  gold_path = Path("gold.csv")
713
  if not gold_path.exists():
714
  return (
715
  "**No gold.csv provided or found in the working directory.**\n"
716
+ "Upload a file or place gold.csv next to app.py.",
717
+ None,
718
+ None,
719
+ None,
720
  )
721
  gold_csv = str(gold_path)
722
  else:
 
725
  logs_jsonl = str(LOG_PATH)
726
  out_dir = str(ARTIFACT_DIR)
727
 
728
+ # run evaluation and write artifacts
729
  evaluate_rag(
730
  gold_csv,
731
  logs_jsonl,
 
734
  group_by_weights=True,
735
  )
736
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
737
  base = ARTIFACT_DIR
738
  perq = base / "metrics_per_question.csv"
739
  agg = base / "metrics_aggregate.json"
 
743
  agg_path = str(agg) if agg.exists() else None
744
  surf_path = str(surf) if surf.exists() else None
745
 
746
+ log_msg = (
747
+ "✅ Evaluation finished.\n\n"
748
+ f"- Per-question metrics: `{perq}`\n"
749
+ f"- Aggregate metrics: `{agg}`\n"
750
+ f"- Config surface: `{surf}`\n\n"
751
+ "Use the download links below to save the files."
752
+ )
753
+
754
+ return log_msg, perq_path, agg_path, surf_path
755
 
756
 
757
  # ------------- Build Gradio UI -----------------
 
768
  top_k = gr.Slider(3, 15, value=8, step=1, label="Top-K chunks")
769
  n_sentences = gr.Slider(2, 8, value=4, step=1, label="Answer length (sentences)")
770
  include_passages = gr.Checkbox(
771
+ value=False, label="Include supporting passages"
 
772
  )
 
773
  with gr.Row():
774
+ w_tfidf = gr.Slider(0.0, 1.0, value=W_TFIDF_DEFAULT, step=0.05, label="TF-IDF weight")
775
+ w_bm25 = gr.Slider(0.0, 1.0, value=W_BM25_DEFAULT, step=0.05, label="BM25 weight")
776
+ w_emb = gr.Slider(0.0, 1.0, value=W_EMB_DEFAULT, step=0.05, label="Dense weight")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
777
 
778
  gr.ChatInterface(
779
  fn=rag_chat_fn,
780
+ additional_inputs=[top_k, n_sentences, include_passages, w_tfidf, w_bm25, w_emb],
 
 
 
 
 
 
 
781
  title="Hybrid RAG Q&A",
782
+ description="Hybrid BM25 + TF-IDF + dense retrieval with MMR sentence selection."
783
  )
784
 
785
  # --------- Evaluation tab ---------
 
787
  gr.Markdown(
788
  "Upload **gold.csv** and compute retrieval metrics against "
789
  "`rag_artifacts/rag_logs.jsonl`.\n\n"
790
+ "When you click **Run Evaluation**, the CSV/JSON files will be created "
791
+ "and appear as download links below."
792
  )
 
793
  gold_file = gr.File(label="gold.csv", file_types=[".csv"])
794
  k_slider = gr.Slider(3, 15, value=8, step=1, label="k for Hit/Recall/nDCG")
795
 
796
+ btn_eval = gr.Button("Run Evaluation", variant="primary")
797
+ eval_out = gr.Markdown(label="Evaluation log")
798
+
799
+ perq_file = gr.File(label="Per-question metrics (CSV)")
800
+ agg_file = gr.File(label="Aggregate metrics (JSON)")
801
+ surf_file = gr.File(label="Config surface (CSV)")
802
 
803
  btn_eval.click(
804
  fn=run_eval_ui,
805
  inputs=[gold_file, k_slider],
806
+ outputs=[eval_out, perq_file, agg_file, surf_file],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
807
  )
808
 
 
809
  # ------------- Launch app -----------------
810
  if __name__ == "__main__":
811
  demo.queue().launch(