Spaces:
Sleeping
Sleeping
| import json | |
| import csv, sys | |
| from datetime import datetime | |
| from pathlib import Path | |
| import streamlit as st | |
| import markdown | |
| ROOT_FOLDER = Path(__file__).resolve().parent.parent | |
| sys.path.append(str(ROOT_FOLDER)) | |
| import sys | |
| import os | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) | |
| from src.semantic import load_vector_store, enrich_search_results | |
| from src.rag_pipeline import run_rag | |
| from src.bm25 import load, search | |
| from src.hybrid import HybridRetriever | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| # βββ Page config (must be first Streamlit call) βββββββββββββββββββββββββββββββ | |
| st.set_page_config( | |
| page_title="Groceries & Gourmet Food Search", | |
| page_icon="π₯", | |
| layout="wide", | |
| initial_sidebar_state="collapsed", | |
| ) | |
| # βββ Paths ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ROOT = Path(__file__).resolve().parent.parent | |
| FEEDBACK_CSV = ROOT / "results" / "feedback.csv" | |
| FEEDBACK_CSV.parent.mkdir(parents=True, exist_ok=True) | |
| TOP_K = 5 | |
| HF_TOKEN = os.getenv('HF_TOKEN') | |
| from huggingface_hub import snapshot_download, login | |
| # βββ Custom CSS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with open('./app/styles.css', "r") as f: | |
| css = f.read() | |
| st.markdown(f"<style>{css}</style>", unsafe_allow_html=True) | |
| VECTOR_STORE_DIR = ROOT / "data" / "processed" | |
| def load_vector_store_cached(): | |
| """ | |
| Load vector store and BM25 index from Hugging Face or local cache. | |
| Returns | |
| ------- | |
| tuple | |
| (vector_store, bm25_retriever) | |
| """ | |
| login(token=HF_TOKEN, add_to_git_credential=False) | |
| VECTOR_STORE_DIR.mkdir(parents=True, exist_ok=True) | |
| if not any(VECTOR_STORE_DIR.iterdir()): | |
| snapshot_path = Path(snapshot_download( | |
| repo_id="rishadaz/amazon_retriever-storage", | |
| repo_type="dataset", | |
| local_dir=str(VECTOR_STORE_DIR), | |
| token=HF_TOKEN, | |
| )) | |
| else: | |
| snapshot_path = VECTOR_STORE_DIR | |
| mini_index_path = Path(snapshot_path) / "tokenisation" / "bm25_index.pkl" | |
| embeddings_dir = Path(snapshot_path) / "embeddings" | |
| vector_store = load_vector_store(embeddings_dir) | |
| bm25_retriever = load(mini_index_path) | |
| return vector_store, bm25_retriever | |
| # βββ Get Data ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # local tag will read from your local directory as a default it will | |
| # read the mini versions of the files we have provided in the repo | |
| data_source = os.getenv('DATA_SOURCE') | |
| print(f"Running with data source {data_source}") | |
| # note: remote has the full generated corpus and | |
| # embeddings which can take a long time to download and | |
| # the app might become heavy too and slow down | |
| # processing. For development pls use the smaller "local" corpus | |
| if data_source == 'local': | |
| MINI_INDEX_PATH = ROOT / "data" / "processed" / "tokenisation" / "bm25_index_mini.pkl" | |
| vector_store = load_vector_store(ROOT_FOLDER / 'data' / 'processed' / 'embeddings') | |
| retriever = load(MINI_INDEX_PATH) | |
| else: | |
| vector_store, retriever = load_vector_store_cached() | |
| def bm25_search(query: str, top_k: int = 3) -> list[dict]: | |
| """ | |
| Run BM25 keyword search. | |
| Parameters | |
| ---------- | |
| query : str | |
| top_k : int | |
| Returns | |
| ------- | |
| list[dict] | |
| Top-k retrieved results. | |
| """ | |
| results = search(retriever, query, top_k) | |
| return results | |
| def semantic_search(query: str, top_k: int = 3) -> list[dict]: | |
| """ | |
| Run semantic (embedding-based) search. | |
| Parameters | |
| ---------- | |
| query : str | |
| top_k : int | |
| Returns | |
| ------- | |
| list[dict] | |
| Top-k retrieved results with scores. | |
| """ | |
| results = enrich_search_results(vector_store, query, top_k) | |
| return results | |
| hybrid_retriever = HybridRetriever( | |
| bm25_retriever=retriever, | |
| semantic_store=vector_store, | |
| k=TOP_K, | |
| bm25_weight=0.5, | |
| semantic_weight=0.5, | |
| ) | |
| def llm_retriever(query: str, top_k: int = 5): | |
| """ | |
| Run RAG pipeline using hybrid retriever. | |
| Parameters | |
| ---------- | |
| query : str | |
| top_k : int | |
| Returns | |
| ------- | |
| tuple | |
| (answer, retrieved_docs, web_sources) | |
| """ | |
| answer, docs, web_sources = run_rag(hybrid_retriever, query=query) | |
| return answer, docs, web_sources | |
| # βββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def stars(rating: float) -> str: | |
| """ | |
| Convert numeric rating into star string. | |
| Parameters | |
| ---------- | |
| rating : float | |
| Returns | |
| ------- | |
| str | |
| Star representation (e.g., β β β β Β½). | |
| """ | |
| full = int(rating) | |
| half = 1 if (rating - full) >= 0.5 else 0 | |
| empty = 5 - full - half | |
| return "β " * full + "Β½" * half + "β" * empty | |
| def log_feedback(query: str, mode: str, asin: str, title: str, vote: str) -> None: | |
| """Append user feedback to CSV log.""" | |
| file_exists = FEEDBACK_CSV.exists() | |
| with open(FEEDBACK_CSV, "a", newline="", encoding="utf-8") as f: | |
| writer = csv.DictWriter( | |
| f, fieldnames=["timestamp", "query", "mode", "asin", "title", "vote"] | |
| ) | |
| if not file_exists: | |
| writer.writeheader() | |
| writer.writerow({ | |
| "timestamp": datetime.now().isoformat(), | |
| "query": query, | |
| "mode": mode, | |
| "asin": asin, | |
| "title": title, | |
| "vote": vote, | |
| }) | |
| def render_product(ind, item, mode): | |
| """Render a single product card with reviews and feedback buttons.""" | |
| item = dict(item) | |
| if "reviews" in item.keys(): | |
| reviews = item.get("reviews",{}) | |
| elif "top_reviews" in item.keys(): | |
| reviews = item.get("top_reviews",{}) | |
| else: | |
| reviews = [] | |
| title = item.get("title","") | |
| avg_rating = item["average_rating"] | |
| n_reviews = len(reviews) | |
| # total_reviews = item.get('total_reviews', n_reviews) | |
| rating_number = item.get('rating_number', 0) | |
| asin = item['parent_asin'] | |
| review_word = "review" if n_reviews == 1 else "reviews" | |
| large_image = item.get('image', "") | |
| image_html = f'<img src="{large_image}" style="width:100%;max-width:200px;border-radius:8px;margin-bottom:8px;" />' if large_image else f'<image src="" />' | |
| raw_price = item.get('price') | |
| score = item.get('score',None) if 'score' in item else item.get('hybrid_score',None) | |
| try: | |
| price_val = float(str(raw_price).replace('$', '').replace(',', '').strip()) | |
| price_html = f'<span style="color:#2ecc71;font-weight:600">${price_val:.2f}</span>' | |
| except (TypeError, ValueError): | |
| price_html = '' | |
| # ββ Product card header βββββββββββββββββββββββββββββββββββββββββββ | |
| score_badge = f'<span class="score-badge">{mode} score: {float(score):.2f}</span>' if score else "<span/>" | |
| if 'retrieval_source' in item: | |
| source_badge = f'<span class="score-badge">Source: {item['retrieval_source']}</span>' | |
| else: | |
| source_badge = '<span />' | |
| st.markdown( | |
| f""" | |
| <div class="product-card" id="{asin}"> | |
| {image_html} | |
| <h4>#{ind + 1} {title}</h4> | |
| <span class="stars">{stars(avg_rating)}</span> | |
| <small style="color:#888">{avg_rating:.1f}/5 avg ({rating_number:,} ratings)</small> | |
| | |
| {score_badge} {source_badge} | |
| {" " + price_html if price_html else ""} | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # ββ Reviews in collapsible expander βββββββββββββββββββββββββββββββ | |
| expander_label = f"π Viewing top {n_reviews} {review_word} " | |
| with st.expander(expander_label, expanded=(n_reviews == 1)): | |
| for j, rev in enumerate(reviews): | |
| st.markdown( | |
| f""" | |
| <div class="review-snippet"> | |
| <strong>{rev['title']}</strong> | |
| Β· | |
| <span class="stars">{stars(rev['rating'])}</span> | |
| <span style="color:#888; font-size:0.8rem"> {rev['rating']}/5</span> | |
| Β· | |
| <br><br> | |
| {rev['text'][:300]}{'β¦' if len(rev['text']) > 300 else ''} | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # ββ Feedback buttons (per product) ββββββββββββββββββββββββββββββββ | |
| col_up, col_dn, _ = st.columns([1, 1, 10]) | |
| with col_up: | |
| if st.button("π", key=f"up_{mode}_{asin}_{ind}"): | |
| log_feedback(query, mode, asin, title, "up") | |
| st.toast("Thanks! π") | |
| with col_dn: | |
| if st.button("π", key=f"dn_{mode}_{asin}_{ind}"): | |
| log_feedback(query, mode, asin, title, "down") | |
| st.toast("Noted! π") | |
| st.markdown("<hr style='border:none;border-top:1px solid #e8e0d0;margin:0.5rem 0 1rem'>", unsafe_allow_html=True) | |
| def render_results(results: list[dict], mode: str) -> None: | |
| """Render a list of product results.""" | |
| if not results: | |
| st.info("No results returned.") | |
| return | |
| for ind, item in enumerate(results): | |
| render_product(ind,item, mode) | |
| # βββ App layout βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown( | |
| """ | |
| <div class="banner"> | |
| <h1>π₯π§ Groceries & Gourmet Food Search</h1> | |
| <p>Amazon Products & Reviews Β· Groceries & Gourmet Food </p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # βββ Search bar βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| query = st.text_input( | |
| "Search for a product or describe what you're looking for", | |
| placeholder="e.g. something sweet for a cheese board...", | |
| ) | |
| # βββ Run searches only when query changes βββββββββββββββββββββββββββββββββββββ | |
| if query.strip() and query != st.session_state.get("last_query"): | |
| st.session_state.last_query = query | |
| with st.spinner("Searching..."): | |
| st.session_state.bm25_results = bm25_search(query, top_k=TOP_K) | |
| st.session_state.semantic_results = semantic_search(query, top_k=TOP_K) | |
| with st.spinner("Asking AI..."): | |
| try: | |
| answer, docs, web_sources = llm_retriever(query, top_k=TOP_K) | |
| st.session_state.llm_result = answer | |
| st.session_state.llm_docs = docs | |
| st.session_state.web_sources = web_sources | |
| except Exception as e: | |
| st.session_state.llm_result = f"**Error:** {e}" | |
| st.session_state.llm_docs = [] | |
| st.session_state.web_sources = [] | |
| elif not query.strip(): | |
| # Clear results when input is emptied | |
| for key in ("last_query", "bm25_results", "semantic_results", "llm_result"): | |
| st.session_state.pop(key, None) | |
| # βββ Tabs βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| tab_search, tab_llm = st.tabs(["π Search", "π€ AI Assistant"]) | |
| # βββ Search Tab βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with tab_search: | |
| mode = st.radio( | |
| "Search mode", | |
| options=["BM25", "Semantic"], | |
| index=0, | |
| horizontal=True, | |
| help="BM25 = keyword matching Β· Semantic = embedding similarity (all-MiniLM-L6-v2 + FAISS)", | |
| ) | |
| if "last_query" not in st.session_state: | |
| st.markdown( | |
| "<p style='color:#aaa; margin-top:1rem;'>Enter a query above to see results.</p>", | |
| unsafe_allow_html=True, | |
| ) | |
| else: | |
| st.markdown(f"#### Top {TOP_K} results β {mode}") | |
| results = ( | |
| st.session_state.bm25_results | |
| if mode == "BM25" | |
| else st.session_state.semantic_results | |
| ) | |
| render_results(results, mode=mode.lower()) | |
| # βββ LLM Tab ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with tab_llm: | |
| if "llm_result" not in st.session_state: | |
| st.markdown( | |
| "<p style='color:#aaa; margin-top:1rem;'>Enter a query above to get AI-powered recommendations.</p>", | |
| unsafe_allow_html=True, | |
| ) | |
| else: | |
| st.markdown(f"#### π€ AI Answer β *\"{st.session_state.last_query}\"*") | |
| st.caption("β οΈ AI responses may contain errors - please verify before relying on them.") | |
| html_response = markdown.markdown( | |
| st.session_state.llm_result, | |
| extensions=["tables", "fenced_code", "nl2br"], | |
| ) | |
| st.markdown( | |
| f"<div class='llm-response'>{html_response}</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown("#### π¦ Retrieved Products") | |
| docs = st.session_state.get("llm_docs", []) | |
| if docs: | |
| docs = [json.loads(json.dumps(obj.metadata, default=str)) for obj in docs] | |
| render_results(docs, mode='hybrid') | |
| else: | |
| st.markdown("<p style='color:#aaa;'>No documents retrieved.</p>", unsafe_allow_html=True) | |
| # ββ Web sources βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| sources = st.session_state.get("web_sources", []) | |
| if sources: | |
| st.markdown("#### π Web Sources") | |
| for s in sources: | |
| st.markdown(f"- [{s['title']}]({s['url']})") | |
| # βββ Sidebar: feedback log ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with st.sidebar: | |
| st.header("π Feedback Log") | |
| if FEEDBACK_CSV.exists(): | |
| import pandas as pd | |
| df = pd.read_csv(FEEDBACK_CSV) | |
| st.dataframe(df.tail(20), use_container_width=True) | |
| st.download_button( | |
| "β¬οΈ Download feedback.csv", | |
| data=df.to_csv(index=False), | |
| file_name="feedback.csv", | |
| mime="text/csv", | |
| ) | |
| else: | |
| st.info("No feedback yet β use π/π on results.") |