github-actions[bot] commited on
Commit
a8a94d1
·
1 Parent(s): e51a05a

chore: sync app/ and src/ from GitHub

Browse files
app/app/app.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv, sys
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+
5
+ import streamlit as st
6
+ import markdown
7
+
8
+ ROOT_FOLDER = Path(__file__).resolve().parent.parent
9
+
10
+ sys.path.append(str(ROOT_FOLDER))
11
+ import sys
12
+ import os
13
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
14
+ from src.retrieval_helpers import enrich_search_results,enrich_bm25_search_results
15
+ from src.semantic import load_vector_store
16
+ from src.rag_pipeline import run_rag
17
+ from src.bm25 import load
18
+ from src.hybrid import HybridRetriever
19
+
20
+ from dotenv import load_dotenv
21
+ load_dotenv()
22
+
23
+ import warnings
24
+ warnings.filterwarnings("ignore", category=UserWarning)
25
+
26
+ # ─── Page config (must be first Streamlit call) ───────────────────────────────
27
+ st.set_page_config(
28
+ page_title="Groceries & Gourmet Food Search",
29
+ page_icon="🥕",
30
+ layout="wide",
31
+ initial_sidebar_state="collapsed",
32
+ )
33
+
34
+ # ─── Paths ────────────────────────────────────────────────────────────────────
35
+ ROOT = Path(__file__).resolve().parent.parent
36
+ FEEDBACK_CSV = ROOT / "results" / "feedback.csv"
37
+ FEEDBACK_CSV.parent.mkdir(parents=True, exist_ok=True)
38
+
39
+ TOP_K = 5
40
+
41
+ HF_TOKEN = os.getenv('HF_TOKEN')
42
+
43
+ from datasets import load_dataset
44
+ from huggingface_hub import snapshot_download, login
45
+
46
+ # ─── Custom CSS ───────────────────────────────────────────────────────────────
47
+ with open('./app/styles.css', "r") as f:
48
+ css = f.read()
49
+
50
+ st.markdown(f"<style>{css}</style>", unsafe_allow_html=True)
51
+
52
+ @st.cache_resource
53
+ def load_hf_dataset():
54
+ return load_dataset(
55
+ "McAuley-Lab/Amazon-Reviews-2023",
56
+ "raw_meta_Grocery_and_Gourmet_Food",
57
+ trust_remote_code=True,
58
+ token=HF_TOKEN
59
+ )
60
+
61
+ VECTOR_STORE_DIR = ROOT / "data" / "processed"
62
+
63
+ @st.cache_resource
64
+ def load_vector_store_cached():
65
+ login(token=HF_TOKEN, add_to_git_credential=False)
66
+ VECTOR_STORE_DIR.mkdir(parents=True, exist_ok=True)
67
+
68
+ snapshot_path = snapshot_download(
69
+ repo_id="rishadaz/amazon_retriever-storage",
70
+ repo_type="dataset",
71
+ local_dir=str(VECTOR_STORE_DIR),
72
+ token=HF_TOKEN,
73
+ )
74
+
75
+ mini_index_path = Path(snapshot_path) / "tokenisation" / "bm25_index_mini.pkl"
76
+ embeddings_dir = Path(snapshot_path) / "embeddings"
77
+
78
+ vector_store = load_vector_store(embeddings_dir)
79
+ bm25_retriever = load(mini_index_path)
80
+
81
+ return vector_store, bm25_retriever
82
+
83
+ # ─── Get Data ──────────────────────────────────────────────────────────────
84
+ # local tag will read from your local directory as a default it will
85
+ # read the mini versions of the files we have provided in the repo
86
+
87
+ data_source = "remote" #"remote" or "local"
88
+
89
+ # note: remote has the full generated corpus and
90
+ # embeddings which can take a long time to download and
91
+ # the app might become heavy too and slow down
92
+ # processing. For development pls use the smaller "local" corpus
93
+
94
+ HF_DATASET = load_hf_dataset()
95
+
96
+ if data_source == 'local':
97
+ MINI_INDEX_PATH = ROOT / "data" / "processed" / "tokenisation" / "bm25_index_mini.pkl"
98
+
99
+ vector_store = load_vector_store(ROOT_FOLDER / 'data' / 'processed' / 'embeddings')
100
+ retriever = load(MINI_INDEX_PATH)
101
+ else:
102
+
103
+ vector_store, retriever = load_vector_store_cached()
104
+
105
+
106
+
107
+ def bm25_search(query: str, top_k: int = 3) -> list[dict]:
108
+ """
109
+ PLACEHOLDER — swap with real BM25Retriever call, e.g.:
110
+ retriever = BM25Retriever.load('data/processed/bm25_index.pkl')
111
+ return retriever.search(query, top_k=top_k)
112
+ Returns top_k review-level results (may include multiple reviews per ASIN).
113
+ """
114
+
115
+ results = enrich_bm25_search_results(retriever, query, top_k, HF_DATASET['full'])
116
+ return results
117
+
118
+
119
+ def semantic_search(query: str, top_k: int = 3) -> list[dict]:
120
+ """
121
+ PLACEHOLDER — swap with real SemanticRetriever call, e.g.:
122
+ retriever = SemanticRetriever.load('data/processed/faiss_index')
123
+ return retriever.search(query, top_k=top_k)
124
+ Returns top_k review-level results (scores are cosine similarities, 0–1).
125
+ """
126
+
127
+ results = enrich_search_results(vector_store, query, top_k, HF_DATASET['full'])
128
+ return results
129
+
130
+ hybrid_retriever = HybridRetriever(
131
+ bm25_retriever=retriever,
132
+ semantic_store=vector_store,
133
+ k=TOP_K,
134
+ bm25_weight=0.5,
135
+ semantic_weight=0.5,
136
+ )
137
+
138
+ def llm_retriever(query: str, top_k: int = 5):
139
+ retriever = hybrid_retriever
140
+ answer, docs = run_rag(retriever, query=query, hf_dataset=HF_DATASET['full'])
141
+ return answer, docs
142
+
143
+
144
+ # ─── Helpers ──��───────────────────────────────────────────────────────────────
145
+ def stars(rating: float) -> str:
146
+ full = int(rating)
147
+ half = 1 if (rating - full) >= 0.5 else 0
148
+ empty = 5 - full - half
149
+ return "★" * full + "½" * half + "☆" * empty
150
+
151
+
152
+ def log_feedback(query: str, mode: str, asin: str, title: str, vote: str) -> None:
153
+ file_exists = FEEDBACK_CSV.exists()
154
+ with open(FEEDBACK_CSV, "a", newline="", encoding="utf-8") as f:
155
+ writer = csv.DictWriter(
156
+ f, fieldnames=["timestamp", "query", "mode", "asin", "title", "vote"]
157
+ )
158
+ if not file_exists:
159
+ writer.writeheader()
160
+ writer.writerow({
161
+ "timestamp": datetime.now().isoformat(),
162
+ "query": query,
163
+ "mode": mode,
164
+ "asin": asin,
165
+ "title": title,
166
+ "vote": vote,
167
+ })
168
+
169
+ def render_product(ind, item):
170
+ reviews = item.get("reviews",{})
171
+ title = item["title"]
172
+ avg_rating = item["average_rating"]
173
+ n_reviews = len(reviews)
174
+ # total_reviews = item.get('total_reviews', n_reviews)
175
+ rating_number = item.get('rating_number', 0)
176
+ asin = item['parent_asin']
177
+ review_word = "review" if n_reviews == 1 else "reviews"
178
+ large_images = item.get('images', {}).get('large', [])
179
+ image_html = f'<img src="{large_images[0]}" style="width:100%;max-width:200px;border-radius:8px;margin-bottom:8px;" />' if large_images else ''
180
+ raw_price = item.get('price')
181
+ try:
182
+ price_val = float(str(raw_price).replace('$', '').replace(',', '').strip())
183
+ price_html = f'<span style="color:#2ecc71;font-weight:600">${price_val:.2f}</span>'
184
+ except (TypeError, ValueError):
185
+ price_html = ''
186
+
187
+
188
+ # ── Product card header ───────────────────────────────────────────
189
+ score_badge = f'<span class="score-badge">similarity score: {float(item["score"]):.2f}</span>' if 'score' in item else "<span/>"
190
+
191
+ st.markdown(
192
+ f"""
193
+ <div class="product-card" id="{asin}">
194
+ {image_html}
195
+ <h4>#{ind + 1} &nbsp; {title}</h4>
196
+ <span class="stars">{stars(avg_rating)}</span>
197
+ &nbsp;<small style="color:#888">{avg_rating:.1f}/5 avg ({rating_number:,} ratings)</small>
198
+ &nbsp;&nbsp;
199
+ {score_badge}
200
+ {"&nbsp;&nbsp;" + price_html if price_html else ""}
201
+ </div>
202
+ """,
203
+ unsafe_allow_html=True,
204
+ )
205
+
206
+ # ── Reviews in collapsible expander ───────────────────────────────
207
+ expander_label = f"📖 Viewing top {n_reviews} {review_word} "
208
+ with st.expander(expander_label, expanded=(n_reviews == 1)):
209
+ for j, rev in enumerate(reviews):
210
+ st.markdown(
211
+ f"""
212
+ <div class="review-snippet">
213
+ <strong>{rev['title']}</strong>
214
+ &nbsp;·&nbsp;
215
+ <span class="stars">{stars(rev['rating'])}</span>
216
+ <span style="color:#888; font-size:0.8rem"> {rev['rating']}/5</span>
217
+ &nbsp;·&nbsp;
218
+ <br><br>
219
+ {rev['text'][:300]}{'…' if len(rev['text']) > 300 else ''}
220
+ </div>
221
+ """,
222
+ unsafe_allow_html=True,
223
+ )
224
+
225
+ # ── Feedback buttons (per product) ────────────────────────────────
226
+ col_up, col_dn, _ = st.columns([1, 1, 10])
227
+ with col_up:
228
+ if st.button("👍", key=f"up_{mode}_{asin}_{ind}"):
229
+ log_feedback(query, mode, asin, title, "up")
230
+ st.toast("Thanks! 👍")
231
+ with col_dn:
232
+ if st.button("👎", key=f"dn_{mode}_{asin}_{ind}"):
233
+ log_feedback(query, mode, asin, title, "down")
234
+ st.toast("Noted! 👎")
235
+
236
+ st.markdown("<hr style='border:none;border-top:1px solid #e8e0d0;margin:0.5rem 0 1rem'>", unsafe_allow_html=True)
237
+
238
+
239
+
240
+ def render_results(results: list[dict], mode: str, query: str) -> None:
241
+ if not results:
242
+ st.info("No results returned.")
243
+ return
244
+
245
+ for ind, item in enumerate(results):
246
+ render_product(ind,item)
247
+
248
+ # ─── App layout ───────────────────────────────────────────────────────────────
249
+ st.markdown(
250
+ """
251
+ <div class="banner">
252
+ <h1>🥕🧀 Groceries & Gourmet Food Search</h1>
253
+ <p>Amazon Products & Reviews · Groceries & Gourmet Food </p>
254
+ </div>
255
+ """,
256
+ unsafe_allow_html=True,
257
+ )
258
+
259
+ # ─── Search bar ───────────────────────────────────────────────────────────────
260
+ query = st.text_input(
261
+ "Search for a product or describe what you're looking for",
262
+ placeholder="e.g. something sweet for a cheese board...",
263
+ )
264
+ # ─── Run searches only when query changes ─────────────────────────────────────
265
+ if query.strip() and query != st.session_state.get("last_query"):
266
+ st.session_state.last_query = query
267
+
268
+ with st.spinner("Searching..."):
269
+ st.session_state.bm25_results = bm25_search(query, top_k=TOP_K)
270
+ st.session_state.semantic_results = semantic_search(query, top_k=TOP_K)
271
+
272
+ with st.spinner("Asking AI..."):
273
+ try:
274
+ answer, docs = llm_retriever(query, top_k=TOP_K)
275
+ st.session_state.llm_result = answer
276
+ st.session_state.llm_docs = docs
277
+ except Exception as e:
278
+ st.session_state.llm_result = f"**Error:** {e}"
279
+ st.session_state.llm_docs = []
280
+
281
+ elif not query.strip():
282
+ # Clear results when input is emptied
283
+ for key in ("last_query", "bm25_results", "semantic_results", "llm_result"):
284
+ st.session_state.pop(key, None)
285
+
286
+ # ─── Tabs ─────────────────────────────────────────────────────────────────────
287
+ tab_search, tab_llm = st.tabs(["🔍 Search", "🤖 AI Assistant"])
288
+
289
+ # ─── Search Tab ───────────────────────────────────────────────────────────────
290
+ with tab_search:
291
+ mode = st.radio(
292
+ "Search mode",
293
+ options=["BM25", "Semantic"],
294
+ index=0,
295
+ horizontal=True,
296
+ help="BM25 = keyword matching · Semantic = embedding similarity (all-MiniLM-L6-v2 + FAISS)",
297
+ )
298
+
299
+ if "last_query" not in st.session_state:
300
+ st.markdown(
301
+ "<p style='color:#aaa; margin-top:1rem;'>Enter a query above to see results.</p>",
302
+ unsafe_allow_html=True,
303
+ )
304
+ else:
305
+ st.markdown(f"#### Top {TOP_K} results — {mode}")
306
+ results = (
307
+ st.session_state.bm25_results
308
+ if mode == "BM25"
309
+ else st.session_state.semantic_results
310
+ )
311
+ render_results(results, mode=mode.lower(), query=st.session_state.last_query)
312
+
313
+ # ─── LLM Tab ──────────────────────────────────────────────────────────────────
314
+ with tab_llm:
315
+ if "llm_result" not in st.session_state:
316
+ st.markdown(
317
+ "<p style='color:#aaa; margin-top:1rem;'>Enter a query above to get AI-powered recommendations.</p>",
318
+ unsafe_allow_html=True,
319
+ )
320
+ else:
321
+ st.markdown(f"#### 🤖 AI Answer — *\"{st.session_state.last_query}\"*")
322
+ st.caption("⚠️ AI responses may contain errors - please verify before relying on them.")
323
+ html_response = markdown.markdown(
324
+ st.session_state.llm_result,
325
+ extensions=["tables", "fenced_code", "nl2br"],
326
+ )
327
+ st.markdown(
328
+ f"<div class='llm-response'>{html_response}</div>",
329
+ unsafe_allow_html=True,
330
+ )
331
+
332
+ st.markdown("#### 📦 Retrieved Products")
333
+ docs = st.session_state.get("llm_docs", [])
334
+ if docs:
335
+ # Build scrollable card list in one HTML block
336
+ cards_html = "<div class='doc-sidebar'>"
337
+ for i, doc in enumerate(docs, 1):
338
+ render_product(i,doc)
339
+ cards_html += "</div>"
340
+ st.markdown(cards_html, unsafe_allow_html=True)
341
+ else:
342
+ st.markdown("<p style='color:#aaa;'>No documents retrieved.</p>", unsafe_allow_html=True)
343
+
344
+ # ─── Sidebar: feedback log ────────────────────────────────────────────────────
345
+ with st.sidebar:
346
+ st.header("📋 Feedback Log")
347
+ if FEEDBACK_CSV.exists():
348
+ import pandas as pd
349
+ df = pd.read_csv(FEEDBACK_CSV)
350
+ st.dataframe(df.tail(20), use_container_width=True)
351
+ st.download_button(
352
+ "⬇️ Download feedback.csv",
353
+ data=df.to_csv(index=False),
354
+ file_name="feedback.csv",
355
+ mime="text/csv",
356
+ )
357
+ else:
358
+ st.info("No feedback yet — use 👍/👎 on results.")
app/app/styles.css ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @import url('https://fonts.googleapis.com/css2?family=Playfair+Display:wght@600&family=Source+Sans+3:wght@400;600&display=swap');
2
+
3
+ html, body, [class*="css"] {
4
+ font-family: 'Source Sans 3', sans-serif;
5
+ }
6
+ h1, h2, h3 { font-family: 'Playfair Display', serif; }
7
+
8
+ .banner {
9
+ background: linear-gradient(135deg, #2d4a22 0%, #4a7c3f 60%, #7aab5c 100%);
10
+ border-radius: 12px;
11
+ padding: 2rem 2.5rem;
12
+ margin-bottom: 1.5rem;
13
+ color: #f5f0e8;
14
+ }
15
+ .banner h1 { margin: 0; font-size: 2.4rem; color: #f5f0e8; }
16
+ .banner p { margin: 0.3rem 0 0; font-size: 1.05rem; opacity: 0.85; }
17
+
18
+ /* Product card (outer) */
19
+ .product-card {
20
+ background: #fffdf7;
21
+ border: 1px solid #e2d9c8;
22
+ border-left: 4px solid #4a7c3f;
23
+ border-radius: 8px;
24
+ padding: 1rem 1.2rem 0.6rem;
25
+ margin-bottom: 0.4rem;
26
+ box-shadow: 0 1px 4px rgba(0,0,0,0.06);
27
+ }
28
+ .product-card h4 { margin: 0 0 0.2rem; color: #1e3318; font-size: 1.05rem; }
29
+
30
+ /* Review snippet inside expander */
31
+ .review-snippet {
32
+ background: #f7f4ee;
33
+ border-radius: 6px;
34
+ padding: 0.6rem 0.9rem;
35
+ margin-bottom: 0.5rem;
36
+ font-size: 0.87rem;
37
+ color: #444;
38
+ line-height: 1.55;
39
+ }
40
+ .score-badge {
41
+ display: inline-block;
42
+ background: #eaf3e6;
43
+ color: #2d5a20;
44
+ border-radius: 20px;
45
+ padding: 2px 10px;
46
+ font-size: 0.78rem;
47
+ font-weight: 600;
48
+ margin-right: 6px;
49
+ }
50
+ .stars { color: #e6a817; }
51
+
52
+ .placeholder-badge {
53
+ background: #fff3cd;
54
+ border: 1px solid #ffc107;
55
+ border-radius: 6px;
56
+ padding: 0.4rem 0.8rem;
57
+ font-size: 0.82rem;
58
+ color: #7a5800;
59
+ display: inline-block;
60
+ margin-bottom: 1rem;
61
+ }
62
+
63
+ .doc-sidebar {
64
+ max-height: 600px;
65
+ overflow-y: auto;
66
+ padding-right: 4px;
67
+ }
68
+ .doc-card {
69
+ background: #1e1e2e;
70
+ border: 1px solid #333;
71
+ border-radius: 8px;
72
+ padding: 0.75rem;
73
+ margin-bottom: 0.65rem;
74
+ }
75
+ .doc-title {
76
+ font-weight: 600;
77
+ font-size: 0.85rem;
78
+ margin-bottom: 0.3rem;
79
+ color: #f0f0f0;
80
+ line-height: 1.3;
81
+ }
82
+ .doc-meta {
83
+ font-size: 0.78rem;
84
+ margin-bottom: 0.3rem;
85
+ display: flex;
86
+ gap: 0.5rem;
87
+ }
88
+ .doc-rating { color: #f5c518; }
89
+ .doc-price { color: #5cb85c; }
90
+ .doc-snippet {
91
+ font-size: 0.75rem;
92
+ color: #999;
93
+ line-height: 1.4;
94
+ }
src/src/__init__.py ADDED
File without changes
src/src/bm25.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/bm25.py — BM25 keyword retrieval
3
+ Uses LangChain's BM25Retriever with the custom tokenizer from utils.py.
4
+
5
+ Document schema (one LangChain Document per product):
6
+ page_content : text BM25 scores against =
7
+ title + features + description + categories +
8
+ details (flattened) + store + top-k review titles & texts
9
+ metadata : structured fields for display in app.py
10
+ (parent_asin, title, main_category, price, store,
11
+ categories, features, description, details, top_reviews)
12
+
13
+ Data source expected: HuggingFace Dataset objects as loaded in
14
+ milestone1_exploration.ipynb via load_dataset("McAuley-Lab/Amazon-Reviews-2023", ...)
15
+ OR the saved .jsonl subsets in data/raw/.
16
+ """
17
+
18
+ import json
19
+ import pickle
20
+ from pathlib import Path
21
+ from typing import Any
22
+ import sys
23
+ from datasets import Dataset
24
+ from langchain_community.retrievers import BM25Retriever
25
+ from langchain_core.documents import Document
26
+ ROOT_FOLDER = Path(__file__).resolve().parent.parent
27
+
28
+ sys.path.append(str(ROOT_FOLDER))
29
+ from src.utils import simple_tokenize
30
+ from src.eda_helpers import get_best_reviews
31
+
32
+
33
+ # ── field helpers ─────────────────────────────────────────────────────────────
34
+
35
+ def _coerce_str(value: Any) -> str:
36
+ """Safely flatten any metadata field to a plain string."""
37
+ if value is None:
38
+ return ""
39
+ if isinstance(value, list):
40
+ return " ".join(_coerce_str(v) for v in value)
41
+ if isinstance(value, dict):
42
+ return " ".join(f"{k} {_coerce_str(v)}" for k, v in value.items())
43
+ s = str(value)
44
+ # treat the literal string "None" as empty
45
+ return "" if s.strip().lower() == "none" else s
46
+
47
+
48
+ def _parse_details(details: Any) -> dict:
49
+ """
50
+ 'details' in this dataset is stored as a JSON string, e.g.:
51
+ '{"Brand": "Luzianne", "Item Form": "Ground", ...}'
52
+ Parse it safely; return an empty dict on failure.
53
+ """
54
+ if not details:
55
+ return {}
56
+ if isinstance(details, dict):
57
+ return details
58
+ try:
59
+ return json.loads(str(details))
60
+ except (json.JSONDecodeError, TypeError):
61
+ return {}
62
+
63
+
64
+ def _parse_price(price: Any) -> float | None:
65
+ """price can be a float, an int, or the string 'None'."""
66
+ if price is None:
67
+ return None
68
+ try:
69
+ v = float(price)
70
+ return None if v != v else v # NaN guard
71
+ except (ValueError, TypeError):
72
+ return None
73
+
74
+
75
+ # ── review selection ──────────────────────────────────────────────────────────
76
+
77
+ def get_top_reviews(
78
+ reviews_dataset_dict,
79
+ parent_asin: str,
80
+ k: int = 5,
81
+ ) -> list[dict]:
82
+ """
83
+ Select the top-k reviews for a product using get_best_reviews() from
84
+ eda_helpers.py (weighted score: helpful_vote 50%, verified_purchase 30%,
85
+ rating extremity 20%).
86
+
87
+ Parameters
88
+ ----------
89
+ reviews_dataset_dict : the full reviews DatasetDict (raw_reviews) —
90
+ NOT the pre-selected 'full' split, because
91
+ get_best_reviews() selects 'full' internally.
92
+ parent_asin : product identifier
93
+ k : number of reviews to return
94
+
95
+ Returns
96
+ -------
97
+ List of dicts with keys: title, text, rating, helpful_vote
98
+ """
99
+ result = get_best_reviews(reviews_dataset_dict, parent_asin, top_k=k)
100
+
101
+ # get_best_reviews returns (total_count, Dataset) when top_k is set,
102
+ # or a bare Dataset with 0 rows when no reviews are found.
103
+ if isinstance(result, tuple):
104
+ _, matched = result
105
+ else:
106
+ matched = result
107
+
108
+ if len(matched) == 0:
109
+ return []
110
+
111
+ return [
112
+ {
113
+ "title": row.get("title", "") or "",
114
+ "text": row.get("text", "") or "",
115
+ "rating": row.get("rating"),
116
+ "helpful_vote": row.get("helpful_vote", 0),
117
+ }
118
+ for row in matched
119
+ ]
120
+
121
+
122
+ # ── document construction ─────────────────────────────────────────────────────
123
+
124
+ def format_review(review: dict) -> str:
125
+ """Format a single review the same way as in the notebook."""
126
+ return (
127
+ f"Review (Rating: {review['rating']}): "
128
+ f"{review['title']}. "
129
+ f"{review['text']}\n "
130
+ )
131
+
132
+
133
+ def build_page_content(product: dict, top_reviews: list[dict]) -> str:
134
+ """
135
+ Build the page_content string that BM25 will index.
136
+ Mirrors the create_document() structure in milestone1_exploration.ipynb.
137
+ """
138
+ title = _coerce_str(product.get("title"))
139
+ description = " ".join(product.get("description") or [])
140
+ features = "\n".join(product.get("features") or [])
141
+ categories = " > ".join(product.get("categories") or [])
142
+ store = _coerce_str(product.get("store"))
143
+ details = _parse_details(product.get("details"))
144
+ details_str = " ".join(f"{k}: {v}" for k, v in details.items())
145
+
146
+ review_lines = "".join(format_review(r) for r in top_reviews)
147
+ n_reviews = len(top_reviews)
148
+
149
+ return f"""Product: {title}
150
+ Category: {categories}
151
+ Store: {store}
152
+
153
+ Features:
154
+ {features}
155
+
156
+ Description:
157
+ {description}
158
+
159
+ Details:
160
+ {details_str}
161
+
162
+ Top Reviews (showing {n_reviews}):
163
+ {review_lines}"""
164
+
165
+
166
+ def _extract_image_url(images: Any) -> str:
167
+ """
168
+ Extract the best available image URL from the images field.
169
+ The field is a dict with keys: thumb, large, hi_res, variant — each a list.
170
+ Prefers 'large', falls back to 'thumb', then 'hi_res'. Returns "" if none found.
171
+ """
172
+ if not images or not isinstance(images, dict):
173
+ return ""
174
+ for key in ("large", "thumb", "hi_res"):
175
+ urls = images.get(key)
176
+ if isinstance(urls, list) and urls and urls[0]:
177
+ return urls[0]
178
+ return ""
179
+
180
+
181
+ def build_document(product: dict, top_reviews: list[dict]) -> Document | None:
182
+ """
183
+ Build one LangChain Document for a single product row from the metadata Dataset.
184
+ Returns None if there is no indexable text.
185
+ """
186
+ page_content = build_page_content(product, top_reviews)
187
+ if not page_content.strip():
188
+ return None
189
+
190
+ details_dict = _parse_details(product.get("details"))
191
+
192
+ metadata = {
193
+ "parent_asin": product.get("parent_asin", ""),
194
+ "title": _coerce_str(product.get("title")),
195
+ "main_category": _coerce_str(product.get("main_category")),
196
+ "price": _parse_price(product.get("price")),
197
+ "store": _coerce_str(product.get("store")),
198
+ "categories": _coerce_str(product.get("categories")),
199
+ "features": _coerce_str(product.get("features")),
200
+ "description": _coerce_str(product.get("description")),
201
+ "details": details_dict,
202
+ "average_rating": product.get("average_rating"),
203
+ "rating_number": product.get("rating_number"),
204
+ "image_url": _extract_image_url(product.get("images")),
205
+ "top_reviews": top_reviews,
206
+ }
207
+
208
+ return Document(page_content=page_content, metadata=metadata)
209
+
210
+
211
+ def pregroup_reviews(
212
+ reviews_dataset_dict,
213
+ max_reviews_per_product: int = 5,
214
+ ) -> dict:
215
+ """
216
+ Pre-group top-k reviews per product using DuckDB for efficient scoring
217
+ and ranking — never loads all 14M reviews into Python memory at once.
218
+
219
+ Uses a single SQL query with ROW_NUMBER() to rank reviews per product
220
+ by the same weighted score as eda_helpers.get_best_reviews():
221
+ helpful_vote 50% (log-scaled) + verified_purchase 30% + rating extremity 20%
222
+ """
223
+ import duckdb
224
+
225
+ print("Pre-grouping reviews via DuckDB (memory-efficient) ...")
226
+ arrow_table = reviews_dataset_dict["full"].data.table
227
+
228
+ k = max_reviews_per_product
229
+ query = f"""
230
+ WITH scored AS (
231
+ SELECT
232
+ parent_asin,
233
+ title,
234
+ text,
235
+ rating,
236
+ helpful_vote,
237
+ verified_purchase,
238
+ (
239
+ 0.5 * (LN(1 + GREATEST(COALESCE(helpful_vote, 0), 0)))
240
+ + 0.3 * (CASE WHEN verified_purchase THEN 1.0 ELSE 0.0 END)
241
+ + 0.2 * (ABS(COALESCE(rating, 3.0) - 3.0) / 2.0)
242
+ ) AS score
243
+ FROM arrow_table
244
+ WHERE parent_asin IS NOT NULL AND parent_asin != ''
245
+ ),
246
+ ranked AS (
247
+ SELECT *,
248
+ ROW_NUMBER() OVER (
249
+ PARTITION BY parent_asin
250
+ ORDER BY score DESC
251
+ ) AS rn
252
+ FROM scored
253
+ )
254
+ SELECT parent_asin, title, text, rating, helpful_vote
255
+ FROM ranked
256
+ WHERE rn <= {k}
257
+ ORDER BY parent_asin, rn
258
+ """
259
+
260
+ rows = duckdb.query(query).fetchall()
261
+ cols = ["parent_asin", "title", "text", "rating", "helpful_vote"]
262
+
263
+ result = {}
264
+ for row in rows:
265
+ r = dict(zip(cols, row))
266
+ asin = r.pop("parent_asin")
267
+ result.setdefault(asin, []).append(r)
268
+
269
+ print(f" {len(result):,} unique parent_asins grouped")
270
+ print(" pre-grouping done")
271
+ return result
272
+
273
+
274
+ def build_documents(
275
+ metadata_dataset: Dataset,
276
+ reviews_dataset_dict,
277
+ max_products: int | None = None,
278
+ max_reviews_per_product: int = 5,
279
+ reviews_lookup: dict | None = None,
280
+ ) -> list[Document]:
281
+ """
282
+ Build one LangChain Document per product.
283
+
284
+ Pass reviews_lookup (from pregroup_reviews) to skip per-product DuckDB
285
+ queries entirely — much faster for large datasets.
286
+ """
287
+ total = len(metadata_dataset)
288
+ n = min(total, max_products) if max_products else total
289
+ print(f"Building documents for {n} products ...")
290
+
291
+ docs = []
292
+ for i in range(n):
293
+ product = metadata_dataset[i]
294
+ parent_asin = product.get("parent_asin", "")
295
+
296
+ if reviews_lookup is not None:
297
+ top_reviews = reviews_lookup.get(parent_asin, [])[:max_reviews_per_product]
298
+ else:
299
+ top_reviews = get_top_reviews(
300
+ reviews_dataset_dict, parent_asin, k=max_reviews_per_product
301
+ )
302
+
303
+ doc = build_document(product, top_reviews)
304
+ if doc:
305
+ docs.append(doc)
306
+
307
+ if (i + 1) % 500 == 0:
308
+ print(f" ... {i + 1}/{n} products processed")
309
+
310
+ print(f" -> {len(docs)} documents built (skipped {n - len(docs)} empty)")
311
+ return docs
312
+
313
+
314
+ # ── index build & persist ─────────────────────────────────────────────────────
315
+
316
+ def build_and_save(
317
+ documents: list[Document],
318
+ index_path: str | Path = "data/processed/bm25_index.pkl",
319
+ corpus_path: str | Path = "data/processed/bm25_corpus.pkl",
320
+ ) -> BM25Retriever:
321
+ """
322
+ Build a BM25Retriever from documents, then pickle both the
323
+ tokenized corpus and the retriever to disk.
324
+
325
+ Parameters
326
+ ----------
327
+ documents : output of build_documents()
328
+ index_path : e.g. 'data/processed/bm25_index.pkl'
329
+ corpus_path : e.g. 'data/processed/bm25_corpus.pkl'
330
+
331
+ Returns
332
+ -------
333
+ The fitted BM25Retriever instance.
334
+ """
335
+ index_path = Path(index_path)
336
+ corpus_path = Path(corpus_path)
337
+ index_path.parent.mkdir(parents=True, exist_ok=True)
338
+
339
+ print(f"Fitting BM25 index over {len(documents)} documents …")
340
+ retriever = BM25Retriever.from_documents(
341
+ documents,
342
+ preprocess_func=simple_tokenize,
343
+ )
344
+
345
+ # Save tokenized corpus separately — useful for inspection in the notebook
346
+ tokenized_corpus = [simple_tokenize(doc.page_content) for doc in documents]
347
+ with open(corpus_path, "wb") as f:
348
+ pickle.dump(tokenized_corpus, f)
349
+ print(f"Tokenized corpus saved → {corpus_path}")
350
+
351
+ with open(index_path, "wb") as f:
352
+ pickle.dump(retriever, f)
353
+ print(f"BM25 index saved → {index_path}")
354
+
355
+ return retriever
356
+
357
+
358
+ # ── load ──────────────────────────────────────────────────────────────────────
359
+
360
+ def load(index_path: str | Path = "data/processed/bm25_index.pkl") -> BM25Retriever:
361
+ """
362
+ Load a previously saved BM25Retriever from disk.
363
+ Call this in app.py instead of rebuilding every time.
364
+ """
365
+ index_path = Path(index_path)
366
+ if not index_path.exists():
367
+ raise FileNotFoundError(
368
+ f"BM25 index not found at '{index_path}'.\n"
369
+ "Run build_and_save() from your notebook first."
370
+ )
371
+ with open(index_path, "rb") as f:
372
+ retriever = pickle.load(f)
373
+ print(f"BM25 index loaded ← {index_path}")
374
+ return retriever
375
+
376
+
377
+ # ── search ────────────────────────────────────────────────────────────────────
378
+
379
+ def search(
380
+ retriever: BM25Retriever,
381
+ query: str,
382
+ top_k: int = 3,
383
+ ) -> list[dict]:
384
+ retriever.k = top_k
385
+
386
+ # Tokenize query the same way the index was built
387
+ tokenized_query = simple_tokenize(query)
388
+
389
+ # Get raw BM25 scores for ALL documents
390
+ scores = retriever.vectorizer.get_scores(tokenized_query) # np.ndarray, len = n_docs
391
+
392
+ # Get top-k doc indices by score
393
+ top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k]
394
+
395
+ results = []
396
+ for idx in top_indices:
397
+ doc = retriever.docs[idx] # retriever.docs holds the original Document list
398
+ m = doc.metadata
399
+ top_reviews = m.get("top_reviews", [])
400
+
401
+ rated = [r["rating"] for r in top_reviews if r.get("rating") is not None]
402
+ avg_rating = round(sum(rated) / len(rated), 1) if rated else 0.0
403
+
404
+ if top_reviews and top_reviews[0].get("text"):
405
+ snippet = top_reviews[0]["text"][:300]
406
+ else:
407
+ snippet = m.get("description", "")[:300]
408
+
409
+ results.append({
410
+ "asin": m.get("parent_asin", ""),
411
+ "title": m.get("title", ""),
412
+ "text": snippet,
413
+ "rating": avg_rating,
414
+ "score": float(scores[idx]),
415
+ "top_reviews": top_reviews,
416
+ })
417
+
418
+ return results
419
+
420
+
421
+ # ── notebook entry point ──────────────────────────────────────────────────────
422
+
423
+ def build_from_hf_datasets(
424
+ metadata_dataset: Dataset,
425
+ reviews_dataset_dict,
426
+ index_path: str | Path = "data/processed/tokenisation/bm25_index.pkl",
427
+ corpus_path: str | Path = "data/processed/tokenisation/bm25_corpus.pkl",
428
+ max_products: int | None = None,
429
+ max_reviews_per_product: int = 5,
430
+ ) -> BM25Retriever:
431
+ """
432
+ End-to-end helper to call from milestone1_exploration.ipynb.
433
+
434
+ Example usage in the notebook:
435
+ --------------------------------
436
+ from src.bm25 import build_from_hf_datasets, load, search
437
+
438
+ retriever = build_from_hf_datasets(
439
+ metadata_dataset=raw_metadata['full'],
440
+ reviews_dataset_dict=raw_reviews,
441
+ max_products=500,
442
+ )
443
+
444
+ # Later in app.py — just load the saved index:
445
+ # retriever = load("data/processed/bm25_index.pkl")
446
+ # results = search(retriever, "something sweet for a cheese board")
447
+ """
448
+ reviews_lookup = pregroup_reviews(reviews_dataset_dict, max_reviews_per_product)
449
+ docs = build_documents(
450
+ metadata_dataset,
451
+ reviews_dataset_dict,
452
+ max_products=max_products,
453
+ max_reviews_per_product=max_reviews_per_product,
454
+ reviews_lookup=reviews_lookup,
455
+ )
456
+ return build_and_save(docs, index_path=index_path, corpus_path=corpus_path)
457
+
458
+
459
+ def build_from_hf_datasets_batched(
460
+ metadata_dataset: Dataset,
461
+ reviews_dataset_dict,
462
+ index_path: str | Path = "data/processed/tokenisation/bm25_index.pkl",
463
+ corpus_path: str | Path = "data/processed/tokenisation/bm25_corpus.pkl",
464
+ batch_size: int = 2000,
465
+ max_reviews_per_product: int = 5,
466
+ max_products: int | None = None,
467
+ ) -> BM25Retriever:
468
+ """
469
+ Memory-safe version of build_from_hf_datasets — builds documents in
470
+ batches to avoid OOM kernel crashes on large datasets.
471
+
472
+ Checkpoints completed batches to data/processed/checkpoints/ after each
473
+ batch, so if the kernel dies mid-run you can resume from the last
474
+ completed batch instead of starting over.
475
+
476
+ Example usage in the notebook:
477
+ --------------------------------
478
+ retriever = build_from_hf_datasets_batched(
479
+ metadata_dataset=raw_metadata['full'],
480
+ reviews_dataset_dict=raw_reviews,
481
+ batch_size=5000,
482
+ max_reviews_per_product=3,
483
+ max_products=60000, # None = use all
484
+ )
485
+ """
486
+ index_path = Path(index_path)
487
+ corpus_path = Path(corpus_path)
488
+
489
+ # checkpoint folder lives next to the index
490
+ checkpoint_dir = index_path.parent / "checkpoints"
491
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
492
+
493
+ total = min(len(metadata_dataset), max_products) if max_products else len(metadata_dataset)
494
+
495
+ # find resume point — checkpoints named docs_0.pkl, docs_2000.pkl, ...
496
+ existing = sorted(checkpoint_dir.glob("docs_*.pkl"))
497
+ if existing:
498
+ last_ckpt = existing[-1]
499
+ resume_start = int(last_ckpt.stem.split("_")[1]) + batch_size
500
+ print(f"Resuming from product {resume_start} "
501
+ f"({len(existing)} checkpoint(s) found)")
502
+ all_docs = []
503
+ for ckpt in existing:
504
+ with open(ckpt, "rb") as f:
505
+ all_docs.extend(pickle.load(f))
506
+ print(f" loaded {len(all_docs)} docs from checkpoints")
507
+ else:
508
+ resume_start = 0
509
+ all_docs = []
510
+ print(f"Starting fresh — {total} products to process")
511
+
512
+ # pre-group all reviews once
513
+ reviews_lookup = pregroup_reviews(reviews_dataset_dict, max_reviews_per_product)
514
+
515
+ # batch loop
516
+ for start in range(resume_start, total, batch_size):
517
+ end = min(start + batch_size, total)
518
+ print(f"\nBatch {start}-{end} of {total} ...")
519
+
520
+ batch = metadata_dataset.select(range(start, end))
521
+ batch_docs = build_documents(
522
+ batch,
523
+ reviews_dataset_dict,
524
+ max_products=None,
525
+ max_reviews_per_product=max_reviews_per_product,
526
+ reviews_lookup=reviews_lookup,
527
+ )
528
+ all_docs.extend(batch_docs)
529
+
530
+ # save checkpoint for this batch
531
+ ckpt_path = checkpoint_dir / f"docs_{start}.pkl"
532
+ with open(ckpt_path, "wb") as f:
533
+ pickle.dump(batch_docs, f)
534
+ print(f" checkpoint saved -> {ckpt_path.name}")
535
+ print(f" cumulative docs : {len(all_docs)}")
536
+
537
+ # build final index
538
+ print(f"\nAll batches done - {len(all_docs)} total documents.")
539
+ retriever = build_and_save(all_docs, index_path=index_path, corpus_path=corpus_path)
540
+
541
+ # clean up checkpoints now that final index is safely written
542
+ for ckpt in checkpoint_dir.glob("docs_*.pkl"):
543
+ ckpt.unlink()
544
+ print("Checkpoints cleaned up.")
545
+
546
+ return retriever
src/src/eda_helpers.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import Dataset
2
+ import duckdb
3
+
4
+ def dataset_overview(dataset_dict) -> None:
5
+ """Print a concise overview of a DatasetDict: splits, features, row counts."""
6
+ print(f"\n{'='*60}")
7
+ print(f" Overview")
8
+ print(f"{'='*60}")
9
+ for split, ds in dataset_dict.items():
10
+ print(f"\n Split : {split!r} ({ds.num_rows:,} rows)")
11
+ print(f" {'Field':<30} {'dtype'}")
12
+ print(f" {'-'*45}")
13
+ for feat, ftype in ds.features.items():
14
+ print(f" {feat:<30} {ftype}")
15
+ print()
16
+
17
+ def get_reviews_by_asin(
18
+ reviews_dataset,
19
+ parent_asin: str,
20
+ ):
21
+ """
22
+ Retrieve all reviews matching a given parent_asin.
23
+
24
+ Parameters
25
+ ----------
26
+ reviews_dataset : DatasetDict (the full reviews DatasetDict)
27
+ parent_asin : the product ASIN to filter by
28
+ split : which split to search in (default: "full")
29
+
30
+ Returns
31
+ -------
32
+ HuggingFace Dataset containing only rows matching the given parent_asin
33
+ """
34
+ if not parent_asin or not isinstance(parent_asin,str):
35
+ raise TypeError("Invalid parent_asin passed")
36
+
37
+ ds = reviews_dataset["full"]
38
+
39
+ arrow_table = ds.data.table
40
+
41
+ matched_arrow = duckdb.query(
42
+ f"SELECT * FROM arrow_table WHERE parent_asin = '{parent_asin}'"
43
+ ).fetch_arrow_table()
44
+
45
+ return Dataset(matched_arrow)
46
+
47
+ def get_best_reviews(
48
+ reviews_dataset,
49
+ parent_asin: str,
50
+ top_k: int = None,
51
+ ):
52
+ """
53
+ Retrieve reviews matching a given parent_asin, optionally returning
54
+ only the top-k highest quality reviews.
55
+
56
+ Ranking score (all components normalized to [0, 1]):
57
+ - helpful_vote : 50% weight (log-scaled to reduce outlier dominance)
58
+ - verified_purchase : 30% weight (bool → 1.0 or 0.0)
59
+ - rating : 20% weight (how extreme the rating is — 1 or 5
60
+ are more informative than a neutral 3)
61
+
62
+ Parameters
63
+ ----------
64
+ reviews_dataset : DatasetDict
65
+ parent_asin : product ASIN to filter by
66
+ top_k : number of top reviews to return (None = return all, sorted)
67
+ split : which split to use
68
+
69
+ Returns
70
+ -------
71
+ HuggingFace Dataset
72
+ """
73
+ import math
74
+
75
+ matched = get_reviews_by_asin(reviews_dataset,parent_asin)
76
+ tot=matched.num_rows
77
+
78
+ if tot == 0:
79
+ return 0, matched
80
+
81
+ if top_k is None:
82
+ return 0, matched
83
+
84
+ # Step 2: compute scores
85
+ helpful_votes = matched["helpful_vote"]
86
+ verified = matched["verified_purchase"]
87
+ ratings = matched["rating"]
88
+
89
+ # Log-scale helpful votes: log(1 + x), then normalize to [0, 1]
90
+ log_votes = [math.log1p(v if v is not None else 0) for v in helpful_votes]
91
+ max_log = max(log_votes) if max(log_votes) > 0 else 1.0
92
+ norm_votes = [v / max_log for v in log_votes]
93
+
94
+ # Verified purchase: 1.0 if True, 0.0 otherwise
95
+ norm_verified = [1.0 if v else 0.0 for v in verified]
96
+
97
+ # Rating extremity: reviews at 1 or 5 are more informative than 3
98
+ # score = 1 - |rating - 3| / 2 → inverted so extreme ratings score higher
99
+ norm_rating = [abs((r if r is not None else 3.0) - 3.0) / 2.0 for r in ratings]
100
+
101
+ # Weighted sum
102
+ scores = [
103
+ 0.50 * nv + 0.30 * ver + 0.20 * nr
104
+ for nv, ver, nr in zip(norm_votes, norm_verified, norm_rating)
105
+ ]
106
+
107
+ # Step 3: select top-k indices by score
108
+ k = min(top_k, matched.num_rows)
109
+ top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:k]
110
+ top_indices_sorted = sorted(top_indices) # preserve original row order
111
+
112
+ return tot, matched.select(top_indices_sorted)
src/src/hybrid.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/hybrid.py
3
+ -------------
4
+ Hybrid retriever combining BM25 keyword search and FAISS semantic search,
5
+ fused with Reciprocal Rank Fusion (RRF).
6
+
7
+ Designed to plug into the existing run_rag() pipeline in rag_pipeline.py
8
+ as a drop-in replacement for the semantic retriever:
9
+
10
+ hybrid_retriever = load_hybrid_retriever(
11
+ bm25_index_path="data/processed/tokenisation/bm25_index_mini.pkl",
12
+ faiss_store_path="data/processed/embeddings",
13
+ k=5,
14
+ )
15
+ answer = run_rag(hybrid_retriever, "Best coffee beans for espresso")
16
+
17
+ The HybridRetriever class extends LangChain's BaseRetriever so it is fully
18
+ compatible with the | (pipe) operator used in rag_pipeline.py:
19
+
20
+ rag_chain = (
21
+ {
22
+ "context": hybrid_retriever | RunnableLambda(build_context),
23
+ "question": RunnablePassthrough(),
24
+ }
25
+ | prompt_template
26
+ | llm
27
+ | StrOutputParser()
28
+ )
29
+ """
30
+
31
+ from __future__ import annotations
32
+
33
+ import logging
34
+ from typing import Any
35
+
36
+ from langchain_community.retrievers import BM25Retriever
37
+ from langchain_community.vectorstores import FAISS
38
+ from langchain_core.callbacks import CallbackManagerForRetrieverRun
39
+ from langchain_core.documents import Document
40
+ from langchain_core.retrievers import BaseRetriever
41
+ from pydantic import Field
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # HybridRetriever
48
+ # ---------------------------------------------------------------------------
49
+
50
+ class HybridRetriever(BaseRetriever):
51
+ """
52
+ Combines BM25 keyword retrieval and FAISS semantic retrieval using
53
+ Reciprocal Rank Fusion (RRF) to produce a unified ranked document list.
54
+
55
+ RRF score for document d across retriever r:
56
+ score(d) = weight_r * (1 / (rrf_c + rank(d, r)))
57
+
58
+ Documents appearing in both retrievers accumulate scores from both,
59
+ naturally promoting results that are relevant by both keyword and meaning.
60
+
61
+ Parameters
62
+ ----------
63
+ bm25_retriever : Fitted LangChain BM25Retriever (from bm25.load())
64
+ semantic_store : Loaded FAISS vectorstore (from semantic.load_vector_store())
65
+ k : Number of final documents to return
66
+ rrf_c : RRF constant — dampens the impact of rank differences.
67
+ Standard value is 60; lower = top ranks matter more.
68
+ bm25_weight : RRF weight for BM25 results (keyword signal)
69
+ semantic_weight : RRF weight for semantic results (meaning signal)
70
+ fetch_multiplier : Fetch this multiple of k from each retriever before fusing.
71
+ More candidates = better fusion quality. Default: 3.
72
+ """
73
+
74
+ bm25_retriever: Any = Field(...)
75
+ semantic_store: Any = Field(...)
76
+ k: int = Field(default=5)
77
+ rrf_c: int = Field(default=60)
78
+ bm25_weight: float = Field(default=0.5)
79
+ semantic_weight: float = Field(default=0.5)
80
+ fetch_multiplier: int = Field(default=3)
81
+
82
+ def _get_relevant_documents(
83
+ self,
84
+ query: str,
85
+ *,
86
+ run_manager: CallbackManagerForRetrieverRun,
87
+ ) -> list[Document]:
88
+ """
89
+ Core retrieval logic called by LangChain when the retriever is invoked.
90
+
91
+ Steps
92
+ -----
93
+ 1. Fetch candidates from BM25 and FAISS independently
94
+ 2. Assign RRF scores weighted by retriever confidence
95
+ 3. Deduplicate by parent_asin, accumulating scores for shared hits
96
+ 4. Sort by fused RRF score and return top-k Documents
97
+ """
98
+ fetch_k = self.k * self.fetch_multiplier
99
+
100
+ # ── 1. BM25 retrieval ────────────────────────────────────────────────
101
+ self.bm25_retriever.k = fetch_k
102
+ try:
103
+ bm25_docs: list[Document] = self.bm25_retriever.invoke(query)
104
+ logger.debug("BM25 returned %d docs for query: %r", len(bm25_docs), query)
105
+ except Exception as exc:
106
+ logger.warning("BM25 retrieval failed: %s — using empty list.", exc)
107
+ bm25_docs = []
108
+
109
+ # ── 2. Semantic retrieval ────────────────────────────────────────────
110
+ # similarity_search returns list[Document] (no scores needed — rank is enough for RRF)
111
+ try:
112
+ semantic_docs: list[Document] = self.semantic_store.similarity_search(
113
+ query, k=fetch_k
114
+ )
115
+ logger.debug(
116
+ "Semantic returned %d docs for query: %r", len(semantic_docs), query
117
+ )
118
+ except Exception as exc:
119
+ logger.warning("Semantic retrieval failed: %s — using empty list.", exc)
120
+ semantic_docs = []
121
+
122
+ # ── 3. RRF fusion ────────────────────────────────────────────────────
123
+ rrf_scores: dict[str, float] = {}
124
+ doc_map: dict[str, Document] = {}
125
+
126
+ def _asin_key(doc: Document, fallback: str) -> str:
127
+ """Use parent_asin as the dedup key; fall back to a content prefix."""
128
+ return doc.metadata.get("parent_asin") or fallback
129
+
130
+ for rank, doc in enumerate(bm25_docs):
131
+ key = _asin_key(doc, f"bm25_{rank}")
132
+ score = self.bm25_weight / (self.rrf_c + rank + 1)
133
+ rrf_scores[key] = rrf_scores.get(key, 0.0) + score
134
+ doc_map[key] = doc # BM25 docs have richer metadata (top_reviews etc.)
135
+
136
+ for rank, doc in enumerate(semantic_docs):
137
+ key = _asin_key(doc, f"sem_{rank}")
138
+ score = self.semantic_weight / (self.rrf_c + rank + 1)
139
+ rrf_scores[key] = rrf_scores.get(key, 0.0) + score
140
+ # Only add to doc_map if BM25 didn't already supply this product
141
+ # (BM25 metadata is richer — has top_reviews, image_url, etc.)
142
+ if key not in doc_map:
143
+ doc_map[key] = doc
144
+
145
+ # ── 4. Sort and truncate ─────────────────────────────────────────────
146
+ ranked_keys = sorted(rrf_scores, key=lambda k: rrf_scores[k], reverse=True)
147
+ top_docs = [doc_map[key] for key in ranked_keys[: self.k]]
148
+
149
+ # Attach fused score to metadata — useful for app display
150
+ for key, doc in zip(ranked_keys, top_docs):
151
+ doc.metadata["hybrid_score"] = round(rrf_scores[key], 6)
152
+ # Record which retriever(s) contributed to this result
153
+ in_bm25 = any(
154
+ _asin_key(d, f"bm25_{i}") == key for i, d in enumerate(bm25_docs)
155
+ )
156
+ in_sem = any(
157
+ _asin_key(d, f"sem_{i}") == key for i, d in enumerate(semantic_docs)
158
+ )
159
+ if in_bm25 and in_sem:
160
+ doc.metadata["retrieval_source"] = "hybrid"
161
+ elif in_bm25:
162
+ doc.metadata["retrieval_source"] = "bm25"
163
+ else:
164
+ doc.metadata["retrieval_source"] = "semantic"
165
+
166
+ logger.info(
167
+ "HybridRetriever: BM25=%d, Semantic=%d → fused=%d (returning top %d)",
168
+ len(bm25_docs), len(semantic_docs), len(rrf_scores), len(top_docs),
169
+ )
170
+ return top_docs
171
+
172
+
173
+ # ---------------------------------------------------------------------------
174
+ # Convenience loader
175
+ # ---------------------------------------------------------------------------
176
+
177
+ def load_hybrid_retriever(
178
+ bm25_index_path: str = "data/processed/tokenisation/bm25_index_mini.pkl",
179
+ faiss_store_path: str = "data/processed/embeddings",
180
+ k: int = 5,
181
+ bm25_weight: float = 0.5,
182
+ semantic_weight: float = 0.5,
183
+ rrf_c: int = 60,
184
+ fetch_multiplier: int = 3,
185
+ ) -> HybridRetriever:
186
+ """
187
+ Load both indexes from disk and return a ready-to-use HybridRetriever.
188
+
189
+ Call this once in your notebook or app.py, then pass the result to run_rag().
190
+
191
+ Parameters
192
+ ----------
193
+ bm25_index_path : Path to the pickled BM25Retriever (from bm25.build_and_save())
194
+ faiss_store_path : Directory containing index.faiss + index.pkl
195
+ (from semantic.build_and_save_vector_store())
196
+ k : Number of documents to return per query
197
+ bm25_weight : RRF weight for BM25 (keyword signal). Default 0.5.
198
+ semantic_weight : RRF weight for semantic (meaning signal). Default 0.5.
199
+ Weights don't need to sum to 1 but relative scale matters.
200
+ rrf_c : RRF rank-dampening constant. Default 60 (standard).
201
+ fetch_multiplier : Candidates to fetch per retriever = k * fetch_multiplier.
202
+
203
+ Returns
204
+ -------
205
+ HybridRetriever
206
+ A LangChain-compatible retriever pipeable with |.
207
+
208
+ Example
209
+ -------
210
+ >>> from src.hybrid import load_hybrid_retriever
211
+ >>> from src.rag_pipeline import run_rag
212
+ >>>
213
+ >>> hybrid = load_hybrid_retriever(k=5)
214
+ >>> answer = run_rag(hybrid, "Best coffee beans for a French press")
215
+ >>> print(answer)
216
+ """
217
+ # Import here to avoid circular imports when used from rag_pipeline.py
218
+ from src.bm25 import load as load_bm25
219
+ from src.semantic import load_vector_store
220
+
221
+ print(f"Loading BM25 index from: {bm25_index_path}")
222
+ bm25_ret: BM25Retriever = load_bm25(bm25_index_path)
223
+
224
+ print(f"Loading FAISS store from: {faiss_store_path}")
225
+ faiss_store: FAISS = load_vector_store(faiss_store_path)
226
+
227
+ retriever = HybridRetriever(
228
+ bm25_retriever=bm25_ret,
229
+ semantic_store=faiss_store,
230
+ k=k,
231
+ bm25_weight=bm25_weight,
232
+ semantic_weight=semantic_weight,
233
+ rrf_c=rrf_c,
234
+ fetch_multiplier=fetch_multiplier,
235
+ )
236
+ print(
237
+ f"HybridRetriever ready — k={k}, "
238
+ f"BM25 weight={bm25_weight}, Semantic weight={semantic_weight}, RRF c={rrf_c}"
239
+ )
240
+ return retriever
src/src/rag_pipeline.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ rag_chain.py
3
+ ------------
4
+ Amazon product RAG (Retrieval-Augmented Generation) pipeline using
5
+ LangChain + HuggingFace Inference Endpoints.
6
+
7
+ Typical usage
8
+ -------------
9
+ >>> from rag_chain import run_rag
10
+ >>> answer = run_rag(retriever, "Moisturizing shampoo for thick curly hair")
11
+ >>> print(answer)
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import logging
17
+ from typing import Any
18
+
19
+ from langchain_core.documents import Document
20
+ from langchain_core.output_parsers import StrOutputParser
21
+ from langchain_core.prompts import ChatPromptTemplate
22
+ from langchain_core.runnables import RunnableLambda, RunnablePassthrough
23
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
24
+ from src.retrieval_helpers import _format_docs
25
+
26
+ # ---------------------------------------------------------------------------
27
+ # Logging
28
+ # ---------------------------------------------------------------------------
29
+ logger = logging.getLogger(__name__)
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Constants
33
+ # ---------------------------------------------------------------------------
34
+ DEFAULT_REPO_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
35
+ DEFAULT_MAX_NEW_TOKENS = 512
36
+ DEFAULT_TOP_K = 5
37
+
38
+ DEFAULT_SYSTEM_PROMPT = (
39
+ "You are a helpful Amazon grocery shopping assistant.\n\n"
40
+ "You will receive a grocery query and a list of related Amazon products (including reviews and metadata).\n\n"
41
+ "Your response must follow this exact structure:\n\n"
42
+ "---\n\n"
43
+ "## 🛒 Recommended Products\n"
44
+ "For each product, write a numbered list entry, mentioning products by title "
45
+ "followed by 1–2 sentences describing the product and why it suits the query.\n\n"
46
+ "## 💡 Tips & Recipe Ideas\n"
47
+ "A bullet-point list of practical tips, storage advice, and brief recipe ideas related to the products above "
48
+ "(do NOT write out full recipes — keep each idea to 1–2 sentences)."
49
+ "Add food emojis if relevant.\n\n"
50
+ "---\n\n"
51
+ "Rules:\n"
52
+ "- Do not invent products. Only recommend products from the provided list.\n"
53
+ "- Keep descriptions factual and grounded in the provided reviews and metadata.\n"
54
+ "- Recipe ideas should be suggestions or ideas only, not step-by-step instructions.\n"
55
+ "- Format the entire response in Markdown.\n"
56
+ "- IMPORTANT: Whenever citing the product title: add the parent_asin in the following format [title](#parent_asin)"
57
+ )
58
+
59
+ # ---------------------------------------------------------------------------
60
+ # Helper functions
61
+ # ---------------------------------------------------------------------------
62
+
63
+ import logging
64
+ from langchain_core.runnables import RunnableLambda
65
+
66
+ logger = logging.getLogger(__name__)
67
+
68
+ def _make_verbose_tap(label: str, verbose: bool):
69
+ """
70
+ Returns a passthrough RunnableLambda that logs *value* when verbose=True.
71
+ Works for any chain step — docs, prompt messages, or raw strings.
72
+ """
73
+ def _tap(value):
74
+ if verbose:
75
+ if hasattr(value, "messages"): # ChatPromptValue
76
+ rendered = "\n".join(
77
+ f"[{m.type.upper()}]: {m.content}"
78
+ for m in value.messages
79
+ )
80
+ elif isinstance(value, list): # list of Documents
81
+ rendered = "\n".join(str(d) for d in value)
82
+ else:
83
+ rendered = str(value)
84
+
85
+ print(f"\n{'='*60}\n{label}\n{'='*60}\n{rendered}\n")
86
+ logger.debug("%s\n%s", label, rendered)
87
+ return value
88
+ return RunnableLambda(_tap)
89
+
90
+ def build_context(docs: list[Document]) -> str:
91
+ """
92
+ Concatenate a list of retrieved LangChain Documents into a single
93
+ context string that the LLM can reason over.
94
+
95
+ Each entry includes the product's ``parent_asin`` (falling back to its
96
+ position index), its page content, and its full metadata dict.
97
+
98
+ Parameters
99
+ ----------
100
+ docs:
101
+ List of ``langchain_core.documents.Document`` objects returned by
102
+ the retriever.
103
+
104
+ Returns
105
+ -------
106
+ str
107
+ A newline-separated block of product descriptions ready for prompt
108
+ injection. Returns an empty string when *docs* is empty.
109
+
110
+ Raises
111
+ ------
112
+ TypeError
113
+ If *docs* is not a list, or any element is not a ``Document``.
114
+ """
115
+ if not isinstance(docs, list):
116
+ raise TypeError(
117
+ f"'docs' must be a list of Document objects, got {type(docs).__name__}."
118
+ )
119
+ for i, doc in enumerate(docs):
120
+ if not isinstance(doc, Document):
121
+ raise TypeError(
122
+ f"Element at index {i} is not a Document; got {type(doc).__name__}."
123
+ )
124
+
125
+ if not docs:
126
+ logger.warning("build_context received an empty document list.")
127
+ return ""
128
+
129
+ return "\n\n".join(
130
+ f"ASIN {doc.metadata.get('parent_asin', n)} Description: {doc.page_content}\n"
131
+ f"Metadata: {doc.metadata}"
132
+ for n, doc in enumerate(docs)
133
+ )
134
+
135
+
136
+ def _build_llm(
137
+ repo_id: str,
138
+ max_new_tokens: int,
139
+ provider: str,
140
+ ) -> ChatHuggingFace:
141
+ """
142
+ Instantiate and return a ``ChatHuggingFace`` model backed by a
143
+ HuggingFace Inference Endpoint.
144
+
145
+ Parameters
146
+ ----------
147
+ repo_id:
148
+ HuggingFace Hub model identifier (e.g.
149
+ ``"meta-llama/Meta-Llama-3-8B-Instruct"``).
150
+ max_new_tokens:
151
+ Maximum number of tokens the model may generate per call.
152
+ provider:
153
+ Inference provider passed to ``HuggingFaceEndpoint``
154
+ (``"auto"``, ``"novita"``, etc.).
155
+
156
+ Returns
157
+ -------
158
+ ChatHuggingFace
159
+ A chat-compatible wrapper around the endpoint.
160
+ """
161
+ endpoint = HuggingFaceEndpoint(
162
+ repo_id=repo_id,
163
+ task="text-generation",
164
+ max_new_tokens=max_new_tokens,
165
+ provider=provider,
166
+ )
167
+ return ChatHuggingFace(llm=endpoint)
168
+
169
+
170
+ def _build_prompt_template(system_prompt: str) -> ChatPromptTemplate:
171
+ """
172
+ Create a ``ChatPromptTemplate`` with a system message and a human
173
+ turn that injects ``{context}`` and ``{question}`` placeholders.
174
+
175
+ Parameters
176
+ ----------
177
+ system_prompt:
178
+ The system-level instruction string.
179
+
180
+ Returns
181
+ -------
182
+ ChatPromptTemplate
183
+ """
184
+ return ChatPromptTemplate.from_messages([
185
+ ("system", system_prompt),
186
+ (
187
+ "human",
188
+ "context:\n{context}\n\nquestion:\n{question}\n\n"
189
+ "Answer based on the Amazon datasets:",
190
+ ),
191
+ ])
192
+
193
+
194
+ # ---------------------------------------------------------------------------
195
+ # Public API
196
+ # ---------------------------------------------------------------------------
197
+
198
+ def run_rag(
199
+ retriever: Any,
200
+ query: str,
201
+ system_prompt: str = DEFAULT_SYSTEM_PROMPT,
202
+ repo_id: str = DEFAULT_REPO_ID,
203
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
204
+ provider: str = "auto",
205
+ verbose: bool = False,
206
+ hf_dataset = None
207
+ ) -> str:
208
+ """
209
+ Execute a full RAG pipeline and return the model's answer.
210
+
211
+ The pipeline follows the steps below:
212
+
213
+ 1. **Retrieve** - *retriever* fetches the *k* most relevant documents
214
+ for *query*.
215
+ 2. **Format context** - :func:`build_context` serialises the documents
216
+ into a single string.
217
+ 3. **Prompt** - the context and query are injected into the chat prompt
218
+ template.
219
+ 4. **Generate** - the LLM produces an answer grounded in the context.
220
+ 5. **Parse** - the raw chat message is unwrapped to a plain string.
221
+
222
+ Parameters
223
+ ----------
224
+ retriever:
225
+ A LangChain-compatible retriever (must expose ``.invoke()`` and be
226
+ pipeable with ``|``). Typically created via
227
+ ``vectorstore.as_retriever(...)``.
228
+ query:
229
+ Natural-language question to answer (non-empty string).
230
+ system_prompt:
231
+ System-level instruction for the assistant. Defaults to
232
+ :data:`DEFAULT_SYSTEM_PROMPT`.
233
+ repo_id:
234
+ HuggingFace Hub model identifier. Defaults to
235
+ ``"meta-llama/Meta-Llama-3-8B-Instruct"``.
236
+ max_new_tokens:
237
+ Upper bound on generated tokens. Must be a positive integer.
238
+ Defaults to ``100``.
239
+ provider:
240
+ HuggingFace inference provider (e.g. ``"auto"``, ``"novita"``).
241
+ Defaults to ``"auto"``.
242
+
243
+ Returns
244
+ -------
245
+ str
246
+ The model's answer as a plain string.
247
+
248
+ Raises
249
+ ------
250
+ TypeError
251
+ If *retriever* is ``None``, *query* is not a string, or
252
+ *system_prompt* is not a string.
253
+ ValueError
254
+ If *query* is blank, *max_new_tokens* is not a positive integer,
255
+ or *repo_id* / *provider* are blank strings.
256
+
257
+ Examples
258
+ --------
259
+ >>> answer = run_rag(retriever, "Best waterproof mascara under $20")
260
+ >>> print(answer)
261
+ """
262
+ # ------------------------------------------------------------------
263
+ # Build chain components
264
+ # ------------------------------------------------------------------
265
+
266
+ logger.info("Initialising LLM endpoint: %s", repo_id)
267
+ llm = _build_llm(repo_id, max_new_tokens, provider)
268
+ prompt_template = _build_prompt_template(system_prompt)
269
+
270
+ retrieved_docs: list[Document] = [] # ← capture target
271
+
272
+ def _retrieve_and_capture(query: str) -> list[Document]:
273
+ """Invoke the retriever and snapshot the results for the caller."""
274
+ docs = retriever.invoke(query)
275
+ retrieved_docs.extend(docs) # ← populate closure variable
276
+ return docs # ← pass through to build_context
277
+
278
+ rag_chain = (
279
+ {
280
+ "context": RunnableLambda(_retrieve_and_capture)
281
+ | RunnableLambda(build_context)
282
+ | _make_verbose_tap("RETRIEVED CONTEXT", verbose),
283
+ "question": RunnablePassthrough(),
284
+ }
285
+ | _make_verbose_tap("PROMPT INPUTS (context + question)", verbose)
286
+ | prompt_template
287
+ | _make_verbose_tap("RENDERED PROMPT SENT TO LLM", verbose) # ← shows exact prompt
288
+ | llm
289
+ | StrOutputParser()
290
+ )
291
+
292
+ # ------------------------------------------------------------------
293
+ # Run
294
+ # ------------------------------------------------------------------
295
+ logger.info("Invoking RAG chain for query: %r", query)
296
+ answer: str = rag_chain.invoke(query)
297
+ logger.debug("RAG answer: %s", answer)
298
+
299
+ if hf_dataset:
300
+ docs = _format_docs(retrieved_docs, hf_dataset)
301
+ else:
302
+ docs = retrieved_docs
303
+
304
+ return answer, docs
src/src/retrieval_helpers.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import duckdb
2
+ import json, sys
3
+ import re
4
+ from pathlib import Path
5
+ ROOT_FOLDER = Path(__file__).resolve().parent.parent
6
+
7
+ sys.path.append(str(ROOT_FOLDER))
8
+ from src.semantic import semantic_search
9
+
10
+ def decode_ratings(page_content):
11
+ block_pattern = r'\[\d\.0★\].*'
12
+ matches = re.findall(block_pattern, page_content)
13
+ if matches:
14
+ pattern = r'\[(\d\.0)★\]\s*(.*?)\s*—\s*(.*)'
15
+ parsed = []
16
+
17
+ for r in matches[:3]:
18
+ match = re.match(pattern, r)
19
+ if match:
20
+ rating, title, text = match.groups()
21
+ parsed.append({
22
+ 'rating': float(rating),
23
+ 'title': title.strip(),
24
+ 'text': text.strip()
25
+ })
26
+
27
+ return(parsed)
28
+ else:
29
+ return {}
30
+
31
+ def enrich_search_results(vector_store, query: str, k: int, hf_dataset):
32
+ """
33
+ Perform similarity search and enrich results with HuggingFace dataset metadata.
34
+
35
+ Args:
36
+ vector_store: LangChain vector store instance
37
+ query: Search query string
38
+ k: Number of results to return
39
+ filter: Filter dict for similarity search
40
+ hf_dataset: HuggingFace Arrow dataset (datasets.Dataset)
41
+
42
+ Returns:
43
+ List of enriched metadata objects as dicts
44
+ """
45
+ results = semantic_search(query, vector_store, k=k)
46
+
47
+ # 1. Extract parent_asins from metadata
48
+ parent_asins = [doc.metadata.get("parent_asin") for doc, score in results]
49
+
50
+ # 2. Query HuggingFace dataset via DuckDB
51
+ con = duckdb.connect()
52
+ arrow_table = hf_dataset.data.table # Get underlying PyArrow table
53
+ con.register("hf_table", arrow_table)
54
+
55
+ asin_list = ", ".join(f"'{asin}'" for asin in parent_asins if asin)
56
+ query_sql = f"SELECT * FROM hf_table WHERE parent_asin IN ({asin_list})"
57
+ hf_rows = con.execute(query_sql).fetchdf()
58
+
59
+ # Build lookup: parent_asin -> metadata dict
60
+ asin_to_metadata = {
61
+ row["parent_asin"]: row.to_dict()
62
+ for _, row in hf_rows.iterrows()
63
+ }
64
+
65
+ enriched_results = []
66
+
67
+ for doc, score in results:
68
+ parent_asin = doc.metadata.get("parent_asin")
69
+ total_reviews = doc.metadata.get("total_reviews")
70
+ metadata_object = asin_to_metadata.get(parent_asin, {}).copy()
71
+ metadata_object['score'] = score
72
+ metadata_object['total_reviews'] = total_reviews
73
+
74
+ # 3. Extract 3 lines after "Top Reviews\n" from page_content
75
+ page_content = doc.page_content
76
+ metadata_object["reviews"] = decode_ratings(page_content)
77
+
78
+ enriched_results.append(metadata_object)
79
+
80
+ con.close()
81
+
82
+ # 4. Return JSON metadata objects
83
+ return [json.loads(json.dumps(obj, default=str)) for obj in enriched_results]
84
+
85
+
86
+ def enrich_bm25_search_results(retriever, query: str, k: int, hf_dataset):
87
+ """
88
+ Perform BM25 search and enrich results with HuggingFace dataset metadata.
89
+
90
+ Args:
91
+ retriever: LangChain BM25Retriever instance
92
+ query: Search query string
93
+ k: Number of results to return
94
+ hf_dataset: HuggingFace Arrow dataset (datasets.Dataset)
95
+
96
+ Returns:
97
+ List of enriched metadata objects as dicts
98
+ """
99
+ # Get BM25 scores via underlying rank_bm25 library
100
+ query_tokens = query.split()
101
+ scores = retriever.vectorizer.get_scores(query_tokens) # numpy array
102
+
103
+ top_k_indices = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)[:k]
104
+ results = [(retriever.docs[i], score) for i, score in top_k_indices]
105
+
106
+ # 1. Extract parent_asins from metadata
107
+ parent_asins = [doc.metadata.get("parent_asin") for doc, score in results]
108
+
109
+ # 2. Query HuggingFace dataset via DuckDB
110
+ con = duckdb.connect()
111
+ arrow_table = hf_dataset.data.table
112
+ con.register("hf_table", arrow_table)
113
+
114
+ asin_list = ", ".join(f"'{asin}'" for asin in parent_asins if asin)
115
+ query_sql = f"SELECT * FROM hf_table WHERE parent_asin IN ({asin_list})"
116
+ hf_rows = con.execute(query_sql).fetchdf()
117
+
118
+ # Build lookup: parent_asin -> metadata dict
119
+ asin_to_metadata = {
120
+ row["parent_asin"]: row.to_dict()
121
+ for _, row in hf_rows.iterrows()
122
+ }
123
+
124
+ enriched_results = []
125
+
126
+ for doc, score in results:
127
+ parent_asin = doc.metadata.get("parent_asin")
128
+
129
+ metadata_object = {
130
+ **doc.metadata,
131
+ **asin_to_metadata.get(parent_asin, {}),
132
+ "score": score,
133
+ }
134
+
135
+ metadata_object['reviews'] = metadata_object.pop('top_reviews', {}) or {}
136
+
137
+ enriched_results.append(metadata_object)
138
+
139
+ con.close()
140
+
141
+ # 4. Return JSON metadata objects
142
+ return [json.loads(json.dumps(obj, default=str)) for obj in enriched_results]
143
+
144
+ def _format_docs(results, hf_dataset):
145
+ """
146
+ Perform similarity search and enrich results with HuggingFace dataset metadata.
147
+
148
+ Args:
149
+ vector_store: LangChain vector store instance
150
+ query: Search query string
151
+ k: Number of results to return
152
+ filter: Filter dict for similarity search
153
+ hf_dataset: HuggingFace Arrow dataset (datasets.Dataset)
154
+
155
+ Returns:
156
+ List of enriched metadata objects as dicts
157
+ """
158
+
159
+ # 1. Extract parent_asins from metadata
160
+ parent_asins = [doc.metadata.get("parent_asin") for doc in results]
161
+
162
+ # 2. Query HuggingFace dataset via DuckDB
163
+ con = duckdb.connect()
164
+ arrow_table = hf_dataset.data.table # Get underlying PyArrow table
165
+ con.register("hf_table", arrow_table)
166
+
167
+ asin_list = ", ".join(f"'{asin}'" for asin in parent_asins if asin)
168
+ query_sql = f"SELECT * FROM hf_table WHERE parent_asin IN ({asin_list})"
169
+ hf_rows = con.execute(query_sql).fetchdf()
170
+
171
+ # Build lookup: parent_asin -> metadata dict
172
+ asin_to_metadata = {
173
+ row["parent_asin"]: row.to_dict()
174
+ for _, row in hf_rows.iterrows()
175
+ }
176
+
177
+ enriched_results = []
178
+
179
+ for doc in results:
180
+ parent_asin = doc.metadata.get("parent_asin")
181
+ total_reviews = doc.metadata.get("total_reviews")
182
+ metadata_object = asin_to_metadata.get(parent_asin, {}).copy()
183
+ metadata_object['total_reviews'] = total_reviews
184
+
185
+ # 3. Extract 3 lines after "Top Reviews\n" from page_content
186
+ page_content = doc.page_content
187
+ metadata_object["reviews"] = decode_ratings(page_content)
188
+
189
+ enriched_results.append(metadata_object)
190
+
191
+ con.close()
192
+
193
+ # 4. Return JSON metadata objects
194
+ return [json.loads(json.dumps(obj, default=str)) for obj in enriched_results]
src/src/semantic.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ semantic_search.py
3
+ ------------------
4
+ Semantic search over an Amazon product catalogue using FAISS + HuggingFace embeddings.
5
+
6
+ Expected inputs
7
+ ---------------
8
+ - metadata_dataset : datasets.Dataset — one row per product (raw_metadata["full"])
9
+ - reviews_dataset : datasets.Dataset — passed to get_best_reviews(reviews, asin, k)
10
+
11
+ Typical usage
12
+ -------------
13
+ docs = build_documents(raw_metadata["full"], raw_reviews, n=100)
14
+ store = build_vector_store(docs)
15
+ results = semantic_search("noise cancelling headphones", store, k=5)
16
+ """
17
+
18
+ import logging
19
+ from typing import Any
20
+ import torch
21
+ import json, os, sys
22
+ from pathlib import Path
23
+
24
+ import faiss
25
+ from datasets import Dataset
26
+ from langchain_community.docstore.in_memory import InMemoryDocstore
27
+ from langchain_community.vectorstores import FAISS
28
+ from langchain_core.documents import Document
29
+ from langchain_huggingface import HuggingFaceEmbeddings
30
+ ROOT_FOLDER = Path(__file__).resolve().parent.parent
31
+
32
+ sys.path.append(str(ROOT_FOLDER))
33
+ from src.eda_helpers import get_best_reviews
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # Constants
39
+ # ---------------------------------------------------------------------------
40
+
41
+ DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
42
+ DEFAULT_TOP_REVIEWS = 5
43
+ DEFAULT_TOP_K = 5
44
+
45
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
46
+ EMBEDDINGS = HuggingFaceEmbeddings(
47
+ model_name=DEFAULT_EMBEDDING_MODEL,
48
+ model_kwargs={
49
+ "device": DEVICE,
50
+ "model_kwargs": {"torch_dtype": torch.float16},
51
+ },
52
+ encode_kwargs={
53
+ "batch_size": 128 if DEVICE == 'cpu' else 512,
54
+ "normalize_embeddings": True,
55
+ },
56
+ )
57
+
58
+ # ---------------------------------------------------------------------------
59
+ # Document construction
60
+ # ---------------------------------------------------------------------------
61
+
62
+ def _format_review(review) -> str:
63
+ """Return a concise single-line string for one review."""
64
+ rating = review.get("rating", "?")
65
+ title = (review.get("title") or "").strip()
66
+ text = (review.get("text") or "").strip()
67
+ return f"[{rating}★] {title} — {text}"
68
+
69
+
70
+ def _build_reviews_block(
71
+ reviews: Dataset,
72
+ parent_asin: str,
73
+ k: int = DEFAULT_TOP_REVIEWS,
74
+ ) -> str:
75
+ """
76
+ Fetch top-k reviews for *parent_asin* and return a formatted text block.
77
+ Returns an empty string when no reviews are found.
78
+ """
79
+ total, product_reviews = get_best_reviews(reviews, parent_asin, k)
80
+ if not product_reviews:
81
+ return 0, ""
82
+ lines = "\n ".join(_format_review(r) for r in product_reviews)
83
+ return total, f"{lines}"
84
+
85
+
86
+ def _build_page_content(product, review_block: str) -> str:
87
+ """Assemble the text that will be embedded. Empty sections are omitted."""
88
+ title = (product.get("title") or "").strip()
89
+ main_category = (product.get("main_category") or "").strip()
90
+ categories = main_category +" >> " + " > ".join(product.get("categories") or [])
91
+ features = "\n ".join(product.get("features") or [])
92
+ description = " ".join(product.get("description") or [])
93
+ details = (product.get("details") or "").strip()
94
+
95
+ parts = [f"Product: {title}"]
96
+ if categories:
97
+ parts.append(f"Category Path: {categories}")
98
+ if features:
99
+ parts.append(f"Features:\n {features}")
100
+ if description:
101
+ parts.append(f"Description:\n {description}")
102
+ if review_block:
103
+ parts.append(f"Top Reviews:\n {review_block}")
104
+ if details:
105
+ parts.append(f"Details:\n {details}")
106
+
107
+ return "\n".join(parts)
108
+
109
+
110
+ def create_document(product, reviews: Dataset) -> Document | None:
111
+ """
112
+ Build a :class:`~langchain_core.documents.Document` from one product row.
113
+
114
+ Args:
115
+ product: A single row from a HuggingFace metadata Dataset (dict-like).
116
+ reviews: The full reviews Dataset, forwarded to ``get_best_reviews``.
117
+
118
+ Returns:
119
+ A Document, or ``None`` if the row has no ``parent_asin``.
120
+
121
+ Notes:
122
+ *page_content* contains only the text that influences embeddings.
123
+ *metadata* stores structured scalars used for filtering and display
124
+ after retrieval — values are kept flat and JSON-serialisable so FAISS
125
+ filter expressions work correctly.
126
+ """
127
+ parent_asin = product.get("parent_asin")
128
+ if not parent_asin:
129
+ logger.warning("Skipping product with missing parent_asin: %s", product.get("title"))
130
+ return None
131
+
132
+ tot, review_block = _build_reviews_block(reviews, parent_asin)
133
+ page_content = _build_page_content(product, review_block)
134
+
135
+ metadata = {
136
+ # --- identifiers ---
137
+ "parent_asin": parent_asin,
138
+ # --- numeric (filterable / rankable) ---
139
+ "price": product.get("price"),
140
+ "average_rating": product.get("average_rating"),
141
+ "rating_number": product.get("rating_number"),
142
+ # --- categorical (filterable) ---
143
+ "main_category": product.get("main_category", ""),
144
+ "categories": product.get("categories") or [],
145
+ # --- free-form (display only; coerce to str for FAISS compatibility) ---
146
+ "details": str(product.get("details") or ""),
147
+ "total_reviews": tot
148
+ }
149
+
150
+ return Document(page_content=page_content, metadata=metadata)
151
+
152
+
153
+ # ---------------------------------------------------------------------------
154
+ # Vector store
155
+ # ---------------------------------------------------------------------------
156
+
157
+ # Case when we want to create embeddings at once
158
+ def build_vector_store(
159
+ docs: list[Document],
160
+ existing_store: FAISS | None = None,
161
+ ) -> FAISS:
162
+ """
163
+ Embed *docs* and return (or update) a FAISS vector store.
164
+
165
+ If ``existing_store`` is provided, documents are added to it.
166
+ Otherwise, a new FAISS store is created.
167
+
168
+ Document IDs are set to ``parent_asin``.
169
+ """
170
+ if not docs:
171
+ raise ValueError("Cannot build a vector store from an empty document list.")
172
+
173
+ logger.info("Embedding on %s", DEVICE)
174
+
175
+ # --- Create new store if needed ---
176
+ if existing_store is None:
177
+ dim = len(EMBEDDINGS.embed_query("probe"))
178
+ index = faiss.IndexFlatL2(dim)
179
+
180
+ vector_store = FAISS(
181
+ embedding_function=EMBEDDINGS,
182
+ index=index,
183
+ docstore=InMemoryDocstore(),
184
+ index_to_docstore_id={},
185
+ )
186
+ else:
187
+ vector_store = existing_store
188
+
189
+ # --- Add documents ---
190
+ uuids = [doc.metadata["parent_asin"] for doc in docs]
191
+ vector_store.add_documents(documents=docs, ids=uuids)
192
+
193
+ logger.info("Indexed %d documents into FAISS.", len(docs))
194
+ return vector_store
195
+
196
+ # Running the above function in batches and saving
197
+ def build_and_save_vector_store(
198
+ metadata_dataset: Dataset,
199
+ reviews: Dataset,
200
+ save_path: str,
201
+ batch_size: int = 500,
202
+ ) -> FAISS:
203
+
204
+ # --- Resume / initialize ---
205
+ if os.path.exists(os.path.join(save_path, "index.faiss")):
206
+ vector_store = FAISS.load_local(
207
+ save_path, EMBEDDINGS, allow_dangerous_deserialization=True
208
+ )
209
+ already_indexed = set(vector_store.index_to_docstore_id.values())
210
+ print(f"Resuming — {len(already_indexed)} docs already indexed.")
211
+ else:
212
+ os.makedirs(save_path, exist_ok=True)
213
+ vector_store = None # let helper create it
214
+ already_indexed = set()
215
+
216
+ progress_file = os.path.join(save_path, "progress.json")
217
+
218
+ # --- Resume progress ---
219
+ if os.path.exists(progress_file):
220
+ with open(progress_file) as f:
221
+ resume_start = json.load(f).get("next_start", 0)
222
+ print(f"Resuming from row {resume_start}.")
223
+ else:
224
+ resume_start = 0
225
+
226
+ total = len(metadata_dataset)
227
+
228
+ for start in range(resume_start, total, batch_size):
229
+ batch = metadata_dataset.select(range(start, min(start + batch_size, total)))
230
+
231
+ docs = []
232
+ for row in batch:
233
+ doc = create_document(row, reviews)
234
+ if doc is not None and doc.metadata["parent_asin"] not in already_indexed:
235
+ docs.append(doc)
236
+
237
+ if docs:
238
+ vector_store = build_vector_store(
239
+ docs=docs,
240
+ existing_store=vector_store,
241
+ )
242
+ already_indexed.update(doc.metadata["parent_asin"] for doc in docs)
243
+
244
+ # --- Save after each batch ---
245
+ vector_store.save_local(save_path)
246
+ with open(progress_file, "w") as f:
247
+ json.dump({"next_start": min(start + batch_size, total)}, f)
248
+
249
+ print(f"Indexed {min(start + batch_size, total)} / {total} rows")
250
+
251
+ if os.path.exists(progress_file):
252
+ os.remove(progress_file)
253
+
254
+ return vector_store
255
+
256
+ # ---------------------------------------------------------------------------
257
+ # Search
258
+ # ---------------------------------------------------------------------------
259
+
260
+ def semantic_search(
261
+ query: str,
262
+ vector_store: FAISS,
263
+ k: int = DEFAULT_TOP_K,
264
+ filter = None,
265
+ ) -> list[Document]:
266
+ """
267
+ Run a semantic similarity search against a pre-built *vector_store*.
268
+
269
+ Args:
270
+ query: Natural-language search query.
271
+ vector_store: A FAISS store built with :func:`build_vector_store`.
272
+ k: Number of results to return.
273
+ filter: Optional metadata filter dict, e.g.
274
+ ``{"main_category": "Electronics"}``.
275
+
276
+ Returns:
277
+ Ordered list of the *k* most relevant Documents.
278
+ """
279
+ results = vector_store.similarity_search_with_score(query, k=k, filter=filter)
280
+ logger.info("'%s' -> %d results", query, len(results))
281
+ return results
282
+
283
+ # ---------------------------------------------------------------------------
284
+ # Read existing vector store
285
+ # ---------------------------------------------------------------------------
286
+
287
+ def load_vector_store(
288
+ load_path: str,
289
+ ) -> FAISS:
290
+
291
+ return FAISS.load_local(
292
+ load_path,
293
+ embeddings=EMBEDDINGS,
294
+ allow_dangerous_deserialization=True,
295
+ )
src/src/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import nltk
3
+ from nltk.corpus import stopwords
4
+
5
+ # Download stopwords if not already downloaded
6
+ nltk.download('stopwords', quiet=True)
7
+
8
+ # Define a set of English stopwords for filtering out common words
9
+ STOPWORDS = set(stopwords.words('english'))
10
+
11
+ # Tokenizer
12
+ def simple_tokenize(text):
13
+ if not text:
14
+ return []
15
+ text = text.lower()
16
+ text = re.sub(r"-", " ", text)
17
+ text = re.sub(r"[^a-z0-9\s]", "", text)
18
+ tokens = text.split()
19
+ tokens = [t for t in tokens if t not in STOPWORDS]
20
+ return tokens