File size: 8,813 Bytes
681ec3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dff5f2a
681ec3c
dff5f2a
681ec3c
 
 
 
 
 
 
 
251d75e
681ec3c
 
 
 
 
 
dff5f2a
 
 
 
 
 
681ec3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dff5f2a
681ec3c
 
 
 
 
 
 
 
dff5f2a
 
 
 
 
 
681ec3c
dff5f2a
681ec3c
dff5f2a
 
681ec3c
dff5f2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251d75e
681ec3c
251d75e
681ec3c
dff5f2a
681ec3c
 
 
 
dff5f2a
681ec3c
 
 
 
 
 
 
 
 
dff5f2a
251d75e
681ec3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251d75e
681ec3c
 
 
 
 
 
 
 
 
 
251d75e
681ec3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dff5f2a
251d75e
681ec3c
 
 
 
 
 
 
dff5f2a
 
 
681ec3c
 
251d75e
681ec3c
dff5f2a
 
681ec3c
 
 
 
 
dff5f2a
681ec3c
 
 
 
 
dff5f2a
681ec3c
 
 
 
 
 
 
 
 
 
 
dff5f2a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""
rag_chain.py
------------
Amazon product RAG (Retrieval-Augmented Generation) pipeline using
LangChain + HuggingFace Inference Endpoints.

