import json import csv, sys from datetime import datetime from pathlib import Path import streamlit as st import markdown ROOT_FOLDER = Path(__file__).resolve().parent.parent sys.path.append(str(ROOT_FOLDER)) import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) from src.semantic import load_vector_store, enrich_search_results from src.rag_pipeline import run_rag from src.bm25 import load, search from src.hybrid import HybridRetriever from dotenv import load_dotenv load_dotenv() import warnings warnings.filterwarnings("ignore", category=UserWarning) # ─── Page config (must be first Streamlit call) ─────────────────────────────── st.set_page_config( page_title="Groceries & Gourmet Food Search", page_icon="🥕", layout="wide", initial_sidebar_state="collapsed", ) # ─── Paths ──────────────────────────────────────────────────────────────────── ROOT = Path(__file__).resolve().parent.parent FEEDBACK_CSV = ROOT / "results" / "feedback.csv" FEEDBACK_CSV.parent.mkdir(parents=True, exist_ok=True) TOP_K = 5 HF_TOKEN = os.getenv('HF_TOKEN') from huggingface_hub import snapshot_download, login # ─── Custom CSS ─────────────────────────────────────────────────────────────── with open('./app/styles.css', "r") as f: css = f.read() st.markdown(f"", unsafe_allow_html=True) VECTOR_STORE_DIR = ROOT / "data" / "processed" @st.cache_resource def load_vector_store_cached(): """ Load vector store and BM25 index from Hugging Face or local cache. Returns ------- tuple (vector_store, bm25_retriever) """ login(token=HF_TOKEN, add_to_git_credential=False) VECTOR_STORE_DIR.mkdir(parents=True, exist_ok=True) if not any(VECTOR_STORE_DIR.iterdir()): snapshot_path = Path(snapshot_download( repo_id="rishadaz/amazon_retriever-storage", repo_type="dataset", local_dir=str(VECTOR_STORE_DIR), token=HF_TOKEN, )) else: snapshot_path = VECTOR_STORE_DIR mini_index_path = Path(snapshot_path) / "tokenisation" / "bm25_index.pkl" embeddings_dir = Path(snapshot_path) / "embeddings" vector_store = load_vector_store(embeddings_dir) bm25_retriever = load(mini_index_path) return vector_store, bm25_retriever # ─── Get Data ────────────────────────────────────────────────────────────── # local tag will read from your local directory as a default it will # read the mini versions of the files we have provided in the repo data_source = os.getenv('DATA_SOURCE') print(f"Running with data source {data_source}") # note: remote has the full generated corpus and # embeddings which can take a long time to download and # the app might become heavy too and slow down # processing. For development pls use the smaller "local" corpus if data_source == 'local': MINI_INDEX_PATH = ROOT / "data" / "processed" / "tokenisation" / "bm25_index_mini.pkl" vector_store = load_vector_store(ROOT_FOLDER / 'data' / 'processed' / 'embeddings') retriever = load(MINI_INDEX_PATH) else: vector_store, retriever = load_vector_store_cached() def bm25_search(query: str, top_k: int = 3) -> list[dict]: """ Run BM25 keyword search. Parameters ---------- query : str top_k : int Returns ------- list[dict] Top-k retrieved results. """ results = search(retriever, query, top_k) return results def semantic_search(query: str, top_k: int = 3) -> list[dict]: """ Run semantic (embedding-based) search. Parameters ---------- query : str top_k : int Returns ------- list[dict] Top-k retrieved results with scores. """ results = enrich_search_results(vector_store, query, top_k) return results hybrid_retriever = HybridRetriever( bm25_retriever=retriever, semantic_store=vector_store, k=TOP_K, bm25_weight=0.5, semantic_weight=0.5, ) def llm_retriever(query: str, top_k: int = 5): """ Run RAG pipeline using hybrid retriever. Parameters ---------- query : str top_k : int Returns ------- tuple (answer, retrieved_docs, web_sources) """ answer, docs, web_sources = run_rag(hybrid_retriever, query=query) return answer, docs, web_sources # ─── Helpers ────────────────────────────────────────────────────────────────── def stars(rating: float) -> str: """ Convert numeric rating into star string. Parameters ---------- rating : float Returns ------- str Star representation (e.g., ★★★★½). """ full = int(rating) half = 1 if (rating - full) >= 0.5 else 0 empty = 5 - full - half return "★" * full + "½" * half + "☆" * empty def log_feedback(query: str, mode: str, asin: str, title: str, vote: str) -> None: """Append user feedback to CSV log.""" file_exists = FEEDBACK_CSV.exists() with open(FEEDBACK_CSV, "a", newline="", encoding="utf-8") as f: writer = csv.DictWriter( f, fieldnames=["timestamp", "query", "mode", "asin", "title", "vote"] ) if not file_exists: writer.writeheader() writer.writerow({ "timestamp": datetime.now().isoformat(), "query": query, "mode": mode, "asin": asin, "title": title, "vote": vote, }) def render_product(ind, item, mode): """Render a single product card with reviews and feedback buttons.""" item = dict(item) if "reviews" in item.keys(): reviews = item.get("reviews",{}) elif "top_reviews" in item.keys(): reviews = item.get("top_reviews",{}) else: reviews = [] title = item.get("title","") avg_rating = item["average_rating"] n_reviews = len(reviews) # total_reviews = item.get('total_reviews', n_reviews) rating_number = item.get('rating_number', 0) asin = item['parent_asin'] review_word = "review" if n_reviews == 1 else "reviews" large_image = item.get('image', "") image_html = f'' if large_image else f'' raw_price = item.get('price') score = item.get('score',None) if 'score' in item else item.get('hybrid_score',None) try: price_val = float(str(raw_price).replace('$', '').replace(',', '').strip()) price_html = f'${price_val:.2f}' except (TypeError, ValueError): price_html = '' # ── Product card header ─────────────────────────────────────────── score_badge = f'{mode} score: {float(score):.2f}' if score else "" if 'retrieval_source' in item: source_badge = f'Source: {item['retrieval_source']}' else: source_badge = '' st.markdown( f"""
{image_html}

#{ind + 1}   {title}

{stars(avg_rating)}  {avg_rating:.1f}/5 avg ({rating_number:,} ratings)    {score_badge} {source_badge} {"  " + price_html if price_html else ""}
""", unsafe_allow_html=True, ) # ── Reviews in collapsible expander ─────────────────────────────── expander_label = f"📖 Viewing top {n_reviews} {review_word} " with st.expander(expander_label, expanded=(n_reviews == 1)): for j, rev in enumerate(reviews): st.markdown( f"""
{rev['title']}  ·  {stars(rev['rating'])} {rev['rating']}/5  · 

{rev['text'][:300]}{'…' if len(rev['text']) > 300 else ''}
""", unsafe_allow_html=True, ) # ── Feedback buttons (per product) ──────────────────────────────── col_up, col_dn, _ = st.columns([1, 1, 10]) with col_up: if st.button("👍", key=f"up_{mode}_{asin}_{ind}"): log_feedback(query, mode, asin, title, "up") st.toast("Thanks! 👍") with col_dn: if st.button("👎", key=f"dn_{mode}_{asin}_{ind}"): log_feedback(query, mode, asin, title, "down") st.toast("Noted! 👎") st.markdown("
", unsafe_allow_html=True) def render_results(results: list[dict], mode: str) -> None: """Render a list of product results.""" if not results: st.info("No results returned.") return for ind, item in enumerate(results): render_product(ind,item, mode) # ─── App layout ─────────────────────────────────────────────────────────────── st.markdown( """ """, unsafe_allow_html=True, ) # ─── Search bar ─────────────────────────────────────────────────────────────── query = st.text_input( "Search for a product or describe what you're looking for", placeholder="e.g. something sweet for a cheese board...", ) # ─── Run searches only when query changes ───────────────────────────────────── if query.strip() and query != st.session_state.get("last_query"): st.session_state.last_query = query with st.spinner("Searching..."): st.session_state.bm25_results = bm25_search(query, top_k=TOP_K) st.session_state.semantic_results = semantic_search(query, top_k=TOP_K) with st.spinner("Asking AI..."): try: answer, docs, web_sources = llm_retriever(query, top_k=TOP_K) st.session_state.llm_result = answer st.session_state.llm_docs = docs st.session_state.web_sources = web_sources except Exception as e: st.session_state.llm_result = f"**Error:** {e}" st.session_state.llm_docs = [] st.session_state.web_sources = [] elif not query.strip(): # Clear results when input is emptied for key in ("last_query", "bm25_results", "semantic_results", "llm_result"): st.session_state.pop(key, None) # ─── Tabs ───────────────────────────────────────────────────────────────────── tab_search, tab_llm = st.tabs(["🔍 Search", "🤖 AI Assistant"]) # ─── Search Tab ─────────────────────────────────────────────────────────────── with tab_search: mode = st.radio( "Search mode", options=["BM25", "Semantic"], index=0, horizontal=True, help="BM25 = keyword matching · Semantic = embedding similarity (all-MiniLM-L6-v2 + FAISS)", ) if "last_query" not in st.session_state: st.markdown( "

Enter a query above to see results.

", unsafe_allow_html=True, ) else: st.markdown(f"#### Top {TOP_K} results — {mode}") results = ( st.session_state.bm25_results if mode == "BM25" else st.session_state.semantic_results ) render_results(results, mode=mode.lower()) # ─── LLM Tab ────────────────────────────────────────────────────────────────── with tab_llm: if "llm_result" not in st.session_state: st.markdown( "

Enter a query above to get AI-powered recommendations.

", unsafe_allow_html=True, ) else: st.markdown(f"#### 🤖 AI Answer — *\"{st.session_state.last_query}\"*") st.caption("⚠️ AI responses may contain errors - please verify before relying on them.") html_response = markdown.markdown( st.session_state.llm_result, extensions=["tables", "fenced_code", "nl2br"], ) st.markdown( f"
{html_response}
", unsafe_allow_html=True, ) st.markdown("#### 📦 Retrieved Products") docs = st.session_state.get("llm_docs", []) if docs: docs = [json.loads(json.dumps(obj.metadata, default=str)) for obj in docs] render_results(docs, mode='hybrid') else: st.markdown("

No documents retrieved.

", unsafe_allow_html=True) # ── Web sources ─────────────────────────────────────────────────────── sources = st.session_state.get("web_sources", []) if sources: st.markdown("#### 🌐 Web Sources") for s in sources: st.markdown(f"- [{s['title']}]({s['url']})") # ─── Sidebar: feedback log ──────────────────────────────────────────────────── with st.sidebar: st.header("📋 Feedback Log") if FEEDBACK_CSV.exists(): import pandas as pd df = pd.read_csv(FEEDBACK_CSV) st.dataframe(df.tail(20), use_container_width=True) st.download_button( "⬇️ Download feedback.csv", data=df.to_csv(index=False), file_name="feedback.csv", mime="text/csv", ) else: st.info("No feedback yet — use 👍/👎 on results.")