rishadaz commited on
Commit
a80240d
Β·
verified Β·
1 Parent(s): 08b4e10

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +314 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,316 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
4
  import streamlit as st
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv, sys, os
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+
5
  import streamlit as st
6
 
7
+ # ─── Repo root is the working directory on HF Spaces ─────────────────────────
8
+ ROOT = Path(__file__).resolve().parent # app.py lives at repo root
9
+ sys.path.append(str(ROOT))
10
+
11
+ from src.retrieval_helpers import enrich_search_results
12
+ from src.semantic import load_vector_store
13
+
14
+ import warnings
15
+ warnings.filterwarnings("ignore", category=UserWarning)
16
+
17
+ # ─── Page config (must be first Streamlit call) ───────────────────────────────
18
+ st.set_page_config(
19
+ page_title="Groceries & Gourmet Food Search",
20
+ page_icon="πŸ₯•",
21
+ layout="wide",
22
+ initial_sidebar_state="collapsed",
23
+ )
24
+
25
+ # ─── Paths ────────────────────────────────────────────────────────────────────
26
+ FEEDBACK_CSV = ROOT / "results" / "feedback.csv"
27
+ FEEDBACK_CSV.parent.mkdir(parents=True, exist_ok=True)
28
+
29
+ # ─── Load HF dataset (cached so it only runs once) ───────────────────────────
30
+ from datasets import load_dataset
31
+
32
+ @st.cache_resource
33
+ def load_hf_dataset():
34
+ return load_dataset(
35
+ "McAuley-Lab/Amazon-Reviews-2023",
36
+ "raw_meta_Grocery_and_Gourmet_Food",
37
+ trust_remote_code=True,
38
+ )
39
+
40
+ HF_DATASET = load_hf_dataset()
41
+
42
+ # ─── Download vector store from your HF dataset repo ─────────────────────────
43
+ from huggingface_hub import hf_hub_download, snapshot_download
44
+
45
+ VECTOR_STORE_DIR = ROOT / "embeddings" / "semantic_vector_store"
46
+
47
+ VECTOR_STORE_DIR = Path("/data/embeddings/semantic_vector_store")
48
+
49
+ @st.cache_resource
50
+ def load_vector_store_cached():
51
+ return load_vector_store(VECTOR_STORE_DIR)
52
+
53
+ # ─── Custom CSS ───────────────────────────────────────────────────────────────
54
+ st.markdown(
55
+ """
56
+ <style>
57
+ @import url('https://fonts.googleapis.com/css2?family=Playfair+Display:wght@600&family=Source+Sans+3:wght@400;600&display=swap');
58
+
59
+ html, body, [class*="css"] {
60
+ font-family: 'Source Sans 3', sans-serif;
61
+ }
62
+ h1, h2, h3 { font-family: 'Playfair Display', serif; }
63
+
64
+ .banner {
65
+ background: linear-gradient(135deg, #2d4a22 0%, #4a7c3f 60%, #7aab5c 100%);
66
+ border-radius: 12px;
67
+ padding: 2rem 2.5rem;
68
+ margin-bottom: 1.5rem;
69
+ color: #f5f0e8;
70
+ }
71
+ .banner h1 { margin: 0; font-size: 2.4rem; color: #f5f0e8; }
72
+ .banner p { margin: 0.3rem 0 0; font-size: 1.05rem; opacity: 0.85; }
73
+
74
+ /* Product card (outer) */
75
+ .product-card {
76
+ background: #fffdf7;
77
+ border: 1px solid #e2d9c8;
78
+ border-left: 4px solid #4a7c3f;
79
+ border-radius: 8px;
80
+ padding: 1rem 1.2rem 0.6rem;
81
+ margin-bottom: 0.4rem;
82
+ box-shadow: 0 1px 4px rgba(0,0,0,0.06);
83
+ }
84
+ .product-card h4 { margin: 0 0 0.2rem; color: #1e3318; font-size: 1.05rem; }
85
+
86
+ /* Review snippet inside expander */
87
+ .review-snippet {
88
+ background: #f7f4ee;
89
+ border-radius: 6px;
90
+ padding: 0.6rem 0.9rem;
91
+ margin-bottom: 0.5rem;
92
+ font-size: 0.87rem;
93
+ color: #444;
94
+ line-height: 1.55;
95
+ }
96
+ .score-badge {
97
+ display: inline-block;
98
+ background: #eaf3e6;
99
+ color: #2d5a20;
100
+ border-radius: 20px;
101
+ padding: 2px 10px;
102
+ font-size: 0.78rem;
103
+ font-weight: 600;
104
+ margin-right: 6px;
105
+ }
106
+ .stars { color: #e6a817; }
107
+
108
+ .placeholder-badge {
109
+ background: #fff3cd;
110
+ border: 1px solid #ffc107;
111
+ border-radius: 6px;
112
+ padding: 0.4rem 0.8rem;
113
+ font-size: 0.82rem;
114
+ color: #7a5800;
115
+ display: inline-block;
116
+ margin-bottom: 1rem;
117
+ }
118
+ </style>
119
+ """,
120
+ unsafe_allow_html=True,
121
+ )
122
+
123
+ # ─── Placeholder retrieval functions ──────────────────────────────────────────
124
+ # TODO: Replace with real imports once src/bm25.py and src/semantic.py are ready:
125
+ # from src.bm25 import BM25Retriever
126
+ # from src.semantic import SemanticRetriever
127
+ #
128
+ # Expected return format β€” list of dicts with keys:
129
+ # asin (str), title (str), text (str), rating (float), score (float)
130
+
131
+ DUMMY_RESULTS = {}
132
+
133
+
134
+ def bm25_search(query: str, top_k: int = 3) -> list[dict]:
135
+ """
136
+ PLACEHOLDER β€” swap with real BM25Retriever call, e.g.:
137
+ retriever = BM25Retriever.load('data/processed/bm25_index.pkl')
138
+ return retriever.search(query, top_k=top_k)
139
+ Returns top_k review-level results (may include multiple reviews per ASIN).
140
+ """
141
+ return [r.copy() for r in DUMMY_RESULTS[:top_k]]
142
+
143
+
144
+ def semantic_search(query: str, top_k: int = 3) -> list[dict]:
145
+ """
146
+ PLACEHOLDER β€” swap with real SemanticRetriever call, e.g.:
147
+ retriever = SemanticRetriever.load('data/processed/faiss_index')
148
+ return retriever.search(query, top_k=top_k)
149
+ Returns top_k review-level results (scores are cosine similarities, 0–1).
150
+ """
151
+ vector_store = load_vector_store_cached()
152
+ results = enrich_search_results(vector_store, query, top_k, HF_DATASET["full"])
153
+ return results
154
+
155
+
156
+ # ─── Helpers ──────────────────────────────────────────────────────────────────
157
+ def stars(rating: float) -> str:
158
+ full = int(rating)
159
+ half = 1 if (rating - full) >= 0.5 else 0
160
+ empty = 5 - full - half
161
+ return "β˜…" * full + "Β½" * half + "β˜†" * empty
162
+
163
+
164
+ def log_feedback(query: str, mode: str, asin: str, title: str, vote: str) -> None:
165
+ file_exists = FEEDBACK_CSV.exists()
166
+ with open(FEEDBACK_CSV, "a", newline="", encoding="utf-8") as f:
167
+ writer = csv.DictWriter(
168
+ f, fieldnames=["timestamp", "query", "mode", "asin", "title", "vote"]
169
+ )
170
+ if not file_exists:
171
+ writer.writeheader()
172
+ writer.writerow({
173
+ "timestamp": datetime.now().isoformat(),
174
+ "query": query,
175
+ "mode": mode,
176
+ "asin": asin,
177
+ "title": title,
178
+ "vote": vote,
179
+ })
180
+
181
+
182
+ def render_results(results: list[dict], mode: str, query: str) -> None:
183
+ if not results:
184
+ st.info("No results returned.")
185
+ return
186
+
187
+ grouped = results
188
+
189
+ for ind, item in enumerate(grouped):
190
+ reviews = item["reviews"]
191
+ title = item["title"]
192
+ avg_rating = item["average_rating"]
193
+ n_reviews = len(reviews)
194
+ total_reviews = item.get('total_reviews', n_reviews)
195
+ rating_number = item.get('rating_number', 0)
196
+ asin = item['parent_asin']
197
+ review_word = "review" if n_reviews == 1 else "reviews"
198
+ large_images = item.get('images', {}).get('large', [])
199
+ image_html = f'<img src="{large_images[0]}" style="width:100%;max-width:200px;border-radius:8px;margin-bottom:8px;" />' if large_images else ''
200
+ raw_price = item.get('price')
201
+ try:
202
+ price_val = float(str(raw_price).replace('$', '').replace(',', '').strip())
203
+ price_html = f'<span style="color:#2ecc71;font-weight:600">${price_val:.2f}</span>'
204
+ except (TypeError, ValueError):
205
+ price_html = ''
206
+
207
+
208
+ # ── Product card header ───────────────────────────────────────────
209
+ st.markdown(
210
+ f"""
211
+ <div class="product-card">
212
+ {image_html}
213
+ <h4>#{ind + 1} &nbsp; {title}</h4>
214
+ <span class="stars">{stars(avg_rating)}</span>
215
+ &nbsp;<small style="color:#888">{avg_rating:.1f}/5 avg ({rating_number:,} ratings)</small>
216
+ &nbsp;&nbsp;
217
+ <span class="score-badge">similarity score: {item['score']}</span>
218
+ {"&nbsp;&nbsp;" + price_html if price_html else ""}
219
+ </div>
220
+ """,
221
+ unsafe_allow_html=True,
222
+ )
223
+
224
+ # ── Reviews in collapsible expander ───────────────────────────────
225
+ expander_label = f"πŸ“– View {n_reviews} of total {total_reviews} {review_word} "
226
+ with st.expander(expander_label, expanded=(n_reviews == 1)):
227
+ for j, rev in enumerate(reviews):
228
+ st.markdown(
229
+ f"""
230
+ <div class="review-snippet">
231
+ <strong>{rev['title']}</strong>
232
+ &nbsp;Β·&nbsp;
233
+ <span class="stars">{stars(rev['rating'])}</span>
234
+ <span style="color:#888; font-size:0.8rem"> {rev['rating']}/5</span>
235
+ &nbsp;Β·&nbsp;
236
+ <br><br>
237
+ {rev['text'][:300]}{'…' if len(rev['text']) > 300 else ''}
238
+ </div>
239
+ """,
240
+ unsafe_allow_html=True,
241
+ )
242
+
243
+ # ── Feedback buttons (per product) ────────────────────────────────
244
+ col_up, col_dn, _ = st.columns([1, 1, 10])
245
+ with col_up:
246
+ if st.button("πŸ‘", key=f"up_{mode}_{asin}_{ind}"):
247
+ log_feedback(query, mode, asin, title, "up")
248
+ st.toast("Thanks! πŸ‘")
249
+ with col_dn:
250
+ if st.button("πŸ‘Ž", key=f"dn_{mode}_{asin}_{ind}"):
251
+ log_feedback(query, mode, asin, title, "down")
252
+ st.toast("Noted! πŸ‘Ž")
253
+
254
+ st.markdown("<hr style='border:none;border-top:1px solid #e8e0d0;margin:0.5rem 0 1rem'>", unsafe_allow_html=True)
255
+
256
+
257
+ # ─── App layout ───────────────────────────────────────────────────────────────
258
+ st.markdown(
259
+ """
260
+ <div class="banner">
261
+ <h1>πŸ₯•πŸ§€ Groceries & Gourmet Food Search</h1>
262
+ <p>Amazon Products & Reviews Β· Groceries & Gourmet Food </p>
263
+ </div>
264
+ """,
265
+ unsafe_allow_html=True,
266
+ )
267
+
268
+ st.markdown(
269
+ '<div class="placeholder-badge">⚠️ Placeholder mode β€” real BM25 / Semantic indices not yet loaded</div>',
270
+ unsafe_allow_html=True,
271
+ )
272
+
273
+ # ─── Search bar ───────────────────────────────────────────────────────────────
274
+ query = st.text_input(
275
+ "Search for a product or describe what you're looking for",
276
+ placeholder="e.g. something sweet for a cheese board...",
277
+ )
278
+
279
+ # ─── Mode radio ───────────────────────────────────────────────────────────────
280
+ mode = st.radio(
281
+ "Search mode",
282
+ options=["BM25", "Semantic"],
283
+ index=0, # BM25 shown by default
284
+ horizontal=True,
285
+ help="BM25 = keyword matching Β· Semantic = embedding similarity (all-MiniLM-L6-v2 + FAISS)",
286
+ )
287
+
288
+ # ─── Run & render ─────────────────────────────────────────────────────────────
289
+ TOP_K = 5 # fixed per milestone requirement
290
+
291
+ if query.strip():
292
+ st.markdown(f"#### Top {TOP_K} results β€” {mode}")
293
+
294
+ results = bm25_search(query, top_k=TOP_K) if mode == "BM25" else semantic_search(query, top_k=TOP_K)
295
+ render_results(results, mode=mode.lower(), query=query)
296
+ else:
297
+ st.markdown(
298
+ "<p style='color:#aaa; margin-top:1rem;'>Enter a query above to see results.</p>",
299
+ unsafe_allow_html=True,
300
+ )
301
+
302
+ # ─── Sidebar: feedback log ────────────────────────────────────────────────────
303
+ with st.sidebar:
304
+ st.header("πŸ“‹ Feedback Log")
305
+ if FEEDBACK_CSV.exists():
306
+ import pandas as pd
307
+ df = pd.read_csv(FEEDBACK_CSV)
308
+ st.dataframe(df.tail(20), use_container_width=True)
309
+ st.download_button(
310
+ "⬇️ Download feedback.csv",
311
+ data=df.to_csv(index=False),
312
+ file_name="feedback.csv",
313
+ mime="text/csv",
314
+ )
315
+ else:
316
+ st.info("No feedback yet β€” use πŸ‘/πŸ‘Ž on results.")