Typical usage
-------------
>>> from rag_chain import run_rag
>>> answer = run_rag(retriever, "Moisturizing shampoo for thick curly hair")
>>> print(answer)
"""

from __future__ import annotations

import logging
from typing import Any

from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
import os
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint

# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
DEFAULT_REPO_ID = "Qwen/Qwen2.5-7B-Instruct"
DEFAULT_MAX_NEW_TOKENS = 512
DEFAULT_TOP_K = 5

DEFAULT_SYSTEM_PROMPT = (
    "You are a helpful Amazon grocery shopping assistant.\n\n"
    "You will receive a grocery query and a list of related Amazon products (including reviews and metadata).\n\n"

    "If the context contains a section starting with 'Web search results', "
    "incorporate that pricing or availability information naturally into your answer β€” "
    "do not copy it verbatim or list raw numbers. Sources will be displayed separately, "
    "so you do not need to include URLs in your response.\n\n"

    "Your response must follow this exact structure:\n\n"
    "---\n\n"
    "## πŸ›’ Recommended Products\n"
    "For each product, write a numbered list entry, mentioning products by title "
    "followed by 1–2 sentences describing the product and why it suits the query.\n\n"
    "## πŸ’‘ Tips & Recipe Ideas\n"
    "A bullet-point list of practical tips, storage advice, and brief recipe ideas related to the products above "
    "(do NOT write out full recipes β€” keep each idea to 1–2 sentences)."
    "Add food emojis if relevant.\n\n"
    "---\n\n"
    "Rules:\n"
    "- Do not invent products. Only recommend products from the provided list.\n"
    "- Keep descriptions factual and grounded in the provided reviews and metadata.\n"
    "- Recipe ideas should be suggestions or ideas only, not step-by-step instructions.\n"
    "- Format the entire response in Markdown.\n"
    "- If any information comes from a web search, cite the source inline as [source](url).\n"
    "- IMPORTANT: Whenever citing the product title: add the parent_asin in the following format [title](#parent_asin)"
)

# ---------------------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------------------
from langchain_core.runnables import RunnableLambda

# Keyword triggers that suggest the query needs external/current information
_WEB_SEARCH_TRIGGERS = {
    "price", "cost", "available", "availability", "recall", "news",
    "latest", "current", "today", "recently", "substitute", "substitution",
    "allergen", "gluten", "vegan", "organic", "nutrition", "calories",
}

def _maybe_web_search(query: str) -> tuple[str, list[dict]]:
    """
    Returns (context_string, sources_list) where sources_list is
    [{"title": ..., "url": ...}, ...] for clean rendering.
    """
    tokens = set(query.lower().split())
    if tokens & _WEB_SEARCH_TRIGGERS:
        try:
            from tavily import TavilyClient
            client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
            response = client.search(query, max_results=3)
            results = response.get("results", [])
            snippets = "\n\n".join(r["content"] for r in results)
            sources = [{"title": r.get("title", r["url"]), "url": r["url"]} for r in results]
            context = f"\n\nWeb search results (use this to answer pricing/availability questions):\n{snippets}"
            return context, sources
        except Exception as e:
            logger.warning("Web search failed: %s", e)
    return "", []


def _make_verbose_tap(label: str, verbose: bool):
    """Returns a Runnable that prints the value with a label if verbose=True, then passes it through unchanged."""
    def _tap(value):
        """Prints the value with a label if verbose=True, then returns it unchanged."""
        if verbose:
            if hasattr(value, "messages"):
                rendered = "\n".join(
                    f"[{m.type.upper()}]: {m.content}"
                    for m in value.messages
                )
            elif isinstance(value, list):
                rendered = "\n".join(str(d) for d in value)
            else:
                rendered = str(value)
            print(f"\n{'='*60}\n{label}\n{'='*60}\n{rendered}\n")
            logger.debug("%s\n%s", label, rendered)
        return value
    return RunnableLambda(_tap)


def build_context(docs: list[Document]) -> str:
    """Converts a list of Documents into a single string context for the LLM."""
    if not isinstance(docs, list):
        raise TypeError(
            f"'docs' must be a list of Document objects, got {type(docs).__name__}."
        )
    for i, doc in enumerate(docs):
        if not isinstance(doc, Document):
            raise TypeError(
                f"Element at index {i} is not a Document; got {type(doc).__name__}."
            )
    if not docs:
        logger.warning("build_context received an empty document list.")
        return ""
    return "\n\n".join(
        f"ASIN {doc.metadata.get('parent_asin', n)} Description: {doc.page_content}\n"
        f"Metadata: {doc.metadata}"
        for n, doc in enumerate(docs)
    )


def _build_llm(
    repo_id: str,
    max_new_tokens: int,
    provider: str,
) -> ChatHuggingFace:
    """Initializes a HuggingFaceEndpoint and wraps it in a ChatHuggingFace LLM."""
    endpoint = HuggingFaceEndpoint(
        repo_id=repo_id,
        task="text-generation",
        max_new_tokens=max_new_tokens,
        provider=provider,
    )
    return ChatHuggingFace(llm=endpoint)


def _build_prompt_template(system_prompt: str) -> ChatPromptTemplate:
    """Constructs a ChatPromptTemplate with the given system prompt and a fixed human prompt."""
    return ChatPromptTemplate.from_messages([
        ("system", system_prompt),
        (
            "human",
            "context:\n{context}\n\nquestion:\n{question}\n\n"
            "Answer based on the Amazon datasets:",
        ),
    ])


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

def run_rag(
    retriever: Any,
    query: str,
    system_prompt: str = DEFAULT_SYSTEM_PROMPT,
    repo_id: str = DEFAULT_REPO_ID,
    max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
    provider: str = "auto",
    verbose: bool = False,
) -> tuple[str, list[Document]]:
    """Runs a Retrieval-Augmented Generation (RAG) chain for a grocery query."""
    # ------------------------------------------------------------------
    # Build chain components
    # ------------------------------------------------------------------
    logger.info("Initialising LLM endpoint: %s", repo_id)
    llm = _build_llm(repo_id, max_new_tokens, provider)
    prompt_template = _build_prompt_template(system_prompt)

    web_context, web_sources = _maybe_web_search(query)

    retrieved_docs: list[Document] = []

    def _retrieve_and_capture(query: str) -> list[Document]:
        """Invokes the retriever and captures the retrieved documents for later use."""
        docs = retriever.invoke(query)
        retrieved_docs.extend(docs)
        return docs

    rag_chain = (
        {
            "context": RunnableLambda(_retrieve_and_capture)
                       | RunnableLambda(build_context)
                       | RunnableLambda(lambda ctx: ctx + web_context)
                       | _make_verbose_tap("RETRIEVED CONTEXT", verbose),
            "question": RunnablePassthrough(),
        }
        | _make_verbose_tap("PROMPT INPUTS (context + question)", verbose)
        | prompt_template
        | _make_verbose_tap("RENDERED PROMPT SENT TO LLM", verbose)
        | llm
        | StrOutputParser()
    )

    # ------------------------------------------------------------------
    # Run
    # ------------------------------------------------------------------
    logger.info("Invoking RAG chain for query: %r", query)
    answer: str = rag_chain.invoke(query)
    logger.debug("RAG answer: %s", answer)

    return answer, retrieved_docs, web_sources