github-actions[bot] commited on
Commit
dff5f2a
·
1 Parent(s): 2ef07e4

chore: sync app/ and src/ from GitHub

Browse files
Files changed (3) hide show
  1. app/app.py +13 -3
  2. src/rag_pipeline.py +47 -130
  3. src/tools.py +17 -0
app/app.py CHANGED
@@ -126,9 +126,10 @@ hybrid_retriever = HybridRetriever(
126
  semantic_weight=0.5,
127
  )
128
 
 
129
  def llm_retriever(query: str, top_k: int = 5):
130
- answer, docs = run_rag(hybrid_retriever, query=query)
131
- return answer, docs
132
 
133
 
134
  # ─── Helpers ──────────────────────────────────────────────────────────────────
@@ -272,12 +273,14 @@ if query.strip() and query != st.session_state.get("last_query"):
272
 
273
  with st.spinner("Asking AI..."):
274
  try:
275
- answer, docs = llm_retriever(query, top_k=TOP_K)
276
  st.session_state.llm_result = answer
277
  st.session_state.llm_docs = docs
 
278
  except Exception as e:
279
  st.session_state.llm_result = f"**Error:** {e}"
280
  st.session_state.llm_docs = []
 
281
 
282
  elif not query.strip():
283
  # Clear results when input is emptied
@@ -338,6 +341,13 @@ with tab_llm:
338
  else:
339
  st.markdown("<p style='color:#aaa;'>No documents retrieved.</p>", unsafe_allow_html=True)
340
 
 
 
 
 
 
 
 
341
  # ─── Sidebar: feedback log ────────────────────────────────────────────────────
342
  with st.sidebar:
343
  st.header("📋 Feedback Log")
 
126
  semantic_weight=0.5,
127
  )
128
 
129
+
130
  def llm_retriever(query: str, top_k: int = 5):
131
+ answer, docs, web_sources = run_rag(hybrid_retriever, query=query)
132
+ return answer, docs, web_sources
133
 
134
 
135
  # ─── Helpers ──────────────────────────────────────────────────────────────────
 
273
 
274
  with st.spinner("Asking AI..."):
275
  try:
276
+ answer, docs, web_sources = llm_retriever(query, top_k=TOP_K)
277
  st.session_state.llm_result = answer
278
  st.session_state.llm_docs = docs
279
+ st.session_state.web_sources = web_sources
280
  except Exception as e:
281
  st.session_state.llm_result = f"**Error:** {e}"
282
  st.session_state.llm_docs = []
283
+ st.session_state.web_sources = []
284
 
285
  elif not query.strip():
286
  # Clear results when input is emptied
 
341
  else:
342
  st.markdown("<p style='color:#aaa;'>No documents retrieved.</p>", unsafe_allow_html=True)
343
 
344
+ # ── Web sources ───────────────────────────────────────────────────────
345
+ sources = st.session_state.get("web_sources", [])
346
+ if sources:
347
+ st.markdown("#### 🌐 Web Sources")
348
+ for s in sources:
349
+ st.markdown(f"- [{s['title']}]({s['url']})")
350
+
351
  # ─── Sidebar: feedback log ────────────────────────────────────────────────────
352
  with st.sidebar:
353
  st.header("📋 Feedback Log")
src/rag_pipeline.py CHANGED
@@ -20,7 +20,9 @@ from langchain_core.documents import Document
20
  from langchain_core.output_parsers import StrOutputParser
21
  from langchain_core.prompts import ChatPromptTemplate
22
  from langchain_core.runnables import RunnableLambda, RunnablePassthrough
 
23
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
 
24
  # ---------------------------------------------------------------------------
25
  # Logging
26
  # ---------------------------------------------------------------------------
