Spaces:
Sleeping
Sleeping
| """ | |
| rag_chain.py | |
| ------------ | |
| Amazon product RAG (Retrieval-Augmented Generation) pipeline using | |
| LangChain + HuggingFace Inference Endpoints. | |
| Typical usage | |
| ------------- | |
| from rag_chain import run_rag | |
| answer = run_rag(retriever, "Moisturizing shampoo for thick curly hair") | |
| print(answer) | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from typing import Any | |
| from langchain_core.documents import Document | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.runnables import RunnableLambda, RunnablePassthrough | |
| import os | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
| # --------------------------------------------------------------------------- | |
| # Logging | |
| # --------------------------------------------------------------------------- | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------------------------- | |
| DEFAULT_REPO_ID = "Qwen/Qwen2.5-7B-Instruct" | |
| DEFAULT_MAX_NEW_TOKENS = 512 | |
| DEFAULT_TOP_K = 5 | |
| DEFAULT_SYSTEM_PROMPT = ( | |
| "You are a helpful Amazon grocery shopping assistant.\n\n" | |
| "You will receive a grocery query and a list of related Amazon products (including reviews and metadata).\n\n" | |
| "If the context contains a section starting with 'Web search results', " | |
| "incorporate that pricing or availability information naturally into your answer — " | |
| "do not copy it verbatim or list raw numbers. Sources will be displayed separately, " | |
| "so you do not need to include URLs in your response.\n\n" | |
| "Your response must follow this exact structure:\n\n" | |
| "---\n\n" | |
| "## 🛒 Recommended Products\n" | |
| "For each product, write a numbered list entry, mentioning products by title " | |
| "followed by 1–2 sentences describing the product and why it suits the query.\n\n" | |
| "## 💡 Tips & Recipe Ideas\n" | |
| "A bullet-point list of practical tips, storage advice, and brief recipe ideas related to the products above " | |
| "(do NOT write out full recipes — keep each idea to 1–2 sentences)." | |
| "Add food emojis if relevant.\n\n" | |
| "---\n\n" | |
| "Rules:\n" | |
| "- Do not invent products. Only recommend products from the provided list.\n" | |
| "- Keep descriptions factual and grounded in the provided reviews and metadata.\n" | |
| "- Recipe ideas should be suggestions or ideas only, not step-by-step instructions.\n" | |
| "- Format the entire response in Markdown.\n" | |
| "- If any information comes from a web search, cite the source inline as [source](url).\n" | |
| "- IMPORTANT: Whenever citing the product title: add the parent_asin in the following format [title](#parent_asin)" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Helper functions | |
| # --------------------------------------------------------------------------- | |
| from langchain_core.runnables import RunnableLambda | |
| # Keyword triggers that suggest the query needs external/current information | |
| _WEB_SEARCH_TRIGGERS = { | |
| "price", "cost", "available", "availability", "recall", "news", | |
| "latest", "current", "today", "recently", "substitute", "substitution", | |
| "allergen", "gluten", "vegan", "organic", "nutrition", "calories", | |
| } | |
| def _maybe_web_search(query: str) -> tuple[str, list[dict]]: | |
| """ | |
| Returns (context_string, sources_list) where sources_list is | |
| [{"title": ..., "url": ...}, ...] for clean rendering. | |
| """ | |
| tokens = set(query.lower().split()) | |
| if tokens & _WEB_SEARCH_TRIGGERS: | |
| try: | |
| from tavily import TavilyClient | |
| client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY")) | |
| response = client.search(query, max_results=3) | |
| results = response.get("results", []) | |
| snippets = "\n\n".join(r["content"] for r in results) | |
| sources = [{"title": r.get("title", r["url"]), "url": r["url"]} for r in results] | |
| context = f"\n\nWeb search results (use this to answer pricing/availability questions):\n{snippets}" | |
| return context, sources | |
| except Exception as e: | |
| logger.warning("Web search failed: %s", e) | |
| return "", [] | |
| def _make_verbose_tap(label: str, verbose: bool): | |
| """Returns a Runnable that prints the value with a label if verbose=True, then passes it through unchanged.""" | |
| def _tap(value): | |
| """Prints the value with a label if verbose=True, then returns it unchanged.""" | |
| if verbose: | |
| if hasattr(value, "messages"): | |
| rendered = "\n".join( | |
| f"[{m.type.upper()}]: {m.content}" | |
| for m in value.messages | |
| ) | |
| elif isinstance(value, list): | |
| rendered = "\n".join(str(d) for d in value) | |
| else: | |
| rendered = str(value) | |
| print(f"\n{'='*60}\n{label}\n{'='*60}\n{rendered}\n") | |
| logger.debug("%s\n%s", label, rendered) | |
| return value | |
| return RunnableLambda(_tap) | |
| def build_context(docs: list[Document]) -> str: | |
| """Converts a list of Documents into a single string context for the LLM.""" | |
| if not isinstance(docs, list): | |
| raise TypeError( | |
| f"'docs' must be a list of Document objects, got {type(docs).__name__}." | |
| ) | |
| for i, doc in enumerate(docs): | |
| if not isinstance(doc, Document): | |
| raise TypeError( | |
| f"Element at index {i} is not a Document; got {type(doc).__name__}." | |
| ) | |
| if not docs: | |
| logger.warning("build_context received an empty document list.") | |
| return "" | |
| return "\n\n".join( | |
| f"ASIN {doc.metadata.get('parent_asin', n)} Description: {doc.page_content}\n" | |
| f"Metadata: {doc.metadata}" | |
| for n, doc in enumerate(docs) | |
| ) | |
| def _build_llm( | |
| repo_id: str, | |
| max_new_tokens: int, | |
| provider: str, | |
| ) -> ChatHuggingFace: | |
| """Initializes a HuggingFaceEndpoint and wraps it in a ChatHuggingFace LLM.""" | |
| endpoint = HuggingFaceEndpoint( | |
| repo_id=repo_id, | |
| task="text-generation", | |
| max_new_tokens=max_new_tokens, | |
| provider=provider, | |
| ) | |
| return ChatHuggingFace(llm=endpoint) | |
| def _build_prompt_template(system_prompt: str) -> ChatPromptTemplate: | |
| """Constructs a ChatPromptTemplate with the given system prompt and a fixed human prompt.""" | |
| return ChatPromptTemplate.from_messages([ | |
| ("system", system_prompt), | |
| ( | |
| "human", | |
| "context:\n{context}\n\nquestion:\n{question}\n\n" | |
| "Answer based on the Amazon datasets:", | |
| ), | |
| ]) | |
| # --------------------------------------------------------------------------- | |
| # Public API | |
| # --------------------------------------------------------------------------- | |
| def run_rag( | |
| retriever: Any, | |
| query: str, | |
| system_prompt: str = DEFAULT_SYSTEM_PROMPT, | |
| repo_id: str = DEFAULT_REPO_ID, | |
| max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS, | |
| provider: str = "auto", | |
| verbose: bool = False, | |
| ) -> tuple[str, list[Document]]: | |
| """Runs a Retrieval-Augmented Generation (RAG) chain for a grocery query.""" | |
| # ------------------------------------------------------------------ | |
| # Build chain components | |
| # ------------------------------------------------------------------ | |
| logger.info("Initialising LLM endpoint: %s", repo_id) | |
| llm = _build_llm(repo_id, max_new_tokens, provider) | |
| prompt_template = _build_prompt_template(system_prompt) | |
| web_context, web_sources = _maybe_web_search(query) | |
| retrieved_docs: list[Document] = [] | |
| def _retrieve_and_capture(query: str) -> list[Document]: | |
| """Invokes the retriever and captures the retrieved documents for later use.""" | |
| docs = retriever.invoke(query) | |
| retrieved_docs.extend(docs) | |
| return docs | |
| rag_chain = ( | |
| { | |
| "context": RunnableLambda(_retrieve_and_capture) | |
| | RunnableLambda(build_context) | |
| | RunnableLambda(lambda ctx: ctx + web_context) | |
| | _make_verbose_tap("RETRIEVED CONTEXT", verbose), | |
| "question": RunnablePassthrough(), | |
| } | |
| | _make_verbose_tap("PROMPT INPUTS (context + question)", verbose) | |
| | prompt_template | |
| | _make_verbose_tap("RENDERED PROMPT SENT TO LLM", verbose) | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Run | |
| # ------------------------------------------------------------------ | |
| logger.info("Invoking RAG chain for query: %r", query) | |
| answer: str = rag_chain.invoke(query) | |
| logger.debug("RAG answer: %s", answer) | |
| return answer, retrieved_docs, web_sources |