github-actions[bot] commited on
Commit
251d75e
·
1 Parent(s): 845adc6

chore: sync app/ and src/ from GitHub

Browse files
Files changed (5) hide show
  1. app/app.py +58 -8
  2. src/bm25.py +7 -0
  3. src/rag_pipeline.py +8 -1
  4. src/semantic.py +6 -0
  5. src/utils.py +2 -0
app/app.py CHANGED
@@ -52,6 +52,14 @@ VECTOR_STORE_DIR = ROOT / "data" / "processed"
52
 
53
  @st.cache_resource
54
  def load_vector_store_cached():
 
 
 
 
 
 
 
 
55
  login(token=HF_TOKEN, add_to_git_credential=False)
56
  VECTOR_STORE_DIR.mkdir(parents=True, exist_ok=True)
57
 
@@ -97,10 +105,17 @@ else:
97
 
98
  def bm25_search(query: str, top_k: int = 3) -> list[dict]:
99
  """
100
- PLACEHOLDER swap with real BM25Retriever call, e.g.:
101
- retriever = BM25Retriever.load('data/processed/bm25_index.pkl')
102
- return retriever.search(query, top_k=top_k)
103
- Returns top_k review-level results (may include multiple reviews per ASIN).
 
 
 
 
 
 
 
104
  """
105
 
106
  results = search(retriever, query, top_k)
@@ -109,10 +124,17 @@ def bm25_search(query: str, top_k: int = 3) -> list[dict]:
109
 
110
  def semantic_search(query: str, top_k: int = 3) -> list[dict]:
111
  """
112
- PLACEHOLDER swap with real SemanticRetriever call, e.g.:
113
- retriever = SemanticRetriever.load('data/processed/faiss_index')
114
- return retriever.search(query, top_k=top_k)
115
- Returns top_k review-level results (scores are cosine similarities, 0–1).
 
 
 
 
 
 
 
116
  """
117
 
118
  results = enrich_search_results(vector_store, query, top_k)
@@ -128,12 +150,37 @@ hybrid_retriever = HybridRetriever(
128
 
129
 
130
  def llm_retriever(query: str, top_k: int = 5):
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  answer, docs, web_sources = run_rag(hybrid_retriever, query=query)
132
  return answer, docs, web_sources
133
 
134
 
135
  # ─── Helpers ──────────────────────────────────────────────────────────────────
136
  def stars(rating: float) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
137
  full = int(rating)
138
  half = 1 if (rating - full) >= 0.5 else 0
139
  empty = 5 - full - half
@@ -141,6 +188,7 @@ def stars(rating: float) -> str:
141
 
142
 
143
  def log_feedback(query: str, mode: str, asin: str, title: str, vote: str) -> None:
 
144
  file_exists = FEEDBACK_CSV.exists()
145
  with open(FEEDBACK_CSV, "a", newline="", encoding="utf-8") as f:
146
  writer = csv.DictWriter(
@@ -158,6 +206,7 @@ def log_feedback(query: str, mode: str, asin: str, title: str, vote: str) -> Non
158
  })
159
 
160
  def render_product(ind, item, mode):
 
161
  item = dict(item)
162
  if "reviews" in item.keys():
163
  reviews = item.get("reviews",{})
@@ -240,6 +289,7 @@ def render_product(ind, item, mode):
240
 
241
 
242
  def render_results(results: list[dict], mode: str) -> None:
 
243
  if not results:
244
  st.info("No results returned.")
245
  return
 
52
 
53
  @st.cache_resource
54
  def load_vector_store_cached():
55
+ """
56
+ Load vector store and BM25 index from Hugging Face or local cache.
57
+
58
+ Returns
59
+ -------
60
+ tuple
61
+ (vector_store, bm25_retriever)
62
+ """
63
  login(token=HF_TOKEN, add_to_git_credential=False)
64
  VECTOR_STORE_DIR.mkdir(parents=True, exist_ok=True)
65
 
 
105
 
106
  def bm25_search(query: str, top_k: int = 3) -> list[dict]:
107
  """
108
+ Run BM25 keyword search.
109
+
110
+ Parameters
111
+ ----------
112
+ query : str
113
+ top_k : int
114
+
115
+ Returns
116
+ -------
117
+ list[dict]
118
+ Top-k retrieved results.
119
  """
120
 
121
  results = search(retriever, query, top_k)
 
124
 
125
  def semantic_search(query: str, top_k: int = 3) -> list[dict]:
126
  """
127
+ Run semantic (embedding-based) search.
128
+
129
+ Parameters
130
+ ----------
131
+ query : str
132
+ top_k : int
133
+
134
+ Returns
135
+ -------
136
+ list[dict]
137
+ Top-k retrieved results with scores.
138
  """
139
 
140
  results = enrich_search_results(vector_store, query, top_k)
 
150
 
151
 
152
  def llm_retriever(query: str, top_k: int = 5):
