""" rag_chain.py ------------ Amazon product RAG (Retrieval-Augmented Generation) pipeline using LangChain + HuggingFace Inference Endpoints. Typical usage ------------- >>> from rag_chain import run_rag >>> answer = run_rag(retriever, "Moisturizing shampoo for thick curly hair") >>> print(answer) """ from __future__ import annotations import logging from typing import Any from langchain_core.documents import Document from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableLambda, RunnablePassthrough import os from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint # --------------------------------------------------------------------------- # Logging # --------------------------------------------------------------------------- logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- DEFAULT_REPO_ID = "Qwen/Qwen2.5-7B-Instruct" DEFAULT_MAX_NEW_TOKENS = 512 DEFAULT_TOP_K = 5 DEFAULT_SYSTEM_PROMPT = ( "You are a helpful Amazon grocery shopping assistant.\n\n" "You will receive a grocery query and a list of related Amazon products (including reviews and metadata).\n\n" "If the context contains a section starting with 'Web search results', " "incorporate that pricing or availability information naturally into your answer — " "do not copy it verbatim or list raw numbers. Sources will be displayed separately, " "so you do not need to include URLs in your response.\n\n" "Your response must follow this exact structure:\n\n" "---\n\n" "## 🛒 Recommended Products\n" "For each product, write a numbered list entry, mentioning products by title " "followed by 1–2 sentences describing the product and why it suits the query.\n\n" "## 💡 Tips & Recipe Ideas\n" "A bullet-point list of practical tips, storage advice, and brief recipe ideas related to the products above " "(do NOT write out full recipes — keep each idea to 1–2 sentences)." "Add food emojis if relevant.\n\n" "---\n\n" "Rules:\n" "- Do not invent products. Only recommend products from the provided list.\n" "- Keep descriptions factual and grounded in the provided reviews and metadata.\n" "- Recipe ideas should be suggestions or ideas only, not step-by-step instructions.\n" "- Format the entire response in Markdown.\n" "- If any information comes from a web search, cite the source inline as [source](url).\n" "- IMPORTANT: Whenever citing the product title: add the parent_asin in the following format [title](#parent_asin)" ) # --------------------------------------------------------------------------- # Helper functions # --------------------------------------------------------------------------- from langchain_core.runnables import RunnableLambda # Keyword triggers that suggest the query needs external/current information _WEB_SEARCH_TRIGGERS = { "price", "cost", "available", "availability", "recall", "news", "latest", "current", "today", "recently", "substitute", "substitution", "allergen", "gluten", "vegan", "organic", "nutrition", "calories", } def _maybe_web_search(query: str) -> tuple[str, list[dict]]: """ Returns (context_string, sources_list) where sources_list is [{"title": ..., "url": ...}, ...] for clean rendering. """ tokens = set(query.lower().split()) if tokens & _WEB_SEARCH_TRIGGERS: try: from tavily import TavilyClient client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY")) response = client.search(query, max_results=3) results = response.get("results", []) snippets = "\n\n".join(r["content"] for r in results) sources = [{"title": r.get("title", r["url"]), "url": r["url"]} for r in results] context = f"\n\nWeb search results (use this to answer pricing/availability questions):\n{snippets}" return context, sources except Exception as e: logger.warning("Web search failed: %s", e) return "", [] def _make_verbose_tap(label: str, verbose: bool): """Returns a Runnable that prints the value with a label if verbose=True, then passes it through unchanged.""" def _tap(value): """Prints the value with a label if verbose=True, then returns it unchanged.""" if verbose: if hasattr(value, "messages"): rendered = "\n".join( f"[{m.type.upper()}]: {m.content}" for m in value.messages ) elif isinstance(value, list): rendered = "\n".join(str(d) for d in value) else: rendered = str(value) print(f"\n{'='*60}\n{label}\n{'='*60}\n{rendered}\n") logger.debug("%s\n%s", label, rendered) return value return RunnableLambda(_tap) def build_context(docs: list[Document]) -> str: """Converts a list of Documents into a single string context for the LLM.""" if not isinstance(docs, list): raise TypeError( f"'docs' must be a list of Document objects, got {type(docs).__name__}." ) for i, doc in enumerate(docs): if not isinstance(doc, Document): raise TypeError( f"Element at index {i} is not a Document; got {type(doc).__name__}." ) if not docs: logger.warning("build_context received an empty document list.") return "" return "\n\n".join( f"ASIN {doc.metadata.get('parent_asin', n)} Description: {doc.page_content}\n" f"Metadata: {doc.metadata}" for n, doc in enumerate(docs) ) def _build_llm( repo_id: str, max_new_tokens: int, provider: str, ) -> ChatHuggingFace: """Initializes a HuggingFaceEndpoint and wraps it in a ChatHuggingFace LLM.""" endpoint = HuggingFaceEndpoint( repo_id=repo_id, task="text-generation", max_new_tokens=max_new_tokens, provider=provider, ) return ChatHuggingFace(llm=endpoint) def _build_prompt_template(system_prompt: str) -> ChatPromptTemplate: """Constructs a ChatPromptTemplate with the given system prompt and a fixed human prompt.""" return ChatPromptTemplate.from_messages([ ("system", system_prompt), ( "human", "context:\n{context}\n\nquestion:\n{question}\n\n" "Answer based on the Amazon datasets:", ), ]) # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def run_rag( retriever: Any, query: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT, repo_id: str = DEFAULT_REPO_ID, max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS, provider: str = "auto", verbose: bool = False, ) -> tuple[str, list[Document]]: """Runs a Retrieval-Augmented Generation (RAG) chain for a grocery query.""" # ------------------------------------------------------------------ # Build chain components # ------------------------------------------------------------------ logger.info("Initialising LLM endpoint: %s", repo_id) llm = _build_llm(repo_id, max_new_tokens, provider) prompt_template = _build_prompt_template(system_prompt) web_context, web_sources = _maybe_web_search(query) retrieved_docs: list[Document] = [] def _retrieve_and_capture(query: str) -> list[Document]: """Invokes the retriever and captures the retrieved documents for later use.""" docs = retriever.invoke(query) retrieved_docs.extend(docs) return docs rag_chain = ( { "context": RunnableLambda(_retrieve_and_capture) | RunnableLambda(build_context) | RunnableLambda(lambda ctx: ctx + web_context) | _make_verbose_tap("RETRIEVED CONTEXT", verbose), "question": RunnablePassthrough(), } | _make_verbose_tap("PROMPT INPUTS (context + question)", verbose) | prompt_template | _make_verbose_tap("RENDERED PROMPT SENT TO LLM", verbose) | llm | StrOutputParser() ) # ------------------------------------------------------------------ # Run # ------------------------------------------------------------------ logger.info("Invoking RAG chain for query: %r", query) answer: str = rag_chain.invoke(query) logger.debug("RAG answer: %s", answer) return answer, retrieved_docs, web_sources