@@ -36,6 +38,12 @@ DEFAULT_TOP_K = 5
36
  DEFAULT_SYSTEM_PROMPT = (
37
  "You are a helpful Amazon grocery shopping assistant.\n\n"
38
  "You will receive a grocery query and a list of related Amazon products (including reviews and metadata).\n\n"
 
 
 
 
 
 
39
  "Your response must follow this exact structure:\n\n"
40
  "---\n\n"
41
  "## 🛒 Recommended Products\n"
@@ -51,65 +59,62 @@ DEFAULT_SYSTEM_PROMPT = (
51
  "- Keep descriptions factual and grounded in the provided reviews and metadata.\n"
52
  "- Recipe ideas should be suggestions or ideas only, not step-by-step instructions.\n"
53
  "- Format the entire response in Markdown.\n"
 
54
  "- IMPORTANT: Whenever citing the product title: add the parent_asin in the following format [title](#parent_asin)"
55
  )
56
 
57
  # ---------------------------------------------------------------------------
58
  # Helper functions
59
  # ---------------------------------------------------------------------------
60
-
61
- import logging
62
  from langchain_core.runnables import RunnableLambda
63
 
64
- logger = logging.getLogger(__name__)
 
 
 
 
 
65
 
66
- def _make_verbose_tap(label: str, verbose: bool):
67
  """
68
- Returns a passthrough RunnableLambda that logs *value* when verbose=True.
69
- Works for any chain step — docs, prompt messages, or raw strings.
70
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def _tap(value):
72
  if verbose:
73
- if hasattr(value, "messages"): # ChatPromptValue
74
  rendered = "\n".join(
75
  f"[{m.type.upper()}]: {m.content}"
76
  for m in value.messages
77
  )
78
- elif isinstance(value, list): # list of Documents
79
  rendered = "\n".join(str(d) for d in value)
80
  else:
81
  rendered = str(value)
82
-
83
  print(f"\n{'='*60}\n{label}\n{'='*60}\n{rendered}\n")
84
  logger.debug("%s\n%s", label, rendered)
85
  return value
86
  return RunnableLambda(_tap)
87
 
88
- def build_context(docs: list[Document]) -> str:
89
- """
90
- Concatenate a list of retrieved LangChain Documents into a single
91
- context string that the LLM can reason over.
92
-
93
- Each entry includes the product's ``parent_asin`` (falling back to its
94
- position index), its page content, and its full metadata dict.
95
-
96
- Parameters
97
- ----------
98
- docs:
99
- List of ``langchain_core.documents.Document`` objects returned by
100
- the retriever.
101
 
