import json
import csv, sys
from datetime import datetime
from pathlib import Path
import streamlit as st
import markdown
ROOT_FOLDER = Path(__file__).resolve().parent.parent
sys.path.append(str(ROOT_FOLDER))
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
from src.semantic import load_vector_store, enrich_search_results
from src.rag_pipeline import run_rag
from src.bm25 import load, search
from src.hybrid import HybridRetriever
from dotenv import load_dotenv
load_dotenv()
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
# ─── Page config (must be first Streamlit call) ───────────────────────────────
st.set_page_config(
page_title="Groceries & Gourmet Food Search",
page_icon="🥕",
layout="wide",
initial_sidebar_state="collapsed",
)
# ─── Paths ────────────────────────────────────────────────────────────────────
ROOT = Path(__file__).resolve().parent.parent
FEEDBACK_CSV = ROOT / "results" / "feedback.csv"
FEEDBACK_CSV.parent.mkdir(parents=True, exist_ok=True)
TOP_K = 5
HF_TOKEN = os.getenv('HF_TOKEN')
from huggingface_hub import snapshot_download, login
# ─── Custom CSS ───────────────────────────────────────────────────────────────
with open('./app/styles.css', "r") as f:
css = f.read()
st.markdown(f"", unsafe_allow_html=True)
VECTOR_STORE_DIR = ROOT / "data" / "processed"
@st.cache_resource
def load_vector_store_cached():
"""
Load vector store and BM25 index from Hugging Face or local cache.
Returns
-------
tuple
(vector_store, bm25_retriever)
"""
login(token=HF_TOKEN, add_to_git_credential=False)
VECTOR_STORE_DIR.mkdir(parents=True, exist_ok=True)
if not any(VECTOR_STORE_DIR.iterdir()):
snapshot_path = Path(snapshot_download(
repo_id="rishadaz/amazon_retriever-storage",
repo_type="dataset",
local_dir=str(VECTOR_STORE_DIR),
token=HF_TOKEN,
))
else:
snapshot_path = VECTOR_STORE_DIR
mini_index_path = Path(snapshot_path) / "tokenisation" / "bm25_index.pkl"
embeddings_dir = Path(snapshot_path) / "embeddings"
vector_store = load_vector_store(embeddings_dir)
bm25_retriever = load(mini_index_path)
return vector_store, bm25_retriever
# ─── Get Data ──────────────────────────────────────────────────────────────
# local tag will read from your local directory as a default it will
# read the mini versions of the files we have provided in the repo
data_source = os.getenv('DATA_SOURCE')
print(f"Running with data source {data_source}")
# note: remote has the full generated corpus and
# embeddings which can take a long time to download and
# the app might become heavy too and slow down
# processing. For development pls use the smaller "local" corpus
if data_source == 'local':
MINI_INDEX_PATH = ROOT / "data" / "processed" / "tokenisation" / "bm25_index_mini.pkl"
vector_store = load_vector_store(ROOT_FOLDER / 'data' / 'processed' / 'embeddings')
retriever = load(MINI_INDEX_PATH)
else:
vector_store, retriever = load_vector_store_cached()
def bm25_search(query: str, top_k: int = 3) -> list[dict]:
"""
Run BM25 keyword search.
Parameters
----------
query : str
top_k : int
Returns
-------
list[dict]
Top-k retrieved results.
"""
results = search(retriever, query, top_k)
return results
def semantic_search(query: str, top_k: int = 3) -> list[dict]:
"""
Run semantic (embedding-based) search.
Parameters
----------
query : str
top_k : int
Returns
-------
list[dict]
Top-k retrieved results with scores.
"""
results = enrich_search_results(vector_store, query, top_k)
return results
hybrid_retriever = HybridRetriever(
bm25_retriever=retriever,
semantic_store=vector_store,
k=TOP_K,
bm25_weight=0.5,
semantic_weight=0.5,
)
def llm_retriever(query: str, top_k: int = 5):
"""
Run RAG pipeline using hybrid retriever.
Parameters
----------
query : str
top_k : int
Returns
-------
tuple
(answer, retrieved_docs, web_sources)
"""
answer, docs, web_sources = run_rag(hybrid_retriever, query=query)
return answer, docs, web_sources
# ─── Helpers ──────────────────────────────────────────────────────────────────
def stars(rating: float) -> str:
"""
Convert numeric rating into star string.
Parameters
----------
rating : float
Returns
-------
str
Star representation (e.g., ★★★★½).
"""
full = int(rating)
half = 1 if (rating - full) >= 0.5 else 0
empty = 5 - full - half
return "★" * full + "½" * half + "☆" * empty
def log_feedback(query: str, mode: str, asin: str, title: str, vote: str) -> None:
"""Append user feedback to CSV log."""
file_exists = FEEDBACK_CSV.exists()
with open(FEEDBACK_CSV, "a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(
f, fieldnames=["timestamp", "query", "mode", "asin", "title", "vote"]
)
if not file_exists:
writer.writeheader()
writer.writerow({
"timestamp": datetime.now().isoformat(),
"query": query,
"mode": mode,
"asin": asin,
"title": title,
"vote": vote,
})
def render_product(ind, item, mode):
"""Render a single product card with reviews and feedback buttons."""
item = dict(item)
if "reviews" in item.keys():
reviews = item.get("reviews",{})
elif "top_reviews" in item.keys():
reviews = item.get("top_reviews",{})
else:
reviews = []
title = item.get("title","")
avg_rating = item["average_rating"]
n_reviews = len(reviews)
# total_reviews = item.get('total_reviews', n_reviews)
rating_number = item.get('rating_number', 0)
asin = item['parent_asin']
review_word = "review" if n_reviews == 1 else "reviews"
large_image = item.get('image', "")
image_html = f'' if large_image else f'
Enter a query above to see results.
", unsafe_allow_html=True, ) else: st.markdown(f"#### Top {TOP_K} results — {mode}") results = ( st.session_state.bm25_results if mode == "BM25" else st.session_state.semantic_results ) render_results(results, mode=mode.lower()) # ─── LLM Tab ────────────────────────────────────────────────────────────────── with tab_llm: if "llm_result" not in st.session_state: st.markdown( "Enter a query above to get AI-powered recommendations.
", unsafe_allow_html=True, ) else: st.markdown(f"#### 🤖 AI Answer — *\"{st.session_state.last_query}\"*") st.caption("⚠️ AI responses may contain errors - please verify before relying on them.") html_response = markdown.markdown( st.session_state.llm_result, extensions=["tables", "fenced_code", "nl2br"], ) st.markdown( f"No documents retrieved.
", unsafe_allow_html=True) # ── Web sources ─────────────────────────────────────────────────────── sources = st.session_state.get("web_sources", []) if sources: st.markdown("#### 🌐 Web Sources") for s in sources: st.markdown(f"- [{s['title']}]({s['url']})") # ─── Sidebar: feedback log ──────────────────────────────────────────────────── with st.sidebar: st.header("📋 Feedback Log") if FEEDBACK_CSV.exists(): import pandas as pd df = pd.read_csv(FEEDBACK_CSV) st.dataframe(df.tail(20), use_container_width=True) st.download_button( "⬇️ Download feedback.csv", data=df.to_csv(index=False), file_name="feedback.csv", mime="text/csv", ) else: st.info("No feedback yet — use 👍/👎 on results.")