Sarisha Das commited on
Commit
681ec3c
·
1 Parent(s): 0bcbce0

update for rag

Browse files
src/streamlit_app.py CHANGED
@@ -15,6 +15,11 @@ os.environ["TRANSFORMERS_CACHE"] = str(ROOT / ".hf_cache" / "transformers")
15
  from utils.retrieval_helpers import enrich_search_results, enrich_bm25_search_results
16
  from utils.bm25 import load
17
  from utils.semantic import load_vector_store
 
 
 
 
 
18
 
19
  from dotenv import load_dotenv
20
  load_dotenv()
@@ -31,83 +36,22 @@ st.set_page_config(
31
  )
32
 
33
  # ─── Paths ────────────────────────────────────────────────────────────────────
 
34
  FEEDBACK_CSV = ROOT / "results" / "feedback.csv"
35
  FEEDBACK_CSV.parent.mkdir(parents=True, exist_ok=True)
36
 
 
 
37
  HF_TOKEN = os.getenv('HF_TOKEN')
38
 
39
  from datasets import load_dataset
40
  from huggingface_hub import snapshot_download, login
41
 
42
  # ─── Custom CSS ───────────────────────────────────────────────────────────────
43
- st.markdown(
44
- """
45
- <style>
46
- @import url('https://fonts.googleapis.com/css2?family=Playfair+Display:wght@600&family=Source+Sans+3:wght@400;600&display=swap');
47
-
48
- html, body, [class*="css"] {
49
- font-family: 'Source Sans 3', sans-serif;
50
- }
51
- h1, h2, h3 { font-family: 'Playfair Display', serif; }
52
-
53
- .banner {
54
- background: linear-gradient(135deg, #2d4a22 0%, #4a7c3f 60%, #7aab5c 100%);
55
- border-radius: 12px;
56
- padding: 2rem 2.5rem;
57
- margin-bottom: 1.5rem;
58
- color: #f5f0e8;
59
- }
60
- .banner h1 { margin: 0; font-size: 2.4rem; color: #f5f0e8; }
61
- .banner p { margin: 0.3rem 0 0; font-size: 1.05rem; opacity: 0.85; }
62
-
63
- /* Product card (outer) */
64
- .product-card {
65
- background: #fffdf7;
66
- border: 1px solid #e2d9c8;
67
- border-left: 4px solid #4a7c3f;
68
- border-radius: 8px;
69
- padding: 1rem 1.2rem 0.6rem;
70
- margin-bottom: 0.4rem;
71
- box-shadow: 0 1px 4px rgba(0,0,0,0.06);
72
- }
73
- .product-card h4 { margin: 0 0 0.2rem; color: #1e3318; font-size: 1.05rem; }
74
-
75
- /* Review snippet inside expander */
76
- .review-snippet {
77
- background: #f7f4ee;
78
- border-radius: 6px;
79
- padding: 0.6rem 0.9rem;
80
- margin-bottom: 0.5rem;
81
- font-size: 0.87rem;
82
- color: #444;
83
- line-height: 1.55;
84
- }
85
- .score-badge {
86
- display: inline-block;
87
- background: #eaf3e6;
88
- color: #2d5a20;
89
- border-radius: 20px;
90
- padding: 2px 10px;
91
- font-size: 0.78rem;
92
- font-weight: 600;
93
- margin-right: 6px;
94
- }
95
- .stars { color: #e6a817; }
96
-
97
- .placeholder-badge {
98
- background: #fff3cd;
99
- border: 1px solid #ffc107;
100
- border-radius: 6px;
101
- padding: 0.4rem 0.8rem;
102
- font-size: 0.82rem;
103
- color: #7a5800;
104
- display: inline-block;
105
- margin-bottom: 1rem;
106
- }
107
- </style>
108
- """,
109
- unsafe_allow_html=True,
110
- )
111
 
112
  @st.cache_resource
113
  def load_hf_dataset():
@@ -183,9 +127,23 @@ def semantic_search(query: str, top_k: int = 3) -> list[dict]:
183
  return retriever.search(query, top_k=top_k)
184
  Returns top_k review-level results (scores are cosine similarities, 0–1).
185
  """
186
- results = enrich_search_results(vector_store, query, top_k, HF_DATASET["full"])
 
187
  return results
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  # ─── Helpers ──────────────────────────────────────────────────────────────────
191
  def stars(rating: float) -> str:
@@ -212,82 +170,85 @@ def log_feedback(query: str, mode: str, asin: str, title: str, vote: str) -> Non
212
  "vote": vote,
213
  })
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  def render_results(results: list[dict], mode: str, query: str) -> None:
217
  if not results:
218
  st.info("No results returned.")
219
  return
220
-
221
- grouped = results
222
-
223
- for ind, item in enumerate(grouped):
224
- reviews = item["reviews"]
225
- title = item["title"]
226
- avg_rating = item["average_rating"]
227
- n_reviews = len(reviews)
228
- # total_reviews = item.get('total_reviews', n_reviews)
229
- rating_number = item.get('rating_number', 0)
230
- asin = item['parent_asin']
231
- review_word = "review" if n_reviews == 1 else "reviews"
232
- large_images = item.get('images', {}).get('large', [])
233
- image_html = f'<img src="{large_images[0]}" style="width:100%;max-width:200px;border-radius:8px;margin-bottom:8px;" />' if large_images else ''
234
- raw_price = item.get('price')
235
- try:
236
- price_val = float(str(raw_price).replace('$', '').replace(',', '').strip())
237
- price_html = f'<span style="color:#2ecc71;font-weight:600">${price_val:.2f}</span>'
238
- except (TypeError, ValueError):
239
- price_html = ''
240
-
241
-
242
- # ── Product card header ───────────────────────────────────────────
243
- st.markdown(
244
- f"""
245
- <div class="product-card">
246
- {image_html}
247
- <h4>#{ind + 1} &nbsp; {title}</h4>
248
- <span class="stars">{stars(avg_rating)}</span>
249
- &nbsp;<small style="color:#888">{avg_rating:.1f}/5 avg ({rating_number:,} ratings)</small>
250
- &nbsp;&nbsp;
251
- <span class="score-badge">similarity score: {item['score']}</span>
252
- {"&nbsp;&nbsp;" + price_html if price_html else ""}
253
- </div>
254
- """,
255
- unsafe_allow_html=True,
256
- )
257
-
258
- # ── Reviews in collapsible expander ───────────────────────────────
259
- expander_label = f"📖 Viewing top {n_reviews} {review_word} "
260
- with st.expander(expander_label, expanded=(n_reviews == 1)):
261
- for j, rev in enumerate(reviews):
262
- st.markdown(
263
- f"""
264
- <div class="review-snippet">
265
- <strong>{rev['title']}</strong>
266
- &nbsp;·&nbsp;
267
- <span class="stars">{stars(rev['rating'])}</span>
268
- <span style="color:#888; font-size:0.8rem"> {rev['rating']}/5</span>
269
- &nbsp;·&nbsp;
270
- <br><br>
271
- {rev['text'][:300]}{'…' if len(rev['text']) > 300 else ''}
272
- </div>
273
- """,
274
- unsafe_allow_html=True,
275
- )
276
-
277
- # ── Feedback buttons (per product) ────────────────────────────────
278
- col_up, col_dn, _ = st.columns([1, 1, 10])
279
- with col_up:
280
- if st.button("👍", key=f"up_{mode}_{asin}_{ind}"):
281
- log_feedback(query, mode, asin, title, "up")
282
- st.toast("Thanks! 👍")
283
- with col_dn:
284
- if st.button("👎", key=f"dn_{mode}_{asin}_{ind}"):
285
- log_feedback(query, mode, asin, title, "down")
286
- st.toast("Noted! 👎")
287
-
288
- st.markdown("<hr style='border:none;border-top:1px solid #e8e0d0;margin:0.5rem 0 1rem'>", unsafe_allow_html=True)
289
-
290
-
291
  # ─── App layout ───────────────────────────────────────────────────────────────
292
  st.markdown(
293
  """