153
+ """
154
+ Run RAG pipeline using hybrid retriever.
155
+
156
+ Parameters
157
+ ----------
158
+ query : str
159
+ top_k : int
160
+
161
+ Returns
162
+ -------
163
+ tuple
164
+ (answer, retrieved_docs, web_sources)
165
+ """
166
  answer, docs, web_sources = run_rag(hybrid_retriever, query=query)
167
  return answer, docs, web_sources
168
 
169
 
170
  # ─── Helpers ──────────────────────────────────────────────────────────────────
171
  def stars(rating: float) -> str:
172
+ """
173
+ Convert numeric rating into star string.
174
+
175
+ Parameters
176
+ ----------
177
+ rating : float
178
+
179
+ Returns
180
+ -------
181
+ str
182
+ Star representation (e.g., ★★★★½).
183
+ """
184
  full = int(rating)
185
  half = 1 if (rating - full) >= 0.5 else 0
186
  empty = 5 - full - half
 
188
 
189
 
190
  def log_feedback(query: str, mode: str, asin: str, title: str, vote: str) -> None:
191
+ """Append user feedback to CSV log."""
192
  file_exists = FEEDBACK_CSV.exists()
193
  with open(FEEDBACK_CSV, "a", newline="", encoding="utf-8") as f:
194
  writer = csv.DictWriter(
 
206
  })
207
 
208
  def render_product(ind, item, mode):
209
+ """Render a single product card with reviews and feedback buttons."""
210
  item = dict(item)
211
  if "reviews" in item.keys():
212
  reviews = item.get("reviews",{})
 
289
 
290
 
291
  def render_results(results: list[dict], mode: str) -> None:
292
+ """Render a list of product results."""
293
  if not results:
294
  st.info("No results returned.")
295
  return
src/bm25.py CHANGED
@@ -365,6 +365,13 @@ def search(
365
  query: str,
366
  top_k: int = 3,
367
  ) -> list[dict]:
 
 
 
 
 
 
 
368
  retriever.k = top_k
369
 
370
  # Tokenize query the same way the index was built
 
365
  query: str,
366
  top_k: int = 3,
367
  ) -> list[dict]:
368
+ """
369
+ Search the BM25Retriever for a query, returning metadata of top-k results.
370
+
371
+ Performs a BM25 keyword search on the indexed documents. Tokenizes the query
372
+ using the same tokenizer as the index, computes BM25 scores for all documents,
373
+ and returns structured metadata (including score) for the top-k matches.
374
+ """
375
  retriever.k = top_k
376
 
377
  # Tokenize query the same way the index was built
src/rag_pipeline.py CHANGED
@@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
31
  # ---------------------------------------------------------------------------
32
  # Constants
33
  # ---------------------------------------------------------------------------
34
- DEFAULT_REPO_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
35
  DEFAULT_MAX_NEW_TOKENS = 512
36
  DEFAULT_TOP_K = 5
37
 
@@ -97,7 +97,9 @@ def _maybe_web_search(query: str) -> tuple[str, list[dict]]:
97
 
98
 
99
  def _make_verbose_tap(label: str, verbose: bool):
 
100
  def _tap(value):
 
101
  if verbose:
102
  if hasattr(value, "messages"):