102
- Returns
103
- -------
104
- str
105
- A newline-separated block of product descriptions ready for prompt
106
- injection. Returns an empty string when *docs* is empty.
107
-
108
- Raises
109
- ------
110
- TypeError
111
- If *docs* is not a list, or any element is not a ``Document``.
112
- """
113
  if not isinstance(docs, list):
114
  raise TypeError(
115
  f"'docs' must be a list of Document objects, got {type(docs).__name__}."
@@ -119,11 +124,9 @@ def build_context(docs: list[Document]) -> str:
119
  raise TypeError(
120
  f"Element at index {i} is not a Document; got {type(doc).__name__}."
121
  )
122
-
123
  if not docs:
124
  logger.warning("build_context received an empty document list.")
125
  return ""
126
-
127
  return "\n\n".join(
128
  f"ASIN {doc.metadata.get('parent_asin', n)} Description: {doc.page_content}\n"
129
  f"Metadata: {doc.metadata}"
@@ -136,26 +139,6 @@ def _build_llm(
136
  max_new_tokens: int,
137
  provider: str,
138
  ) -> ChatHuggingFace:
139
- """
140
- Instantiate and return a ``ChatHuggingFace`` model backed by a
141
- HuggingFace Inference Endpoint.
142
-
143
- Parameters
144
- ----------
145
- repo_id:
146
- HuggingFace Hub model identifier (e.g.
147
- ``"meta-llama/Meta-Llama-3-8B-Instruct"``).
148
- max_new_tokens:
149
- Maximum number of tokens the model may generate per call.
150
- provider:
151
- Inference provider passed to ``HuggingFaceEndpoint``
152
- (``"auto"``, ``"novita"``, etc.).
153
-
154
- Returns
155
- -------
156
- ChatHuggingFace
157
- A chat-compatible wrapper around the endpoint.
158
- """
159
  endpoint = HuggingFaceEndpoint(
160
  repo_id=repo_id,
161
  task="text-generation",
@@ -166,19 +149,6 @@ def _build_llm(
166
 
167
 
168
  def _build_prompt_template(system_prompt: str) -> ChatPromptTemplate:
169
- """
170
- Create a ``ChatPromptTemplate`` with a system message and a human
171
- turn that injects ``{context}`` and ``{question}`` placeholders.
172
-
173
- Parameters
174
- ----------
175
- system_prompt:
176
- The system-level instruction string.
177
-
178
- Returns
179
- -------
180
- ChatPromptTemplate
181
- """
182
  return ChatPromptTemplate.from_messages([
183
  ("system", system_prompt),
184
  (
@@ -201,87 +171,34 @@ def run_rag(
201
  max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
202
  provider: str = "auto",
203
  verbose: bool = False,
204
- ) -> str:
205
- """
206
- Execute a full RAG pipeline and return the model's answer.
207
-
208
- The pipeline follows the steps below:
209
-
210
- 1. **Retrieve** - *retriever* fetches the *k* most relevant documents
211
- for *query*.
212
- 2. **Format context** - :func:`build_context` serialises the documents
213
- into a single string.
214
- 3. **Prompt** - the context and query are injected into the chat prompt
215
- template.
216
- 4. **Generate** - the LLM produces an answer grounded in the context.
217
- 5. **Parse** - the raw chat message is unwrapped to a plain string.
218
-
219
- Parameters
220
- ----------
221
- retriever:
222
- A LangChain-compatible retriever (must expose ``.invoke()`` and be
223
- pipeable with ``|``). Typically created via
224
- ``vectorstore.as_retriever(...)``.
225
- query:
226
- Natural-language question to answer (non-empty string).
227
- system_prompt:
228
- System-level instruction for the assistant. Defaults to
229
- :data:`DEFAULT_SYSTEM_PROMPT`.
230
- repo_id:
231
- HuggingFace Hub model identifier. Defaults to
232
- ``"meta-llama/Meta-Llama-3-8B-Instruct"``.
233
- max_new_tokens:
234
- Upper bound on generated tokens. Must be a positive integer.
235
- Defaults to ``100``.
236
- provider:
237
- HuggingFace inference provider (e.g. ``"auto"``, ``"novita"``).
238
- Defaults to ``"auto"``.
239
-
240
- Returns
241
- -------
242
- str
243
- The model's answer as a plain string.
244
-
245
- Raises
246
- ------
247
- TypeError
248
- If *retriever* is ``None``, *query* is not a string, or
249
- *system_prompt* is not a string.
250
- ValueError
251
- If *query* is blank, *max_new_tokens* is not a positive integer,
252
- or *repo_id* / *provider* are blank strings.
253
-
254
- Examples
255
- --------
256
- >>> answer = run_rag(retriever, "Best waterproof mascara under $20")
257
- >>> print(answer)
258
- """
259
  # ------------------------------------------------------------------
260
  # Build chain components
261
  # ------------------------------------------------------------------
262
-
263
  logger.info("Initialising LLM endpoint: %s", repo_id)
264
  llm = _build_llm(repo_id, max_new_tokens, provider)
265
  prompt_template = _build_prompt_template(system_prompt)
266
 
267
- retrieved_docs: list[Document] = [] # ← capture target
 
 
268
 
269
  def _retrieve_and_capture(query: str) -> list[Document]:
270
- """Invoke the retriever and snapshot the results for the caller."""
271
  docs = retriever.invoke(query)
272
- retrieved_docs.extend(docs) # ← populate closure variable
273
- return docs # ← pass through to build_context
274
 
275
  rag_chain = (
276
  {
277
  "context": RunnableLambda(_retrieve_and_capture)
278
  | RunnableLambda(build_context)
 
279
  | _make_verbose_tap("RETRIEVED CONTEXT", verbose),
280
  "question": RunnablePassthrough(),
281
  }
282
  | _make_verbose_tap("PROMPT INPUTS (context + question)", verbose)
283
  | prompt_template
284
- | _make_verbose_tap("RENDERED PROMPT SENT TO LLM", verbose) # ← shows exact prompt
285
  | llm
286
  | StrOutputParser()
287
  )
@@ -293,4 +210,4 @@ def run_rag(
293
  answer: str = rag_chain.invoke(query)
294
  logger.debug("RAG answer: %s", answer)
295
 
296
- return answer, retrieved_docs
 
20
  from langchain_core.output_parsers import StrOutputParser
21
  from langchain_core.prompts import ChatPromptTemplate
22
  from langchain_core.runnables import RunnableLambda, RunnablePassthrough
23
+ import os
24
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
25
+
26
  # ---------------------------------------------------------------------------
27
  # Logging
28
  # ---------------------------------------------------------------------------
 
38
  DEFAULT_SYSTEM_PROMPT = (
39
  "You are a helpful Amazon grocery shopping assistant.\n\n"
40
  "You will receive a grocery query and a list of related Amazon products (including reviews and metadata).\n\n"
41
+
42
+ "If the context contains a section starting with 'Web search results', "
43
+ "incorporate that pricing or availability information naturally into your answer — "
44
+ "do not copy it verbatim or list raw numbers. Sources will be displayed separately, "
45
+ "so you do not need to include URLs in your response.\n\n"
46
+
47
  "Your response must follow this exact structure:\n\n"
48
  "---\n\n"
49
  "## 🛒 Recommended Products\n"
 
59
  "- Keep descriptions factual and grounded in the provided reviews and metadata.\n"
60
  "- Recipe ideas should be suggestions or ideas only, not step-by-step instructions.\n"
61
  "- Format the entire response in Markdown.\n"
62
+ "- If any information comes from a web search, cite the source inline as [source](url).\n"
63
  "- IMPORTANT: Whenever citing the product title: add the parent_asin in the following format [title](#parent_asin)"
64
  )
65
 
66
  # ---------------------------------------------------------------------------
67
  # Helper functions
68
  # ---------------------------------------------------------------------------
 
 
69
  from langchain_core.runnables import RunnableLambda
70
 
71
+ # Keyword triggers that suggest the query needs external/current information
72
+ _WEB_SEARCH_TRIGGERS = {
73
+ "price", "cost", "available", "availability", "recall", "news",
74
+ "latest", "current", "today", "recently", "substitute", "substitution",
75
+ "allergen", "gluten", "vegan", "organic", "nutrition", "calories",
76
+ }
77
 
78
+ def _maybe_web_search(query: str) -> tuple[str, list[dict]]:
79
  """