@@ -304,29 +265,85 @@ query = st.text_input(
304
  "Search for a product or describe what you're looking for",
305
  placeholder="e.g. something sweet for a cheese board...",
306
  )
 
 
 
307
 
308
- # ─── Mode radio ───────────────────────────────────────────────────────────────
309
- mode = st.radio(
310
- "Search mode",
311
- options=["BM25", "Semantic"],
312
- index=0, # BM25 shown by default
313
- horizontal=True,
314
- help="BM25 = keyword matching · Semantic = embedding similarity (all-MiniLM-L6-v2 + FAISS)",
315
- )
316
 
317
- # ─── Run & render ─────────────────────────────────────────────────────────────
318
- TOP_K = 5 # fixed per milestone requirement
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
- if query.strip():
321
- st.markdown(f"#### Top {TOP_K} results — {mode}")
 
 
 
 
 
 
 
 
 
 
 
322
 
323
- results = bm25_search(query, top_k=TOP_K) if mode == "BM25" else semantic_search(query, top_k=TOP_K)
324
- render_results(results, mode=mode.lower(), query=query)
325
- else:
326
- st.markdown(
327
- "<p style='color:#aaa; margin-top:1rem;'>Enter a query above to see results.</p>",
328
- unsafe_allow_html=True,
329
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
  # ─── Sidebar: feedback log ────────────────────────────────────────────────────
332
  with st.sidebar:
 
15
  from utils.retrieval_helpers import enrich_search_results, enrich_bm25_search_results
16
  from utils.bm25 import load
17
  from utils.semantic import load_vector_store
18
+ from utils.rag_pipeline import run_rag
19
+ from utils.bm25 import load
20
+ from utils.hybrid import HybridRetriever
21
+
22
+ import markdown
23
 
24
  from dotenv import load_dotenv
25
  load_dotenv()
 
36
  )
37
 
38
  # ─── Paths ────────────────────────────────────────────────────────────────────
39
+ ROOT = Path(__file__).resolve().parent.parent
40
  FEEDBACK_CSV = ROOT / "results" / "feedback.csv"
41
  FEEDBACK_CSV.parent.mkdir(parents=True, exist_ok=True)
42
 
43
+ TOP_K = 5
44
+
45
  HF_TOKEN = os.getenv('HF_TOKEN')
46
 
47
  from datasets import load_dataset
48
  from huggingface_hub import snapshot_download, login
49
 
50
  # ─── Custom CSS ───────────────────────────────────────────────────────────────
51
+ with open('./src/styles.css', "r") as f:
52
+ css = f.read()
53
+
54
+ st.markdown(f"<style>{css}</style>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  @st.cache_resource
57
  def load_hf_dataset():
 
127
  return retriever.search(query, top_k=top_k)
128
  Returns top_k review-level results (scores are cosine similarities, 0–1).
129
  """
130
+
131
+ results = enrich_search_results(vector_store, query, top_k, HF_DATASET['full'])
132
  return results
133
 
134
+ hybrid_retriever = HybridRetriever(
135
+ bm25_retriever=retriever,
136
+ semantic_store=vector_store,
137
+ k=TOP_K,
138
+ bm25_weight=0.5,
139
+ semantic_weight=0.5,
140
+ )
141
+
142
+ def llm_retriever(query: str, top_k: int = 5):
143
+ retriever = hybrid_retriever
144
+ answer, docs = run_rag(retriever, query=query, hf_dataset=HF_DATASET['full'])
145
+ return answer, docs
146
+
147
 
148
  # ─── Helpers ──────────────────────────────────────────────────────────────────
149
  def stars(rating: float) -> str:
 
170
  "vote": vote,
171
  })
172
 
173
+ def render_product(ind, item):
174
+ reviews = item.get("reviews",{})
175
+ title = item["title"]
176
+ avg_rating = item["average_rating"]
177
+ n_reviews = len(reviews)
178
+ # total_reviews = item.get('total_reviews', n_reviews)
179
+ rating_number = item.get('rating_number', 0)
180
+ asin = item['parent_asin']
181
+ review_word = "review" if n_reviews == 1 else "reviews"
182
+ large_images = item.get('images', {}).get('large', [])
183
+ image_html = f'<img src="{large_images[0]}" style="width:100%;max-width:200px;border-radius:8px;margin-bottom:8px;" />' if large_images else ''
184
+ raw_price = item.get('price')
185
+ try:
186
+ price_val = float(str(raw_price).replace('$', '').replace(',', '').strip())
187
+ price_html = f'<span style="color:#2ecc71;font-weight:600">${price_val:.2f}</span>'
188
+ except (TypeError, ValueError):
189
+ price_html = ''
190
+
191
+
192
+ # ── Product card header ───────────────────────────────────────────
193
+ score_badge = f'<span class="score-badge">similarity score: {float(item["score"]):.2f}</span>' if 'score' in item else "<span/>"
194
+
195
+ st.markdown(
196
+ f"""
197
+ <div class="product-card" id="{asin}">
198
+ {image_html}
199
+ <h4>#{ind + 1} &nbsp; {title}</h4>
200
+ <span class="stars">{stars(avg_rating)}</span>
201
+ &nbsp;<small style="color:#888">{avg_rating:.1f}/5 avg ({rating_number:,} ratings)</small>
202
+ &nbsp;&nbsp;
203
+ {score_badge}
204
+ {"&nbsp;&nbsp;" + price_html if price_html else ""}
205
+ </div>
206
+ """,
207
+ unsafe_allow_html=True,
208
+ )
209
+
210
+ # ── Reviews in collapsible expander ───────────────────────────────
211
+ expander_label = f"📖 Viewing top {n_reviews} {review_word} "
212
+ with st.expander(expander_label, expanded=(n_reviews == 1)):
213
+ for j, rev in enumerate(reviews):
214
+ st.markdown(
215
+ f"""
216
+ <div class="review-snippet">
217
+ <strong>{rev['title']}</strong>
218
+ &nbsp;·&nbsp;
219
+ <span class="stars">{stars(rev['rating'])}</span>
220
+ <span style="color:#888; font-size:0.8rem"> {rev['rating']}/5</span>
221
+ &nbsp;·&nbsp;
222
+ <br><br>
223
+ {rev['text'][:300]}{'…' if len(rev['text']) > 300 else ''}
224
+ </div>
225
+ """,
226
+ unsafe_allow_html=True,
227
+ )
228
+
229
+ # ── Feedback buttons (per product) ────────────────────────────────
230
+ col_up, col_dn, _ = st.columns([1, 1, 10])
231
+ with col_up:
232
+ if st.button("👍", key=f"up_{mode}_{asin}_{ind}"):
233
+ log_feedback(query, mode, asin, title, "up")
234
+ st.toast("Thanks! 👍")
235
+ with col_dn:
236
+ if st.button("👎", key=f"dn_{mode}_{asin}_{ind}"):
237
+ log_feedback(query, mode, asin, title, "down")
238
+ st.toast("Noted! 👎")
239
+
240
+ st.markdown("<hr style='border:none;border-top:1px solid #e8e0d0;margin:0.5rem 0 1rem'>", unsafe_allow_html=True)
241
+
242
+
243
 
244
  def render_results(results: list[dict], mode: str, query: str) -> None:
245
  if not results:
246
  st.info("No results returned.")
247
  return
248
+
249
+ for ind, item in enumerate(results):
250
+ render_product(ind,item)
251
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  # ─── App layout ───────────────────────────────────────────────────────────────
253
  st.markdown(
254
  """
 
265
  "Search for a product or describe what you're looking for",
266
  placeholder="e.g. something sweet for a cheese board...",
267
  )
268
+ # ─── Run searches only when query changes ─────────────────────────────────────
269
+ if query.strip() and query != st.session_state.get("last_query"):
270
+ st.session_state.last_query = query
271
 
272
+ with st.spinner("Searching..."):
273
+ st.session_state.bm25_results = bm25_search(query, top_k=TOP_K)
274
+ st.session_state.semantic_results = semantic_search(query, top_k=TOP_K)
 
 
 
 
 
275
 
276
+ with st.spinner("Asking AI..."):
277
+ try:
278
+ answer, docs = llm_retriever(query, top_k=TOP_K)
279
+ st.session_state.llm_result = answer
280
+ st.session_state.llm_docs = docs
281
+ except Exception as e:
282
+ st.session_state.llm_result = f"**Error:** {e}"
283
+ st.session_state.llm_docs = []
284
+
285
+ elif not query.strip():
286
+ # Clear results when input is emptied
287
+ for key in ("last_query", "bm25_results", "semantic_results", "llm_result"):
288
+ st.session_state.pop(key, None)
289
+
290
+ # ─── Tabs ─────────────────────────────────────────────────────────────────────
291
+ tab_search, tab_llm = st.tabs(["🔍 Search", "🤖 AI Assistant"])
292
+
293
+ # ─── Search Tab ───────────────────────────────────────────────────────────────
294
+ with tab_search:
295
+ mode = st.radio(
296
+ "Search mode",
297
+ options=["BM25", "Semantic"],
298
+ index=0,
299
+ horizontal=True,
300
+ help="BM25 = keyword matching · Semantic = embedding similarity (all-MiniLM-L6-v2 + FAISS)",
301
+ )
302
 
303
+ if "last_query" not in st.session_state:
304
+ st.markdown(
305
+ "<p style='color:#aaa; margin-top:1rem;'>Enter a query above to see results.</p>",
306
+ unsafe_allow_html=True,
307
+ )
308
+ else:
309
+ st.markdown(f"#### Top {TOP_K} results — {mode}")
310
+ results = (
311
+ st.session_state.bm25_results
312
+ if mode == "BM25"
313
+ else st.session_state.semantic_results
314
+ )
315
+ render_results(results, mode=mode.lower(), query=st.session_state.last_query)
316
 
317
+ # ─── LLM Tab ──────────────────────────────────────────────────────────────────
318
+ with tab_llm:
319
+ if "llm_result" not in st.session_state:
320
+ st.markdown(
321
+ "<p style='color:#aaa; margin-top:1rem;'>Enter a query above to get AI-powered recommendations.</p>",
322
+ unsafe_allow_html=True,
323
+ )
324
+ else:
325
+ st.markdown(f"#### 🤖 AI Answer — *\"{st.session_state.last_query}\"*")
326
+ st.caption("⚠️ AI responses may contain errors - please verify before relying on them.")
327
+ html_response = markdown.markdown(
328
+ st.session_state.llm_result,
329
+ extensions=["tables", "fenced_code", "nl2br"],
330
+ )
331
+ st.markdown(
332
+ f"<div class='llm-response'>{html_response}</div>",
333
+ unsafe_allow_html=True,
334
+ )
335
+
336
+ st.markdown("#### 📦 Retrieved Products")
337
+ docs = st.session_state.get("llm_docs", [])
338
+ if docs:
339
+ # Build scrollable card list in one HTML block
340
+ cards_html = "<div class='doc-sidebar'>"
341
+ for i, doc in enumerate(docs, 1):
342
+ render_product(i,doc)
343
+ cards_html += "</div>"
344
+ st.markdown(cards_html, unsafe_allow_html=True)
345
+ else:
346
+ st.markdown("<p style='color:#aaa;'>No documents retrieved.</p>", unsafe_allow_html=True)
347
 
348
  # ─── Sidebar: feedback log ────────────────────────────────────────────────────
349
  with st.sidebar:
src/styles.css ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @import url('https://fonts.googleapis.com/css2?family=Playfair+Display:wght@600&family=Source+Sans+3:wght@400;600&display=swap');
2
+
3
+ html, body, [class*="css"] {
4
+ font-family: 'Source Sans 3', sans-serif;
5
+ }
6
+ h1, h2, h3 { font-family: 'Playfair Display', serif; }
7
+
8
+ .banner {
9
+ background: linear-gradient(135deg, #2d4a22 0%, #4a7c3f 60%, #7aab5c 100%);
10
+ border-radius: 12px;
11
+ padding: 2rem 2.5rem;
12
+ margin-bottom: 1.5rem;
13
+ color: #f5f0e8;
14
+ }
15
+ .banner h1 { margin: 0; font-size: 2.4rem; color: #f5f0e8; }
16
+ .banner p { margin: 0.3rem 0 0; font-size: 1.05rem; opacity: 0.85; }
17
+
18
+ /* Product card (outer) */
19
+ .product-card {
20
+ background: #fffdf7;
21
+ border: 1px solid #e2d9c8;
22
+ border-left: 4px solid #4a7c3f;
23
+ border-radius: 8px;
24
+ padding: 1rem 1.2rem 0.6rem;
25
+ margin-bottom: 0.4rem;
26
+ box-shadow: 0 1px 4px rgba(0,0,0,0.06);
27
+ }
28
+ .product-card h4 { margin: 0 0 0.2rem; color: #1e3318; font-size: 1.05rem; }
29
+
30
+ /* Review snippet inside expander */
31
+ .review-snippet {
32
+ background: #f7f4ee;
33
+ border-radius: 6px;
34
+ padding: 0.6rem 0.9rem;
35
+ margin-bottom: 0.5rem;
36
+ font-size: 0.87rem;
37
+ color: #444;
38
+ line-height: 1.55;
39
+ }
40
+ .score-badge {
41
+ display: inline-block;
42
+ background: #eaf3e6;
43
+ color: #2d5a20;
44
+ border-radius: 20px;
45
+ padding: 2px 10px;
46
+ font-size: 0.78rem;
47
+ font-weight: 600;
48
+ margin-right: 6px;
49
+ }
50
+ .stars { color: #e6a817; }
51
+
52
+ .placeholder-badge {
53
+ background: #fff3cd;
54
+ border: 1px solid #ffc107;
55
+ border-radius: 6px;
56
+ padding: 0.4rem 0.8rem;
57
+ font-size: 0.82rem;
58
+ color: #7a5800;
59
+ display: inline-block;
60
+ margin-bottom: 1rem;
61
+ }
62
+
63
+ .doc-sidebar {
64
+ max-height: 600px;
65
+ overflow-y: auto;
66
+ padding-right: 4px;
67
+ }
68
+ .doc-card {
69
+ background: #1e1e2e;
70
+ border: 1px solid #333;
71
+ border-radius: 8px;
72
+ padding: 0.75rem;
73
+ margin-bottom: 0.65rem;
74
+ }
75
+ .doc-title {
76
+ font-weight: 600;
77
+ font-size: 0.85rem;
78
+ margin-bottom: 0.3rem;
79
+ color: #f0f0f0;
80
+ line-height: 1.3;
81
+ }
82
+ .doc-meta {
83
+ font-size: 0.78rem;
84
+ margin-bottom: 0.3rem;
85
+ display: flex;
86
+ gap: 0.5rem;
87
+ }
88
+ .doc-rating { color: #f5c518; }
89
+ .doc-price { color: #5cb85c; }
90
+ .doc-snippet {
91
+ font-size: 0.75rem;
92
+ color: #999;
93
+ line-height: 1.4;
94
+ }
utils/bm25.py CHANGED
@@ -368,21 +368,8 @@ def load(index_path: str | Path = "data/processed/bm25_index.pkl") -> BM25Retrie
368
  f"BM25 index not found at '{index_path}'.\n"
369
  "Run build_and_save() from your notebook first."
370
  )
371
-
372
- # Patch: pickle saved simple_tokenize under 'utils' top-level namespace,
373
- # but it now lives in utils.bm25 — register it where pickle expects it
374
- import sys
375
- import types
376
- from utils import bm25 as bm25_module
377
-
378
- if "utils" not in sys.modules or not hasattr(sys.modules["utils"], "simple_tokenize"):
379
- fake_utils = types.ModuleType("utils")
380
- fake_utils.simple_tokenize = bm25_module.simple_tokenize
381
- sys.modules["utils"] = fake_utils
382
-
383
  with open(index_path, "rb") as f:
384
  retriever = pickle.load(f)
385
-
386
  print(f"BM25 index loaded ← {index_path}")
387
  return retriever
388
 
@@ -410,18 +397,25 @@ def search(
410
  asin, title, text, rating, score, top_reviews
411
  """
412
  retriever.k = top_k
413
- docs = retriever.invoke(query)
414
-
 
 
 
 
 
 
 
 
415
  results = []
416
- for doc in docs:
417
- m = doc.metadata
 
418
  top_reviews = m.get("top_reviews", [])
419
 
420
- # Average rating across retrieved top reviews
421
  rated = [r["rating"] for r in top_reviews if r.get("rating") is not None]
422
  avg_rating = round(sum(rated) / len(rated), 1) if rated else 0.0
423
 
424
- # Snippet = first review text, falling back to description
425
  if top_reviews and top_reviews[0].get("text"):
426
  snippet = top_reviews[0]["text"][:300]
427
  else:
@@ -432,7 +426,7 @@ def search(
432
  "title": m.get("title", ""),
433
  "text": snippet,
434
  "rating": avg_rating,
435
- "score": 0.0, # LangChain BM25Retriever does not expose raw scores
436
  "top_reviews": top_reviews,
437
  })
438
 
@@ -444,8 +438,8 @@ def search(
444
  def build_from_hf_datasets(
445
  metadata_dataset: Dataset,
446
  reviews_dataset_dict,
447
- index_path: str | Path = "data/processed/bm25_index.pkl",
448
- corpus_path: str | Path = "data/processed/bm25_corpus.pkl",
449
  max_products: int | None = None,
450
  max_reviews_per_product: int = 5,
451
  ) -> BM25Retriever:
@@ -480,8 +474,8 @@ def build_from_hf_datasets(
480
  def build_from_hf_datasets_batched(
481
  metadata_dataset: Dataset,
482
  reviews_dataset_dict,
483
- index_path: str | Path = "data/processed/bm25_index.pkl",
484
- corpus_path: str | Path = "data/processed/bm25_corpus.pkl",
485
  batch_size: int = 2000,
486
  max_reviews_per_product: int = 5,
487
  max_products: int | None = None,
 
368
  f"BM25 index not found at '{index_path}'.\n"
369
  "Run build_and_save() from your notebook first."
370
  )
 
 
 
 
 
 
 
 
 
 
 
 
371
  with open(index_path, "rb") as f:
372
  retriever = pickle.load(f)
 
373
  print(f"BM25 index loaded ← {index_path}")
374
  return retriever
375
 
 
397
  asin, title, text, rating, score, top_reviews
398
  """
399
  retriever.k = top_k
400
+
401
+ # Tokenize query the same way the index was built
402
+ tokenized_query = simple_tokenize(query)
403
+
404
+ # Get raw BM25 scores for ALL documents
405
+ scores = retriever.vectorizer.get_scores(tokenized_query) # np.ndarray, len = n_docs
406
+
407
+ # Get top-k doc indices by score
408
+ top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k]
409
+
410
  results = []
411
+ for idx in top_indices:
412
+ doc = retriever.docs[idx] # retriever.docs holds the original Document list
413
+ m = doc.metadata
414
  top_reviews = m.get("top_reviews", [])
415
 
 
416
  rated = [r["rating"] for r in top_reviews if r.get("rating") is not None]
417
  avg_rating = round(sum(rated) / len(rated), 1) if rated else 0.0
418
 
 
419
  if top_reviews and top_reviews[0].get("text"):
420
  snippet = top_reviews[0]["text"][:300]
421
  else:
 
426
  "title": m.get("title", ""),
427
  "text": snippet,
428
  "rating": avg_rating,
429
+ "score": float(scores[idx]),
430
  "top_reviews": top_reviews,
431
  })
432
 
 
438
  def build_from_hf_datasets(
439
  metadata_dataset: Dataset,
440
  reviews_dataset_dict,
441
+ index_path: str | Path = "data/processed/tokenisation/bm25_index.pkl",
442
+ corpus_path: str | Path = "data/processed/tokenisation/bm25_corpus.pkl",
443
  max_products: int | None = None,
444
  max_reviews_per_product: int = 5,
445
  ) -> BM25Retriever:
 
474
  def build_from_hf_datasets_batched(
475
  metadata_dataset: Dataset,
476
  reviews_dataset_dict,
477
+ index_path: str | Path = "data/processed/tokenisation/bm25_index.pkl",
478
+ corpus_path: str | Path = "data/processed/tokenisation/bm25_corpus.pkl",
479
  batch_size: int = 2000,
480
  max_reviews_per_product: int = 5,
481
  max_products: int | None = None,
utils/hybrid.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/hybrid.py
3
+ -------------
4
+ Hybrid retriever combining BM25 keyword search and FAISS semantic search,
5
+ fused with Reciprocal Rank Fusion (RRF).
6
+
7
+ Designed to plug into the existing run_rag() pipeline in rag_pipeline.py
8
+ as a drop-in replacement for the semantic retriever:
9
+
10
+ hybrid_retriever = load_hybrid_retriever(
11
+ bm25_index_path="data/processed/tokenisation/bm25_index_mini.pkl",
12
+ faiss_store_path="data/processed/embeddings",
13
+ k=5,
14
+ )
15
+ answer = run_rag(hybrid_retriever, "Best coffee beans for espresso")
16
+
17
+ The HybridRetriever class extends LangChain's BaseRetriever so it is fully
18
+ compatible with the | (pipe) operator used in rag_pipeline.py:
19
+
20
+ rag_chain = (
21
+ {
22
+ "context": hybrid_retriever | RunnableLambda(build_context),
23
+ "question": RunnablePassthrough(),
24
+ }
25
+ | prompt_template
26
+ | llm
27
+ | StrOutputParser()
28
+ )
29
+ """
30
+
31
+ from __future__ import annotations
32
+
33
+ import logging
34
+ from typing import Any
35
+
36
+ from langchain_community.retrievers import BM25Retriever
37
+ from langchain_community.vectorstores import FAISS
38
+ from langchain_core.callbacks import CallbackManagerForRetrieverRun
39
+ from langchain_core.documents import Document
40
+ from langchain_core.retrievers import BaseRetriever
41
+ from pydantic import Field
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # HybridRetriever
48
+ # ---------------------------------------------------------------------------
49
+
50
+ class HybridRetriever(BaseRetriever):
51
+ """
52
+ Combines BM25 keyword retrieval and FAISS semantic retrieval using
53
+ Reciprocal Rank Fusion (RRF) to produce a unified ranked document list.
54
+
55
+ RRF score for document d across retriever r:
56
+ score(d) = weight_r * (1 / (rrf_c + rank(d, r)))
57
+
58
+ Documents appearing in both retrievers accumulate scores from both,
59
+ naturally promoting results that are relevant by both keyword and meaning.
60
+
61
+ Parameters
62
+ ----------
63
+ bm25_retriever : Fitted LangChain BM25Retriever (from bm25.load())
64
+ semantic_store : Loaded FAISS vectorstore (from semantic.load_vector_store())
65
+ k : Number of final documents to return
66
+ rrf_c : RRF constant — dampens the impact of rank differences.
67
+ Standard value is 60; lower = top ranks matter more.
68
+ bm25_weight : RRF weight for BM25 results (keyword signal)
69
+ semantic_weight : RRF weight for semantic results (meaning signal)
70
+ fetch_multiplier : Fetch this multiple of k from each retriever before fusing.
71
+ More candidates = better fusion quality. Default: 3.
72
+ """
73
+
74
+ bm25_retriever: Any = Field(...)
75
+ semantic_store: Any = Field(...)
76
+ k: int = Field(default=5)
77
+ rrf_c: int = Field(default=60)
78
+ bm25_weight: float = Field(default=0.5)
79
+ semantic_weight: float = Field(default=0.5)
80
+ fetch_multiplier: int = Field(default=3)
81
+
82
+ def _get_relevant_documents(
83
+ self,
84
+ query: str,
85
+ *,
86
+ run_manager: CallbackManagerForRetrieverRun,
87
+ ) -> list[Document]:
88
+ """
89
+ Core retrieval logic called by LangChain when the retriever is invoked.
90
+
91
+ Steps
92
+ -----
93
+ 1. Fetch candidates from BM25 and FAISS independently
94
+ 2. Assign RRF scores weighted by retriever confidence
95
+ 3. Deduplicate by parent_asin, accumulating scores for shared hits
96
+ 4. Sort by fused RRF score and return top-k Documents
97
+ """
98
+ fetch_k = self.k * self.fetch_multiplier
99
+
100
+ # ── 1. BM25 retrieval ────────────────────────────────────────────────
101
+ self.bm25_retriever.k = fetch_k
102
+ try:
103
+ bm25_docs: list[Document] = self.bm25_retriever.invoke(query)
104
+ logger.debug("BM25 returned %d docs for query: %r", len(bm25_docs), query)
105
+ except Exception as exc:
106
+ logger.warning("BM25 retrieval failed: %s — using empty list.", exc)
107
+ bm25_docs = []
108
+
109
+ # ── 2. Semantic retrieval ────────────────────────────────────────────
110
+ # similarity_search returns list[Document] (no scores needed — rank is enough for RRF)
111
+ try:
112
+ semantic_docs: list[Document] = self.semantic_store.similarity_search(
113
+ query, k=fetch_k
114
+ )
115
+ logger.debug(
116
+ "Semantic returned %d docs for query: %r", len(semantic_docs), query
117
+ )
118
+ except Exception as exc:
119
+ logger.warning("Semantic retrieval failed: %s — using empty list.", exc)
120
+ semantic_docs = []
121
+
122
+ # ── 3. RRF fusion ────────────────────────────────────────────────────
123
+ rrf_scores: dict[str, float] = {}
124
+ doc_map: dict[str, Document] = {}
125
+
126
+ def _asin_key(doc: Document, fallback: str) -> str:
127
+ """Use parent_asin as the dedup key; fall back to a content prefix."""
128
+ return doc.metadata.get("parent_asin") or fallback
129
+
130
+ for rank, doc in enumerate(bm25_docs):
131
+ key = _asin_key(doc, f"bm25_{rank}")
132
+ score = self.bm25_weight / (self.rrf_c + rank + 1)
133
+ rrf_scores[key] = rrf_scores.get(key, 0.0) + score
134
+ doc_map[key] = doc # BM25 docs have richer metadata (top_reviews etc.)
135
+
136
+ for rank, doc in enumerate(semantic_docs):
137
+ key = _asin_key(doc, f"sem_{rank}")
138
+ score = self.semantic_weight / (self.rrf_c + rank + 1)
139
+ rrf_scores[key] = rrf_scores.get(key, 0.0) + score
140
+ # Only add to doc_map if BM25 didn't already supply this product
141
+ # (BM25 metadata is richer — has top_reviews, image_url, etc.)
142
+ if key not in doc_map:
143
+ doc_map[key] = doc
144
+
145
+ # ── 4. Sort and truncate ─────────────────────────────────────────────
146
+ ranked_keys = sorted(rrf_scores, key=lambda k: rrf_scores[k], reverse=True)
147
+ top_docs = [doc_map[key] for key in ranked_keys[: self.k]]
148
+
149
+ # Attach fused score to metadata — useful for app display
150
+ for key, doc in zip(ranked_keys, top_docs):
151
+ doc.metadata["hybrid_score"] = round(rrf_scores[key], 6)
152
+ # Record which retriever(s) contributed to this result
153
+ in_bm25 = any(
154
+ _asin_key(d, f"bm25_{i}") == key for i, d in enumerate(bm25_docs)
155
+ )
156
+ in_sem = any(
157
+ _asin_key(d, f"sem_{i}") == key for i, d in enumerate(semantic_docs)
158
+ )
159
+ if in_bm25 and in_sem:
160
+ doc.metadata["retrieval_source"] = "hybrid"
161
+ elif in_bm25:
162
+ doc.metadata["retrieval_source"] = "bm25"
163
+ else:
164
+ doc.metadata["retrieval_source"] = "semantic"
165
+
166
+ logger.info(
167
+ "HybridRetriever: BM25=%d, Semantic=%d → fused=%d (returning top %d)",
168
+ len(bm25_docs), len(semantic_docs), len(rrf_scores), len(top_docs),
169
+ )
170
+ return top_docs
171
+
172
+
173
+ # ---------------------------------------------------------------------------
174
+ # Convenience loader
175
+ # ---------------------------------------------------------------------------
176
+
177
+ def load_hybrid_retriever(
178
+ bm25_index_path: str = "data/processed/tokenisation/bm25_index_mini.pkl",
179
+ faiss_store_path: str = "data/processed/embeddings",
180
+ k: int = 5,
181
+ bm25_weight: float = 0.5,
182
+ semantic_weight: float = 0.5,
183
+ rrf_c: int = 60,
184
+ fetch_multiplier: int = 3,
185
+ ) -> HybridRetriever:
186
+ """
187
+ Load both indexes from disk and return a ready-to-use HybridRetriever.
188
+
189
+ Call this once in your notebook or app.py, then pass the result to run_rag().
190
+
191
+ Parameters
192
+ ----------
193
+ bm25_index_path : Path to the pickled BM25Retriever (from bm25.build_and_save())
194
+ faiss_store_path : Directory containing index.faiss + index.pkl
195
+ (from semantic.build_and_save_vector_store())
196
+ k : Number of documents to return per query
197
+ bm25_weight : RRF weight for BM25 (keyword signal). Default 0.5.
198
+ semantic_weight : RRF weight for semantic (meaning signal). Default 0.5.
199
+ Weights don't need to sum to 1 but relative scale matters.
200
+ rrf_c : RRF rank-dampening constant. Default 60 (standard).
201
+ fetch_multiplier : Candidates to fetch per retriever = k * fetch_multiplier.
202
+
203
+ Returns
204
+ -------
205
+ HybridRetriever
206
+ A LangChain-compatible retriever pipeable with |.
207
+
208
+ Example
209
+ -------
210
+ >>> from utils.hybrid import load_hybrid_retriever
211
+ >>> from utils.rag_pipeline import run_rag
212
+ >>>
213
+ >>> hybrid = load_hybrid_retriever(k=5)
214
+ >>> answer = run_rag(hybrid, "Best coffee beans for a French press")
215
+ >>> print(answer)
216
+ """
217
+ # Import here to avoid circular imports when used from rag_pipeline.py
218
+ from utils.bm25 import load as load_bm25
219
+ from utils.semantic import load_vector_store
220
+
221
+ print(f"Loading BM25 index from: {bm25_index_path}")
222
+ bm25_ret: BM25Retriever = load_bm25(bm25_index_path)
223
+
224
+ print(f"Loading FAISS store from: {faiss_store_path}")
225
+ faiss_store: FAISS = load_vector_store(faiss_store_path)
226
+
227
+ retriever = HybridRetriever(
228
+ bm25_retriever=bm25_ret,
229
+ semantic_store=faiss_store,
230
+ k=k,
231
+ bm25_weight=bm25_weight,
232
+ semantic_weight=semantic_weight,
233
+ rrf_c=rrf_c,
234
+ fetch_multiplier=fetch_multiplier,
235
+ )
236
+ print(
237
+ f"HybridRetriever ready — k={k}, "
238
+ f"BM25 weight={bm25_weight}, Semantic weight={semantic_weight}, RRF c={rrf_c}"
239
+ )
240
+ return retriever
utils/rag_pipeline.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ rag_chain.py
3
+ ------------
4
+ Amazon product RAG (Retrieval-Augmented Generation) pipeline using
5
+ LangChain + HuggingFace Inference Endpoints.
6
+
7
+ Typical usage
8
+ -------------
9
+ >>> from rag_chain import run_rag
10
+ >>> answer = run_rag(retriever, "Moisturizing shampoo for thick curly hair")
11
+ >>> print(answer)
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import logging
17
+ from typing import Any
18
+
19
+ from langchain_core.documents import Document
20
+ from langchain_core.output_parsers import StrOutputParser
21
+ from langchain_core.prompts import ChatPromptTemplate
22
+ from langchain_core.runnables import RunnableLambda, RunnablePassthrough
23
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
24
+ from utils.retrieval_helpers import _format_docs
25
+
26
+ # ---------------------------------------------------------------------------
27
+ # Logging
28
+ # ---------------------------------------------------------------------------
29
+ logger = logging.getLogger(__name__)
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Constants
33
+ # ---------------------------------------------------------------------------
34
+ DEFAULT_REPO_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
35
+ DEFAULT_MAX_NEW_TOKENS = 512
36
+ DEFAULT_TOP_K = 5
37
+
38
+ DEFAULT_SYSTEM_PROMPT = (
39
+ "You are a helpful Amazon grocery shopping assistant.\n\n"
40
+ "You will receive a grocery query and a list of related Amazon products (including reviews and metadata).\n\n"
41
+ "Your response must follow this exact structure:\n\n"
42
+ "---\n\n"
43
+ "## 🛒 Recommended Products\n"
44
+ "For each product, write a numbered list entry, mentioning products by title "
45
+ "followed by 1–2 sentences describing the product and why it suits the query.\n\n"
46
+ "## 💡 Tips & Recipe Ideas\n"
47
+ "A bullet-point list of practical tips, storage advice, and brief recipe ideas related to the products above "
48
+ "(do NOT write out full recipes — keep each idea to 1–2 sentences)."
49
+ "Add food emojis if relevant.\n\n"
50
+ "---\n\n"
51
+ "Rules:\n"
52
+ "- Do not invent products. Only recommend products from the provided list.\n"
53
+ "- Keep descriptions factual and grounded in the provided reviews and metadata.\n"
54
+ "- Recipe ideas should be suggestions or ideas only, not step-by-step instructions.\n"
55
+ "- Format the entire response in Markdown.\n"
56
+ "- IMPORTANT: Whenever citing the product title: add the parent_asin in the following format [title](#parent_asin)"
57
+ )
58
+
59
+ # ---------------------------------------------------------------------------
60
+ # Helper functions
61
+ # ---------------------------------------------------------------------------
62
+
63
+ import logging
64
+ from langchain_core.runnables import RunnableLambda
65
+
66
+ logger = logging.getLogger(__name__)
67
+
68
+ def _make_verbose_tap(label: str, verbose: bool):
69
+ """
70
+ Returns a passthrough RunnableLambda that logs *value* when verbose=True.
71
+ Works for any chain step — docs, prompt messages, or raw strings.
72
+ """
73
+ def _tap(value):
74
+ if verbose:
75
+ if hasattr(value, "messages"): # ChatPromptValue
76
+ rendered = "\n".join(
77
+ f"[{m.type.upper()}]: {m.content}"
78
+ for m in value.messages
79
+ )
80
+ elif isinstance(value, list): # list of Documents
81
+ rendered = "\n".join(str(d) for d in value)
82
+ else:
83
+ rendered = str(value)
84
+
85
+ print(f"\n{'='*60}\n{label}\n{'='*60}\n{rendered}\n")
86
+ logger.debug("%s\n%s", label, rendered)
87
+ return value
88
+ return RunnableLambda(_tap)
89
+
90
+ def build_context(docs: list[Document]) -> str:
91
+ """
92
+ Concatenate a list of retrieved LangChain Documents into a single
93
+ context string that the LLM can reason over.
94
+
95
+ Each entry includes the product's ``parent_asin`` (falling back to its
96
+ position index), its page content, and its full metadata dict.
97
+
98
+ Parameters
99
+ ----------
100
+ docs:
101
+ List of ``langchain_core.documents.Document`` objects returned by
102
+ the retriever.
103
+
104
+ Returns
105
+ -------
106
+ str
107
+ A newline-separated block of product descriptions ready for prompt
108
+ injection. Returns an empty string when *docs* is empty.
109
+
110
+ Raises
111
+ ------
112
+ TypeError
113
+ If *docs* is not a list, or any element is not a ``Document``.
114
+ """
115
+ if not isinstance(docs, list):
116
+ raise TypeError(
117
+ f"'docs' must be a list of Document objects, got {type(docs).__name__}."
118
+ )
119
+ for i, doc in enumerate(docs):
120
+ if not isinstance(doc, Document):
121
+ raise TypeError(
122
+ f"Element at index {i} is not a Document; got {type(doc).__name__}."
123
+ )
124
+
125
+ if not docs:
126
+ logger.warning("build_context received an empty document list.")
127
+ return ""
128
+
129
+ return "\n\n".join(
130
+ f"ASIN {doc.metadata.get('parent_asin', n)} Description: {doc.page_content}\n"
131
+ f"Metadata: {doc.metadata}"
132
+ for n, doc in enumerate(docs)
133
+ )
134
+
135
+
136
+ def _build_llm(
137
+ repo_id: str,
138
+ max_new_tokens: int,
139
+ provider: str,
140
+ ) -> ChatHuggingFace:
141
+ """
142
+ Instantiate and return a ``ChatHuggingFace`` model backed by a
143
+ HuggingFace Inference Endpoint.
144
+
145
+ Parameters
146
+ ----------
147
+ repo_id:
148
+ HuggingFace Hub model identifier (e.g.
149
+ ``"meta-llama/Meta-Llama-3-8B-Instruct"``).
150
+ max_new_tokens:
151
+ Maximum number of tokens the model may generate per call.
152
+ provider:
153
+ Inference provider passed to ``HuggingFaceEndpoint``
154
+ (``"auto"``, ``"novita"``, etc.).
155
+
156
+ Returns
157
+ -------
158
+ ChatHuggingFace
159
+ A chat-compatible wrapper around the endpoint.
160
+ """
161
+ endpoint = HuggingFaceEndpoint(
162
+ repo_id=repo_id,
163
+ task="text-generation",
164
+ max_new_tokens=max_new_tokens,
165
+ provider=provider,
166
+ )
167
+ return ChatHuggingFace(llm=endpoint)
168
+
169
+
170
+ def _build_prompt_template(system_prompt: str) -> ChatPromptTemplate:
171
+ """
172
+ Create a ``ChatPromptTemplate`` with a system message and a human
173
+ turn that injects ``{context}`` and ``{question}`` placeholders.
174
+
175
+ Parameters
176
+ ----------
177
+ system_prompt:
178
+ The system-level instruction string.
179
+
180
+ Returns
181
+ -------
182
+ ChatPromptTemplate
183
+ """
184
+ return ChatPromptTemplate.from_messages([
185
+ ("system", system_prompt),
186
+ (
187
+ "human",
188
+ "context:\n{context}\n\nquestion:\n{question}\n\n"
189
+ "Answer based on the Amazon datasets:",
190
+ ),
191
+ ])
192
+
193
+
194
+ # ---------------------------------------------------------------------------
195
+ # Public API
196
+ # ---------------------------------------------------------------------------
197
+
198
+ def run_rag(
199
+ retriever: Any,
200
+ query: str,
201
+ system_prompt: str = DEFAULT_SYSTEM_PROMPT,
202
+ repo_id: str = DEFAULT_REPO_ID,
203
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
204
+ provider: str = "auto",
205
+ verbose: bool = False,
206
+ hf_dataset = None
207
+ ) -> str:
208
+ """
209
+ Execute a full RAG pipeline and return the model's answer.
210
+
211
+ The pipeline follows the steps below:
212
+
213
+ 1. **Retrieve** - *retriever* fetches the *k* most relevant documents
214
+ for *query*.
215
+ 2. **Format context** - :func:`build_context` serialises the documents
216
+ into a single string.
217
+ 3. **Prompt** - the context and query are injected into the chat prompt
218
+ template.
219
+ 4. **Generate** - the LLM produces an answer grounded in the context.
220
+ 5. **Parse** - the raw chat message is unwrapped to a plain string.
221
+
222
+ Parameters
223
+ ----------
224
+ retriever:
225
+ A LangChain-compatible retriever (must expose ``.invoke()`` and be
226
+ pipeable with ``|``). Typically created via
227
+ ``vectorstore.as_retriever(...)``.
228
+ query:
229
+ Natural-language question to answer (non-empty string).
230
+ system_prompt:
231
+ System-level instruction for the assistant. Defaults to
232
+ :data:`DEFAULT_SYSTEM_PROMPT`.
233
+ repo_id:
234
+ HuggingFace Hub model identifier. Defaults to
235
+ ``"meta-llama/Meta-Llama-3-8B-Instruct"``.
236
+ max_new_tokens:
237
+ Upper bound on generated tokens. Must be a positive integer.
238
+ Defaults to ``100``.
239
+ provider:
240
+ HuggingFace inference provider (e.g. ``"auto"``, ``"novita"``).
241
+ Defaults to ``"auto"``.
242
+
243
+ Returns
244
+ -------
245
+ str
246
+ The model's answer as a plain string.
247
+
248
+ Raises
249
+ ------
250
+ TypeError
251
+ If *retriever* is ``None``, *query* is not a string, or
252
+ *system_prompt* is not a string.
253
+ ValueError
254
+ If *query* is blank, *max_new_tokens* is not a positive integer,
255
+ or *repo_id* / *provider* are blank strings.
256
+
257
+ Examples
258
+ --------
259
+ >>> answer = run_rag(retriever, "Best waterproof mascara under $20")
260
+ >>> print(answer)
261
+ """
262
+ # ------------------------------------------------------------------
263
+ # Build chain components
264
+ # ------------------------------------------------------------------
265
+
266
+ logger.info("Initialising LLM endpoint: %s", repo_id)
267
+ llm = _build_llm(repo_id, max_new_tokens, provider)
268
+ prompt_template = _build_prompt_template(system_prompt)
269
+
270
+ retrieved_docs: list[Document] = [] # ← capture target
271
+
272
+ def _retrieve_and_capture(query: str) -> list[Document]:
273
+ """Invoke the retriever and snapshot the results for the caller."""
274
+ docs = retriever.invoke(query)
275
+ retrieved_docs.extend(docs) # ← populate closure variable
276
+ return docs # ← pass through to build_context
277
+
278
+ rag_chain = (
279
+ {
280
+ "context": RunnableLambda(_retrieve_and_capture)
281
+ | RunnableLambda(build_context)
282
+ | _make_verbose_tap("RETRIEVED CONTEXT", verbose),
283
+ "question": RunnablePassthrough(),
284
+ }
285
+ | _make_verbose_tap("PROMPT INPUTS (context + question)", verbose)
286
+ | prompt_template
287
+ | _make_verbose_tap("RENDERED PROMPT SENT TO LLM", verbose) # ← shows exact prompt
288
+ | llm
289
+ | StrOutputParser()
290
+ )
291
+
292
+ # ------------------------------------------------------------------
293
+ # Run
294
+ # ------------------------------------------------------------------
295
+ logger.info("Invoking RAG chain for query: %r", query)
296
+ answer: str = rag_chain.invoke(query)
297
+ logger.debug("RAG answer: %s", answer)
298
+
299
+ if hf_dataset:
300
+ docs = _format_docs(retrieved_docs, hf_dataset)
301
+ else:
302
+ docs = retrieved_docs
303
+
304
+ return answer, docs
utils/retrieval_helpers.py CHANGED
@@ -82,6 +82,7 @@ def enrich_search_results(vector_store, query: str, k: int, hf_dataset):
82
  # 4. Return JSON metadata objects
83
  return [json.loads(json.dumps(obj, default=str)) for obj in enriched_results]
84
 
 
85
  def enrich_bm25_search_results(retriever, query: str, k: int, hf_dataset):
86
  """
87
  Perform BM25 search and enrich results with HuggingFace dataset metadata.
@@ -137,5 +138,57 @@ def enrich_bm25_search_results(retriever, query: str, k: int, hf_dataset):
137
 
138
  con.close()
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  # 4. Return JSON metadata objects
141
  return [json.loads(json.dumps(obj, default=str)) for obj in enriched_results]
 
82
  # 4. Return JSON metadata objects
83
  return [json.loads(json.dumps(obj, default=str)) for obj in enriched_results]
84
 
85
+
86
  def enrich_bm25_search_results(retriever, query: str, k: int, hf_dataset):
87
  """
88
  Perform BM25 search and enrich results with HuggingFace dataset metadata.
 
138
 
139
  con.close()
140
 
141
+ # 4. Return JSON metadata objects
142
+ return [json.loads(json.dumps(obj, default=str)) for obj in enriched_results]
143
+
144
+ def _format_docs(results, hf_dataset):
145
+ """
146
+ Perform similarity search and enrich results with HuggingFace dataset metadata.
147
+
148
+ Args:
149
+ vector_store: LangChain vector store instance
150
+ query: Search query string
151
+ k: Number of results to return
152
+ filter: Filter dict for similarity search
153
+ hf_dataset: HuggingFace Arrow dataset (datasets.Dataset)
154
+
155
+ Returns:
156
+ List of enriched metadata objects as dicts
157
+ """
158
+
159
+ # 1. Extract parent_asins from metadata
160
+ parent_asins = [doc.metadata.get("parent_asin") for doc in results]
161
+
162
+ # 2. Query HuggingFace dataset via DuckDB
163
+ con = duckdb.connect()
164
+ arrow_table = hf_dataset.data.table # Get underlying PyArrow table
165
+ con.register("hf_table", arrow_table)
166
+
167
+ asin_list = ", ".join(f"'{asin}'" for asin in parent_asins if asin)
168
+ query_sql = f"SELECT * FROM hf_table WHERE parent_asin IN ({asin_list})"
169
+ hf_rows = con.execute(query_sql).fetchdf()
170
+
171
+ # Build lookup: parent_asin -> metadata dict
172
+ asin_to_metadata = {
173
+ row["parent_asin"]: row.to_dict()
174
+ for _, row in hf_rows.iterrows()
175
+ }
176
+
177
+ enriched_results = []
178
+
179
+ for doc in results:
180
+ parent_asin = doc.metadata.get("parent_asin")
181
+ total_reviews = doc.metadata.get("total_reviews")
182
+ metadata_object = asin_to_metadata.get(parent_asin, {}).copy()
183
+ metadata_object['total_reviews'] = total_reviews
184
+
185
+ # 3. Extract 3 lines after "Top Reviews\n" from page_content
186
+ page_content = doc.page_content
187
+ metadata_object["reviews"] = decode_ratings(page_content)
188
+
189
+ enriched_results.append(metadata_object)
190
+
191
+ con.close()
192
+
193
  # 4. Return JSON metadata objects
194
  return [json.loads(json.dumps(obj, default=str)) for obj in enriched_results]