103
  rendered = "\n".join(
@@ -115,6 +117,7 @@ def _make_verbose_tap(label: str, verbose: bool):
115
 
116
 
117
  def build_context(docs: list[Document]) -> str:
 
118
  if not isinstance(docs, list):
119
  raise TypeError(
120
  f"'docs' must be a list of Document objects, got {type(docs).__name__}."
@@ -139,6 +142,7 @@ def _build_llm(
139
  max_new_tokens: int,
140
  provider: str,
141
  ) -> ChatHuggingFace:
 
142
  endpoint = HuggingFaceEndpoint(
143
  repo_id=repo_id,
144
  task="text-generation",
@@ -149,6 +153,7 @@ def _build_llm(
149
 
150
 
151
  def _build_prompt_template(system_prompt: str) -> ChatPromptTemplate:
 
152
  return ChatPromptTemplate.from_messages([
153
  ("system", system_prompt),
154
  (
@@ -172,6 +177,7 @@ def run_rag(
172
  provider: str = "auto",
173
  verbose: bool = False,
174
  ) -> tuple[str, list[Document]]:
 
175
  # ------------------------------------------------------------------
176
  # Build chain components
177
  # ------------------------------------------------------------------
@@ -184,6 +190,7 @@ def run_rag(
184
  retrieved_docs: list[Document] = []
185
 
186
  def _retrieve_and_capture(query: str) -> list[Document]:
 
187
  docs = retriever.invoke(query)
188
  retrieved_docs.extend(docs)
189
  return docs
 
31
  # ---------------------------------------------------------------------------
32
  # Constants
33
  # ---------------------------------------------------------------------------
34
+ DEFAULT_REPO_ID = "Qwen/Qwen2.5-7B-Instruct"
35
  DEFAULT_MAX_NEW_TOKENS = 512
36
  DEFAULT_TOP_K = 5
37
 
 
97
 
98
 
99
  def _make_verbose_tap(label: str, verbose: bool):
100
+ """Returns a Runnable that prints the value with a label if verbose=True, then passes it through unchanged."""
101
  def _tap(value):
102
+ """Prints the value with a label if verbose=True, then returns it unchanged."""
103
  if verbose:
104
  if hasattr(value, "messages"):
105
  rendered = "\n".join(
 
117
 
118
 
119
  def build_context(docs: list[Document]) -> str:
120
+ """Converts a list of Documents into a single string context for the LLM."""
121
  if not isinstance(docs, list):
122
  raise TypeError(
123
  f"'docs' must be a list of Document objects, got {type(docs).__name__}."
 
142
  max_new_tokens: int,
143
  provider: str,
144
  ) -> ChatHuggingFace:
145
+ """Initializes a HuggingFaceEndpoint and wraps it in a ChatHuggingFace LLM."""
146
  endpoint = HuggingFaceEndpoint(
147
  repo_id=repo_id,
148
  task="text-generation",
 
153
 
154
 
155
  def _build_prompt_template(system_prompt: str) -> ChatPromptTemplate:
156
+ """Constructs a ChatPromptTemplate with the given system prompt and a fixed human prompt."""
157
  return ChatPromptTemplate.from_messages([
158
  ("system", system_prompt),
159
  (
 
177
  provider: str = "auto",
178
  verbose: bool = False,
179
  ) -> tuple[str, list[Document]]:
180
+ """Runs a Retrieval-Augmented Generation (RAG) chain for a grocery query."""
181
  # ------------------------------------------------------------------
182
  # Build chain components
183
  # ------------------------------------------------------------------
 
190
  retrieved_docs: list[Document] = []
191
 
192
  def _retrieve_and_capture(query: str) -> list[Document]:
193
+ """Invokes the retriever and captures the retrieved documents for later use."""
194
  docs = retriever.invoke(query)
195
  retrieved_docs.extend(docs)
196
  return docs
src/semantic.py CHANGED
@@ -48,6 +48,7 @@ DEFAULT_TOP_K = 5
48
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
49
  @st.cache_resource(show_spinner=False)
50
  def get_embeddings():
 
51
  return HuggingFaceEmbeddings(
52
  model_name=DEFAULT_EMBEDDING_MODEL,
53
  model_kwargs={
@@ -207,6 +208,10 @@ def build_and_save_vector_store(
207
  save_path: str,
208
  batch_size: int = 500,
209
  ) -> FAISS:
 
 
 
 
210
 
211
  # --- Resume / initialize ---
212
  if os.path.exists(os.path.join(save_path, "index.faiss")):
@@ -297,6 +302,7 @@ def enrich_search_results(vector_store, query: str, k: int, filter=None):
297
  def load_vector_store(
298
  load_path: str,
299
  ) -> FAISS:
 
300
 
301
  return FAISS.load_local(
302
  load_path,
 
48
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
49
  @st.cache_resource(show_spinner=False)
50
  def get_embeddings():
51
+ """Initializes and returns a HuggingFaceEmbeddings instance with the specified model and device settings."""
52
  return HuggingFaceEmbeddings(
53
  model_name=DEFAULT_EMBEDDING_MODEL,
54
  model_kwargs={
 
208
  save_path: str,
209
  batch_size: int = 500,
210
  ) -> FAISS:
211
+ """
212
+ Build a FAISS vector store from a metadata Dataset, processing in batches and saving progress.
213
+ This function processes the metadata dataset in batches, creating Documents and embedding them into a FAISS vector store.
214
+ """
215
 
216
  # --- Resume / initialize ---
217
  if os.path.exists(os.path.join(save_path, "index.faiss")):
 
302
  def load_vector_store(
303
  load_path: str,
304
  ) -> FAISS:
305
+ """Load a FAISS vector store from disk."""
306
 
307
  return FAISS.load_local(
308
  load_path,
src/utils.py CHANGED
@@ -11,6 +11,7 @@ STOPWORDS = set(stopwords.words('english'))
11
 
12
  # Tokenizer
13
  def simple_tokenize(text):
 
14
  if not text:
15
  return []
16
  text = text.lower()
@@ -49,6 +50,7 @@ def extract_image(row):
49
  return None
50
 
51
  def decode_ratings(page_content):
 
52
  block_pattern = r'\[\d\.0★\].*'
53
  matches = re.findall(block_pattern, page_content)
54
  if matches:
 
11
 
12
  # Tokenizer
13
  def simple_tokenize(text):
14
+ """A simple tokenizer that lowercases text, removes punctuation, and filters out stopwords."""
15
  if not text:
16
  return []
17
  text = text.lower()
 
50
  return None
51
 
52
  def decode_ratings(page_content):
53
+ """Extracts up to 3 ratings from the page content string, returning a list of dicts with rating, title, and text."""
54
  block_pattern = r'\[\d\.0★\].*'
55
  matches = re.findall(block_pattern, page_content)
56
  if matches: