amazon_retriever / src /rag_pipeline.py
github-actions[bot]
chore: sync app/ and src/ from GitHub
251d75e
"""
rag_chain.py
------------
Amazon product RAG (Retrieval-Augmented Generation) pipeline using
LangChain + HuggingFace Inference Endpoints.
Typical usage
-------------
>>> from rag_chain import run_rag
>>> answer = run_rag(retriever, "Moisturizing shampoo for thick curly hair")
>>> print(answer)
"""
from __future__ import annotations
import logging
from typing import Any
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
import os
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
DEFAULT_REPO_ID = "Qwen/Qwen2.5-7B-Instruct"
DEFAULT_MAX_NEW_TOKENS = 512
DEFAULT_TOP_K = 5
DEFAULT_SYSTEM_PROMPT = (
"You are a helpful Amazon grocery shopping assistant.\n\n"
"You will receive a grocery query and a list of related Amazon products (including reviews and metadata).\n\n"
"If the context contains a section starting with 'Web search results', "
"incorporate that pricing or availability information naturally into your answer — "
"do not copy it verbatim or list raw numbers. Sources will be displayed separately, "
"so you do not need to include URLs in your response.\n\n"
"Your response must follow this exact structure:\n\n"
"---\n\n"
"## 🛒 Recommended Products\n"
"For each product, write a numbered list entry, mentioning products by title "
"followed by 1–2 sentences describing the product and why it suits the query.\n\n"
"## 💡 Tips & Recipe Ideas\n"
"A bullet-point list of practical tips, storage advice, and brief recipe ideas related to the products above "
"(do NOT write out full recipes — keep each idea to 1–2 sentences)."
"Add food emojis if relevant.\n\n"
"---\n\n"
"Rules:\n"
"- Do not invent products. Only recommend products from the provided list.\n"
"- Keep descriptions factual and grounded in the provided reviews and metadata.\n"
"- Recipe ideas should be suggestions or ideas only, not step-by-step instructions.\n"
"- Format the entire response in Markdown.\n"
"- If any information comes from a web search, cite the source inline as [source](url).\n"
"- IMPORTANT: Whenever citing the product title: add the parent_asin in the following format [title](#parent_asin)"
)
# ---------------------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------------------
from langchain_core.runnables import RunnableLambda
# Keyword triggers that suggest the query needs external/current information
_WEB_SEARCH_TRIGGERS = {
"price", "cost", "available", "availability", "recall", "news",
"latest", "current", "today", "recently", "substitute", "substitution",
"allergen", "gluten", "vegan", "organic", "nutrition", "calories",
}
def _maybe_web_search(query: str) -> tuple[str, list[dict]]:
"""
Returns (context_string, sources_list) where sources_list is
[{"title": ..., "url": ...}, ...] for clean rendering.
"""
tokens = set(query.lower().split())
if tokens & _WEB_SEARCH_TRIGGERS:
try:
from tavily import TavilyClient
client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
response = client.search(query, max_results=3)
results = response.get("results", [])
snippets = "\n\n".join(r["content"] for r in results)
sources = [{"title": r.get("title", r["url"]), "url": r["url"]} for r in results]
context = f"\n\nWeb search results (use this to answer pricing/availability questions):\n{snippets}"
return context, sources
except Exception as e:
logger.warning("Web search failed: %s", e)
return "", []
def _make_verbose_tap(label: str, verbose: bool):
"""Returns a Runnable that prints the value with a label if verbose=True, then passes it through unchanged."""
def _tap(value):
"""Prints the value with a label if verbose=True, then returns it unchanged."""
if verbose:
if hasattr(value, "messages"):
rendered = "\n".join(
f"[{m.type.upper()}]: {m.content}"
for m in value.messages
)
elif isinstance(value, list):
rendered = "\n".join(str(d) for d in value)
else:
rendered = str(value)
print(f"\n{'='*60}\n{label}\n{'='*60}\n{rendered}\n")
logger.debug("%s\n%s", label, rendered)
return value
return RunnableLambda(_tap)
def build_context(docs: list[Document]) -> str:
"""Converts a list of Documents into a single string context for the LLM."""
if not isinstance(docs, list):
raise TypeError(
f"'docs' must be a list of Document objects, got {type(docs).__name__}."
)
for i, doc in enumerate(docs):
if not isinstance(doc, Document):
raise TypeError(
f"Element at index {i} is not a Document; got {type(doc).__name__}."
)
if not docs:
logger.warning("build_context received an empty document list.")
return ""
return "\n\n".join(
f"ASIN {doc.metadata.get('parent_asin', n)} Description: {doc.page_content}\n"
f"Metadata: {doc.metadata}"
for n, doc in enumerate(docs)
)
def _build_llm(
repo_id: str,
max_new_tokens: int,
provider: str,
) -> ChatHuggingFace:
"""Initializes a HuggingFaceEndpoint and wraps it in a ChatHuggingFace LLM."""
endpoint = HuggingFaceEndpoint(
repo_id=repo_id,
task="text-generation",
max_new_tokens=max_new_tokens,
provider=provider,
)
return ChatHuggingFace(llm=endpoint)
def _build_prompt_template(system_prompt: str) -> ChatPromptTemplate:
"""Constructs a ChatPromptTemplate with the given system prompt and a fixed human prompt."""
return ChatPromptTemplate.from_messages([
("system", system_prompt),
(
"human",
"context:\n{context}\n\nquestion:\n{question}\n\n"
"Answer based on the Amazon datasets:",
),
])
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def run_rag(
retriever: Any,
query: str,
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
repo_id: str = DEFAULT_REPO_ID,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
provider: str = "auto",
verbose: bool = False,
) -> tuple[str, list[Document]]:
"""Runs a Retrieval-Augmented Generation (RAG) chain for a grocery query."""
# ------------------------------------------------------------------
# Build chain components
# ------------------------------------------------------------------
logger.info("Initialising LLM endpoint: %s", repo_id)
llm = _build_llm(repo_id, max_new_tokens, provider)
prompt_template = _build_prompt_template(system_prompt)
web_context, web_sources = _maybe_web_search(query)
retrieved_docs: list[Document] = []
def _retrieve_and_capture(query: str) -> list[Document]:
"""Invokes the retriever and captures the retrieved documents for later use."""
docs = retriever.invoke(query)
retrieved_docs.extend(docs)
return docs
rag_chain = (
{
"context": RunnableLambda(_retrieve_and_capture)
| RunnableLambda(build_context)
| RunnableLambda(lambda ctx: ctx + web_context)
| _make_verbose_tap("RETRIEVED CONTEXT", verbose),
"question": RunnablePassthrough(),
}
| _make_verbose_tap("PROMPT INPUTS (context + question)", verbose)
| prompt_template
| _make_verbose_tap("RENDERED PROMPT SENT TO LLM", verbose)
| llm
| StrOutputParser()
)
# ------------------------------------------------------------------
# Run
# ------------------------------------------------------------------
logger.info("Invoking RAG chain for query: %r", query)
answer: str = rag_chain.invoke(query)
logger.debug("RAG answer: %s", answer)
return answer, retrieved_docs, web_sources