Spaces:
Running
Running
github-actions[bot] commited on
Commit ·
251d75e
1
Parent(s): 845adc6
chore: sync app/ and src/ from GitHub
Browse files- app/app.py +58 -8
- src/bm25.py +7 -0
- src/rag_pipeline.py +8 -1
- src/semantic.py +6 -0
- src/utils.py +2 -0
app/app.py
CHANGED
|
@@ -52,6 +52,14 @@ VECTOR_STORE_DIR = ROOT / "data" / "processed"
|
|
| 52 |
|
| 53 |
@st.cache_resource
|
| 54 |
def load_vector_store_cached():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
login(token=HF_TOKEN, add_to_git_credential=False)
|
| 56 |
VECTOR_STORE_DIR.mkdir(parents=True, exist_ok=True)
|
| 57 |
|
|
@@ -97,10 +105,17 @@ else:
|
|
| 97 |
|
| 98 |
def bm25_search(query: str, top_k: int = 3) -> list[dict]:
|
| 99 |
"""
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
"""
|
| 105 |
|
| 106 |
results = search(retriever, query, top_k)
|
|
@@ -109,10 +124,17 @@ def bm25_search(query: str, top_k: int = 3) -> list[dict]:
|
|
| 109 |
|
| 110 |
def semantic_search(query: str, top_k: int = 3) -> list[dict]:
|
| 111 |
"""
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
"""
|
| 117 |
|
| 118 |
results = enrich_search_results(vector_store, query, top_k)
|
|
@@ -128,12 +150,37 @@ hybrid_retriever = HybridRetriever(
|
|
| 128 |
|
| 129 |
|
| 130 |
def llm_retriever(query: str, top_k: int = 5):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
answer, docs, web_sources = run_rag(hybrid_retriever, query=query)
|
| 132 |
return answer, docs, web_sources
|
| 133 |
|
| 134 |
|
| 135 |
# ─── Helpers ──────────────────────────────────────────────────────────────────
|
| 136 |
def stars(rating: float) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
full = int(rating)
|
| 138 |
half = 1 if (rating - full) >= 0.5 else 0
|
| 139 |
empty = 5 - full - half
|
|
@@ -141,6 +188,7 @@ def stars(rating: float) -> str:
|
|
| 141 |
|
| 142 |
|
| 143 |
def log_feedback(query: str, mode: str, asin: str, title: str, vote: str) -> None:
|
|
|
|
| 144 |
file_exists = FEEDBACK_CSV.exists()
|
| 145 |
with open(FEEDBACK_CSV, "a", newline="", encoding="utf-8") as f:
|
| 146 |
writer = csv.DictWriter(
|
|
@@ -158,6 +206,7 @@ def log_feedback(query: str, mode: str, asin: str, title: str, vote: str) -> Non
|
|
| 158 |
})
|
| 159 |
|
| 160 |
def render_product(ind, item, mode):
|
|
|
|
| 161 |
item = dict(item)
|
| 162 |
if "reviews" in item.keys():
|
| 163 |
reviews = item.get("reviews",{})
|
|
@@ -240,6 +289,7 @@ def render_product(ind, item, mode):
|
|
| 240 |
|
| 241 |
|
| 242 |
def render_results(results: list[dict], mode: str) -> None:
|
|
|
|
| 243 |
if not results:
|
| 244 |
st.info("No results returned.")
|
| 245 |
return
|
|
|
|
| 52 |
|
| 53 |
@st.cache_resource
|
| 54 |
def load_vector_store_cached():
|
| 55 |
+
"""
|
| 56 |
+
Load vector store and BM25 index from Hugging Face or local cache.
|
| 57 |
+
|
| 58 |
+
Returns
|
| 59 |
+
-------
|
| 60 |
+
tuple
|
| 61 |
+
(vector_store, bm25_retriever)
|
| 62 |
+
"""
|
| 63 |
login(token=HF_TOKEN, add_to_git_credential=False)
|
| 64 |
VECTOR_STORE_DIR.mkdir(parents=True, exist_ok=True)
|
| 65 |
|
|
|
|
| 105 |
|
| 106 |
def bm25_search(query: str, top_k: int = 3) -> list[dict]:
|
| 107 |
"""
|
| 108 |
+
Run BM25 keyword search.
|
| 109 |
+
|
| 110 |
+
Parameters
|
| 111 |
+
----------
|
| 112 |
+
query : str
|
| 113 |
+
top_k : int
|
| 114 |
+
|
| 115 |
+
Returns
|
| 116 |
+
-------
|
| 117 |
+
list[dict]
|
| 118 |
+
Top-k retrieved results.
|
| 119 |
"""
|
| 120 |
|
| 121 |
results = search(retriever, query, top_k)
|
|
|
|
| 124 |
|
| 125 |
def semantic_search(query: str, top_k: int = 3) -> list[dict]:
|
| 126 |
"""
|
| 127 |
+
Run semantic (embedding-based) search.
|
| 128 |
+
|
| 129 |
+
Parameters
|
| 130 |
+
----------
|
| 131 |
+
query : str
|
| 132 |
+
top_k : int
|
| 133 |
+
|
| 134 |
+
Returns
|
| 135 |
+
-------
|
| 136 |
+
list[dict]
|
| 137 |
+
Top-k retrieved results with scores.
|
| 138 |
"""
|
| 139 |
|
| 140 |
results = enrich_search_results(vector_store, query, top_k)
|
|
|
|
| 150 |
|
| 151 |
|
| 152 |
def llm_retriever(query: str, top_k: int = 5):
|
| 153 |
+
"""
|
| 154 |
+
Run RAG pipeline using hybrid retriever.
|
| 155 |
+
|
| 156 |
+
Parameters
|
| 157 |
+
----------
|
| 158 |
+
query : str
|
| 159 |
+
top_k : int
|
| 160 |
+
|
| 161 |
+
Returns
|
| 162 |
+
-------
|
| 163 |
+
tuple
|
| 164 |
+
(answer, retrieved_docs, web_sources)
|
| 165 |
+
"""
|
| 166 |
answer, docs, web_sources = run_rag(hybrid_retriever, query=query)
|
| 167 |
return answer, docs, web_sources
|
| 168 |
|
| 169 |
|
| 170 |
# ─── Helpers ──────────────────────────────────────────────────────────────────
|
| 171 |
def stars(rating: float) -> str:
|
| 172 |
+
"""
|
| 173 |
+
Convert numeric rating into star string.
|
| 174 |
+
|
| 175 |
+
Parameters
|
| 176 |
+
----------
|
| 177 |
+
rating : float
|
| 178 |
+
|
| 179 |
+
Returns
|
| 180 |
+
-------
|
| 181 |
+
str
|
| 182 |
+
Star representation (e.g., ★★★★½).
|
| 183 |
+
"""
|
| 184 |
full = int(rating)
|
| 185 |
half = 1 if (rating - full) >= 0.5 else 0
|
| 186 |
empty = 5 - full - half
|
|
|
|
| 188 |
|
| 189 |
|
| 190 |
def log_feedback(query: str, mode: str, asin: str, title: str, vote: str) -> None:
|
| 191 |
+
"""Append user feedback to CSV log."""
|
| 192 |
file_exists = FEEDBACK_CSV.exists()
|
| 193 |
with open(FEEDBACK_CSV, "a", newline="", encoding="utf-8") as f:
|
| 194 |
writer = csv.DictWriter(
|
|
|
|
| 206 |
})
|
| 207 |
|
| 208 |
def render_product(ind, item, mode):
|
| 209 |
+
"""Render a single product card with reviews and feedback buttons."""
|
| 210 |
item = dict(item)
|
| 211 |
if "reviews" in item.keys():
|
| 212 |
reviews = item.get("reviews",{})
|
|
|
|
| 289 |
|
| 290 |
|
| 291 |
def render_results(results: list[dict], mode: str) -> None:
|
| 292 |
+
"""Render a list of product results."""
|
| 293 |
if not results:
|
| 294 |
st.info("No results returned.")
|
| 295 |
return
|
src/bm25.py
CHANGED
|
@@ -365,6 +365,13 @@ def search(
|
|
| 365 |
query: str,
|
| 366 |
top_k: int = 3,
|
| 367 |
) -> list[dict]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
retriever.k = top_k
|
| 369 |
|
| 370 |
# Tokenize query the same way the index was built
|
|
|
|
| 365 |
query: str,
|
| 366 |
top_k: int = 3,
|
| 367 |
) -> list[dict]:
|
| 368 |
+
"""
|
| 369 |
+
Search the BM25Retriever for a query, returning metadata of top-k results.
|
| 370 |
+
|
| 371 |
+
Performs a BM25 keyword search on the indexed documents. Tokenizes the query
|
| 372 |
+
using the same tokenizer as the index, computes BM25 scores for all documents,
|
| 373 |
+
and returns structured metadata (including score) for the top-k matches.
|
| 374 |
+
"""
|
| 375 |
retriever.k = top_k
|
| 376 |
|
| 377 |
# Tokenize query the same way the index was built
|
src/rag_pipeline.py
CHANGED
|
@@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
|
|
| 31 |
# ---------------------------------------------------------------------------
|
| 32 |
# Constants
|
| 33 |
# ---------------------------------------------------------------------------
|
| 34 |
-
DEFAULT_REPO_ID = "
|
| 35 |
DEFAULT_MAX_NEW_TOKENS = 512
|
| 36 |
DEFAULT_TOP_K = 5
|
| 37 |
|
|
@@ -97,7 +97,9 @@ def _maybe_web_search(query: str) -> tuple[str, list[dict]]:
|
|
| 97 |
|
| 98 |
|
| 99 |
def _make_verbose_tap(label: str, verbose: bool):
|
|
|
|
| 100 |
def _tap(value):
|
|
|
|
| 101 |
if verbose:
|
| 102 |
if hasattr(value, "messages"):
|
| 103 |
rendered = "\n".join(
|
|
@@ -115,6 +117,7 @@ def _make_verbose_tap(label: str, verbose: bool):
|
|
| 115 |
|
| 116 |
|
| 117 |
def build_context(docs: list[Document]) -> str:
|
|
|
|
| 118 |
if not isinstance(docs, list):
|
| 119 |
raise TypeError(
|
| 120 |
f"'docs' must be a list of Document objects, got {type(docs).__name__}."
|
|
@@ -139,6 +142,7 @@ def _build_llm(
|
|
| 139 |
max_new_tokens: int,
|
| 140 |
provider: str,
|
| 141 |
) -> ChatHuggingFace:
|
|
|
|
| 142 |
endpoint = HuggingFaceEndpoint(
|
| 143 |
repo_id=repo_id,
|
| 144 |
task="text-generation",
|
|
@@ -149,6 +153,7 @@ def _build_llm(
|
|
| 149 |
|
| 150 |
|
| 151 |
def _build_prompt_template(system_prompt: str) -> ChatPromptTemplate:
|
|
|
|
| 152 |
return ChatPromptTemplate.from_messages([
|
| 153 |
("system", system_prompt),
|
| 154 |
(
|
|
@@ -172,6 +177,7 @@ def run_rag(
|
|
| 172 |
provider: str = "auto",
|
| 173 |
verbose: bool = False,
|
| 174 |
) -> tuple[str, list[Document]]:
|
|
|
|
| 175 |
# ------------------------------------------------------------------
|
| 176 |
# Build chain components
|
| 177 |
# ------------------------------------------------------------------
|
|
@@ -184,6 +190,7 @@ def run_rag(
|
|
| 184 |
retrieved_docs: list[Document] = []
|
| 185 |
|
| 186 |
def _retrieve_and_capture(query: str) -> list[Document]:
|
|
|
|
| 187 |
docs = retriever.invoke(query)
|
| 188 |
retrieved_docs.extend(docs)
|
| 189 |
return docs
|
|
|
|
| 31 |
# ---------------------------------------------------------------------------
|
| 32 |
# Constants
|
| 33 |
# ---------------------------------------------------------------------------
|
| 34 |
+
DEFAULT_REPO_ID = "Qwen/Qwen2.5-7B-Instruct"
|
| 35 |
DEFAULT_MAX_NEW_TOKENS = 512
|
| 36 |
DEFAULT_TOP_K = 5
|
| 37 |
|
|
|
|
| 97 |
|
| 98 |
|
| 99 |
def _make_verbose_tap(label: str, verbose: bool):
|
| 100 |
+
"""Returns a Runnable that prints the value with a label if verbose=True, then passes it through unchanged."""
|
| 101 |
def _tap(value):
|
| 102 |
+
"""Prints the value with a label if verbose=True, then returns it unchanged."""
|
| 103 |
if verbose:
|
| 104 |
if hasattr(value, "messages"):
|
| 105 |
rendered = "\n".join(
|
|
|
|
| 117 |
|
| 118 |
|
| 119 |
def build_context(docs: list[Document]) -> str:
|
| 120 |
+
"""Converts a list of Documents into a single string context for the LLM."""
|
| 121 |
if not isinstance(docs, list):
|
| 122 |
raise TypeError(
|
| 123 |
f"'docs' must be a list of Document objects, got {type(docs).__name__}."
|
|
|
|
| 142 |
max_new_tokens: int,
|
| 143 |
provider: str,
|
| 144 |
) -> ChatHuggingFace:
|
| 145 |
+
"""Initializes a HuggingFaceEndpoint and wraps it in a ChatHuggingFace LLM."""
|
| 146 |
endpoint = HuggingFaceEndpoint(
|
| 147 |
repo_id=repo_id,
|
| 148 |
task="text-generation",
|
|
|
|
| 153 |
|
| 154 |
|
| 155 |
def _build_prompt_template(system_prompt: str) -> ChatPromptTemplate:
|
| 156 |
+
"""Constructs a ChatPromptTemplate with the given system prompt and a fixed human prompt."""
|
| 157 |
return ChatPromptTemplate.from_messages([
|
| 158 |
("system", system_prompt),
|
| 159 |
(
|
|
|
|
| 177 |
provider: str = "auto",
|
| 178 |
verbose: bool = False,
|
| 179 |
) -> tuple[str, list[Document]]:
|
| 180 |
+
"""Runs a Retrieval-Augmented Generation (RAG) chain for a grocery query."""
|
| 181 |
# ------------------------------------------------------------------
|
| 182 |
# Build chain components
|
| 183 |
# ------------------------------------------------------------------
|
|
|
|
| 190 |
retrieved_docs: list[Document] = []
|
| 191 |
|
| 192 |
def _retrieve_and_capture(query: str) -> list[Document]:
|
| 193 |
+
"""Invokes the retriever and captures the retrieved documents for later use."""
|
| 194 |
docs = retriever.invoke(query)
|
| 195 |
retrieved_docs.extend(docs)
|
| 196 |
return docs
|
src/semantic.py
CHANGED
|
@@ -48,6 +48,7 @@ DEFAULT_TOP_K = 5
|
|
| 48 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 49 |
@st.cache_resource(show_spinner=False)
|
| 50 |
def get_embeddings():
|
|
|
|
| 51 |
return HuggingFaceEmbeddings(
|
| 52 |
model_name=DEFAULT_EMBEDDING_MODEL,
|
| 53 |
model_kwargs={
|
|
@@ -207,6 +208,10 @@ def build_and_save_vector_store(
|
|
| 207 |
save_path: str,
|
| 208 |
batch_size: int = 500,
|
| 209 |
) -> FAISS:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
# --- Resume / initialize ---
|
| 212 |
if os.path.exists(os.path.join(save_path, "index.faiss")):
|
|
@@ -297,6 +302,7 @@ def enrich_search_results(vector_store, query: str, k: int, filter=None):
|
|
| 297 |
def load_vector_store(
|
| 298 |
load_path: str,
|
| 299 |
) -> FAISS:
|
|
|
|
| 300 |
|
| 301 |
return FAISS.load_local(
|
| 302 |
load_path,
|
|
|
|
| 48 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 49 |
@st.cache_resource(show_spinner=False)
|
| 50 |
def get_embeddings():
|
| 51 |
+
"""Initializes and returns a HuggingFaceEmbeddings instance with the specified model and device settings."""
|
| 52 |
return HuggingFaceEmbeddings(
|
| 53 |
model_name=DEFAULT_EMBEDDING_MODEL,
|
| 54 |
model_kwargs={
|
|
|
|
| 208 |
save_path: str,
|
| 209 |
batch_size: int = 500,
|
| 210 |
) -> FAISS:
|
| 211 |
+
"""
|
| 212 |
+
Build a FAISS vector store from a metadata Dataset, processing in batches and saving progress.
|
| 213 |
+
This function processes the metadata dataset in batches, creating Documents and embedding them into a FAISS vector store.
|
| 214 |
+
"""
|
| 215 |
|
| 216 |
# --- Resume / initialize ---
|
| 217 |
if os.path.exists(os.path.join(save_path, "index.faiss")):
|
|
|
|
| 302 |
def load_vector_store(
|
| 303 |
load_path: str,
|
| 304 |
) -> FAISS:
|
| 305 |
+
"""Load a FAISS vector store from disk."""
|
| 306 |
|
| 307 |
return FAISS.load_local(
|
| 308 |
load_path,
|
src/utils.py
CHANGED
|
@@ -11,6 +11,7 @@ STOPWORDS = set(stopwords.words('english'))
|
|
| 11 |
|
| 12 |
# Tokenizer
|
| 13 |
def simple_tokenize(text):
|
|
|
|
| 14 |
if not text:
|
| 15 |
return []
|
| 16 |
text = text.lower()
|
|
@@ -49,6 +50,7 @@ def extract_image(row):
|
|
| 49 |
return None
|
| 50 |
|
| 51 |
def decode_ratings(page_content):
|
|
|
|
| 52 |
block_pattern = r'\[\d\.0★\].*'
|
| 53 |
matches = re.findall(block_pattern, page_content)
|
| 54 |
if matches:
|
|
|
|
| 11 |
|
| 12 |
# Tokenizer
|
| 13 |
def simple_tokenize(text):
|
| 14 |
+
"""A simple tokenizer that lowercases text, removes punctuation, and filters out stopwords."""
|
| 15 |
if not text:
|
| 16 |
return []
|
| 17 |
text = text.lower()
|
|
|
|
| 50 |
return None
|
| 51 |
|
| 52 |
def decode_ratings(page_content):
|
| 53 |
+
"""Extracts up to 3 ratings from the page content string, returning a list of dicts with rating, title, and text."""
|
| 54 |
block_pattern = r'\[\d\.0★\].*'
|
| 55 |
matches = re.findall(block_pattern, page_content)
|
| 56 |
if matches:
|