80
+ Returns (context_string, sources_list) where sources_list is
81
+ [{"title": ..., "url": ...}, ...] for clean rendering.
82
  """
83
+ tokens = set(query.lower().split())
84
+ if tokens & _WEB_SEARCH_TRIGGERS:
85
+ try:
86
+ from tavily import TavilyClient
87
+ client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
88
+ response = client.search(query, max_results=3)
89
+ results = response.get("results", [])
90
+ snippets = "\n\n".join(r["content"] for r in results)
91
+ sources = [{"title": r.get("title", r["url"]), "url": r["url"]} for r in results]
92
+ context = f"\n\nWeb search results (use this to answer pricing/availability questions):\n{snippets}"
93
+ return context, sources
94
+ except Exception as e:
95
+ logger.warning("Web search failed: %s", e)
96
+ return "", []
97
+
98
+
99
+ def _make_verbose_tap(label: str, verbose: bool):
100
  def _tap(value):
101
  if verbose:
102
+ if hasattr(value, "messages"):
103
  rendered = "\n".join(
104
  f"[{m.type.upper()}]: {m.content}"
105
  for m in value.messages
106
  )
107
+ elif isinstance(value, list):
108
  rendered = "\n".join(str(d) for d in value)
109
  else:
110
  rendered = str(value)
 
111
  print(f"\n{'='*60}\n{label}\n{'='*60}\n{rendered}\n")
112
  logger.debug("%s\n%s", label, rendered)
113
  return value
114
  return RunnableLambda(_tap)
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ def build_context(docs: list[Document]) -> str:
 
 
 
 
 
 
 
 
 
 
118
  if not isinstance(docs, list):
119
  raise TypeError(
120
  f"'docs' must be a list of Document objects, got {type(docs).__name__}."
 
124
  raise TypeError(
125
  f"Element at index {i} is not a Document; got {type(doc).__name__}."
126
  )
 
127
  if not docs:
128
  logger.warning("build_context received an empty document list.")
129
  return ""
 
130
  return "\n\n".join(
131
  f"ASIN {doc.metadata.get('parent_asin', n)} Description: {doc.page_content}\n"
132
  f"Metadata: {doc.metadata}"
 
139
  max_new_tokens: int,
140
  provider: str,
141
  ) -> ChatHuggingFace:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  endpoint = HuggingFaceEndpoint(
143
  repo_id=repo_id,
144
  task="text-generation",
 
149
 
150
 
151
  def _build_prompt_template(system_prompt: str) -> ChatPromptTemplate:
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  return ChatPromptTemplate.from_messages([
153
  ("system", system_prompt),
154
  (
 
171
  max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
172
  provider: str = "auto",
173
  verbose: bool = False,
174
+ ) -> tuple[str, list[Document]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  # ------------------------------------------------------------------
176
  # Build chain components
177
  # ------------------------------------------------------------------
 
178
  logger.info("Initialising LLM endpoint: %s", repo_id)
179
  llm = _build_llm(repo_id, max_new_tokens, provider)
180
  prompt_template = _build_prompt_template(system_prompt)
181
 
182
+ web_context, web_sources = _maybe_web_search(query)
183
+
184
+ retrieved_docs: list[Document] = []
185
 
186
  def _retrieve_and_capture(query: str) -> list[Document]:
 
187
  docs = retriever.invoke(query)
188
+ retrieved_docs.extend(docs)
189
+ return docs
190
 
191
  rag_chain = (
192
  {
193
  "context": RunnableLambda(_retrieve_and_capture)
194
  | RunnableLambda(build_context)
195
+ | RunnableLambda(lambda ctx: ctx + web_context)
196
  | _make_verbose_tap("RETRIEVED CONTEXT", verbose),
197
  "question": RunnablePassthrough(),
198
  }
199
  | _make_verbose_tap("PROMPT INPUTS (context + question)", verbose)
200
  | prompt_template
201
+ | _make_verbose_tap("RENDERED PROMPT SENT TO LLM", verbose)
202
  | llm
203
  | StrOutputParser()
204
  )
 
210
  answer: str = rag_chain.invoke(query)
211
  logger.debug("RAG answer: %s", answer)
212
 
213
+ return answer, retrieved_docs, web_sources
src/tools.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain.tools import tool
3
+
4
+ @tool
5
+ def web_search(query: str, max_results: int = 3) -> str:
6
+ """
7
+ Search the web for current information about a grocery or gourmet food product.
8
+ Use this when the user asks about recent news, current pricing, availability,
9
+ updated nutritional info, or anything unlikely to be in the product review corpus.
10
+ Input should be a specific product name or question.
11
+ """
12
+
13
+ from tavily import TavilyClient
14
+ client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
15
+ results = client.search(query, max_results=max_results)
16
+ snippets = [r["content"] for r in results.get("results", [])]
17
+ return "\n\n".join(snippets) if snippets else "No results found."