File size: 13,920 Bytes
ab26b91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
# llm_agent.py
# ============================================
# LLM layer for Amazon Multimodal RAG project
# - Reuses CLIP + Chroma from rag.py
# - Supports zero-shot / few-shot / multi-shot prompts
# - Exposes generate_answer() for UI team
# ============================================

import textwrap
import logging
from typing import List, Dict, Optional

from transformers import pipeline

# Import teammates' code
from rag import CLIPEmbedder, ChromaVectorStore, clean_text

logger = logging.getLogger(__name__)


# ===========================================================
# 1. LLM CLIENTS
# ===========================================================

# 1a. OpenAI GPT-4 Client
try:
    from openai import OpenAI
    OPENAI_AVAILABLE = True
except ImportError:
    logger.warning("OpenAI package not installed. Install with: pip install openai")
    OPENAI_AVAILABLE = False


class OpenAILLMClient:
    """
    OpenAI GPT-4 client with same interface as LLMClient.
    Compatible drop-in replacement for HuggingFace pipeline.
    """

    def __init__(
        self,
        api_key: str,
        model: str = "gpt-4o",
        max_tokens: int = 512,
        temperature: float = 0.2,
    ):
        if not OPENAI_AVAILABLE:
            raise ImportError("OpenAI package not installed. Install with: pip install openai")

        if not api_key:
            raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable.")

        self.client = OpenAI(api_key=api_key)
        self.model = model
        self.max_tokens = max_tokens
        self.temperature = temperature
        logger.info(f"Initialized OpenAI client with model: {model}")

    def generate(self, prompt: str) -> str:
        """
        Generate text using OpenAI API.
        Interface compatible with LLMClient.generate()
        """
        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=[{"role": "user", "content": prompt}],
                max_tokens=self.max_tokens,
                temperature=self.temperature
            )
            return response.choices[0].message.content.strip()
        except Exception as e:
            logger.error(f"OpenAI API error: {e}")
            raise


# 1b. HuggingFace Local Model Client

class LLMClient:
    """
    Thin wrapper around a HuggingFace text-generation pipeline.
    Swap model_name for any open-source instruct model you can run.
    Examples: 
      - "meta-llama/Meta-Llama-3-8B-Instruct"
      - "mistralai/Mixtral-8x7B-Instruct-v0.1"
      - "mistralai/Mistral-7B-Instruct-v0.3
    """

    def __init__(
        self,
        model_name: str = "mistralai/Mistral-7B-Instruct-v0.3",
        max_new_tokens: int = 512,
        temperature: float = 0.2,
    ):
        self.generator = pipeline(
            "text-generation",
            model=model_name,
            device_map="auto",
        )
        self.max_new_tokens = max_new_tokens
        self.temperature = temperature

    def generate(self, prompt: str) -> str:
        out = self.generator(
            prompt,
            max_new_tokens=self.max_new_tokens,
            do_sample=True,
            temperature=self.temperature,
            pad_token_id=self.generator.tokenizer.eos_token_id,
        )[0]["generated_text"]
        # Many instruct models echo the prompt; strip it out if needed
        return out[len(prompt):].strip() if out.startswith(prompt) else out.strip()


# ===========================================================
# 2. RETRIEVAL → CONTEXT BUILDING
# ===========================================================

def retrieve_products(
    query_text: Optional[str] = None,
    image_path: Optional[str] = None,
    persist_dir: str = "chromadb_store",
    top_k: int = 5,
) -> List[Dict]:
    """
    Uses the same CLIP + Chroma setup as rag.py,
    but returns a clean Python list of product dicts.
    """
    if not query_text and not image_path:
        raise ValueError("Provide either query_text or image_path.")

    embedder = CLIPEmbedder()
    vectorstore = ChromaVectorStore(persist_dir=persist_dir)

    # True multimodal fusion: combine text + image when both are provided
    if query_text and image_path:
        # Both text and image provided: fuse embeddings (matches rag.py:229)
        text_emb = embedder.embed_text(query_text)
        img_emb = embedder.embed_image(image_path)
        emb = (text_emb + img_emb) / 2  # Simple averaging, consistent with index building
    elif query_text:
        # Text-only query
        emb = embedder.embed_text(query_text)
    elif image_path:
        # Image-only query
        emb = embedder.embed_image(image_path)
    else:
        raise ValueError("Provide either query_text or image_path.")

    results = vectorstore.query(emb, top_k=top_k)

    products = []
    ids = results["ids"][0]
    metas = results["metadatas"][0]
    dists = results["distances"][0]

    for pid, meta, dist in zip(ids, metas, dists):
        products.append(
            {
                "id": pid,
                "name": meta.get("name", ""),
                "category": meta.get("category", ""),
                "image_path": meta.get("image_path", None),
                "distance": float(dist),
            }
        )

    return products


def build_context_block(products: List[Dict]) -> str:
    """
    Turns retrieved products into a readable text block
    that we can feed to the LLM as 'CONTEXT'.
    """
    lines = []
    for i, p in enumerate(products, start=1):
        snippet = textwrap.dedent(f"""
        [Product {i}]
        id: {p.get("id")}
        name: {p.get("name")}
        category: {p.get("category")}
        image_path: {p.get("image_path")}
        similarity_score: {1 - p.get("distance", 0):.4f}
        """).strip()
        lines.append(snippet)
    return "\n\n".join(lines)


# ===========================================================
# 3. PROMPT TEMPLATES
#    (Zero-shot / Few-shot / Multi-shot)
# ===========================================================

def _few_shot_examples() -> str:
    """
    Two short in-context examples using the same format.
    This satisfies the 'few-shot' requirement.
    """
    return textwrap.dedent("""
    ### Example 1
    USER QUESTION:
    "What are the main features of this Bluetooth speaker?"

    CONTEXT:
    [Product 1]
    name: JBL Go 3 Portable Speaker
    category: Electronics
    image_path: images/jbl_go3.jpg

    ASSISTANT ANSWER:
    The JBL Go 3 is a small portable Bluetooth speaker designed for travel.
    It offers wireless Bluetooth audio, IP67 water and dust resistance,
    and up to about 5 hours of playback on a single charge.

    ### Example 2
    USER QUESTION:
    "Can you compare the two smartwatches you found for me?"

    CONTEXT:
    [Product 1]
    name: Apple Watch Series 9 GPS
    category: Wearable Technology

    [Product 2]
    name: Samsung Galaxy Watch 6
    category: Wearable Technology

    ASSISTANT ANSWER:
    Both watches are full-featured smartwatches for fitness and daily use.
    The Apple Watch Series 9 is tightly integrated with the Apple ecosystem
    and works best with iPhones. The Galaxy Watch 6 is built for Android
    phones and integrates with Samsung Health. Choose based on whether
    you mainly use iOS or Android.
    """).strip()


def build_prompt(
    user_question: str,
    context_block: str,
    mode: str = "zero-shot",
    chat_history: Optional[List[Dict[str, str]]] = None,
    is_image_query: bool = False,
) -> str:
    """
    mode: "zero-shot" | "few-shot" | "multi-shot"
    chat_history: list of {"role": "user"/"assistant", "content": "..."}
    is_image_query: True if user uploaded an image (changes prompt strategy)
    """

    history_str = ""
    if chat_history:
        formatted_turns = []
        for turn in chat_history:
            role = turn.get("role", "user").upper()
            content = turn.get("content", "")
            formatted_turns.append(f"{role}: {content}")
        history_str = "\n".join(formatted_turns)

    # Different instructions for image vs text queries
    if is_image_query:
        base_instructions = textwrap.dedent("""
        You are a helpful e-commerce assistant for an Amazon-like store.

        IMPORTANT: The user uploaded an image, and our visual similarity search system (powered by CLIP)
        has retrieved the most visually similar products from our database.

        You are given:
        1) The user's question about the uploaded image.
        2) A CONTEXT block with retrieved products ranked by visual similarity.
           - similarity_score: Higher scores (closer to 1.0) mean the product looks more similar to the uploaded image.
           - Each product includes: id, name, category, image_path, similarity_score.

        RULES FOR IMAGE-BASED QUERIES:
        - The products in CONTEXT were selected because they visually resemble the uploaded image.
        - Trust the similarity_score: products with scores > 0.8 are highly similar to the uploaded image.
        - Describe the retrieved products based on their names, categories, and similarity scores.
        - If the top result has high similarity (>0.8), you can confidently say "This appears to be..." or "The uploaded image shows...".
        - If similarity scores are moderate (0.6-0.8), say "This looks similar to..." and list top matches.
        - Compare multiple products if their similarity scores are close.
        - You can infer product characteristics from the product name and category.
        - Be helpful and descriptive based on the retrieved product information.
        - Do NOT say you cannot see the image - the visual search has already been performed for you.
        """).strip()
    else:
        base_instructions = textwrap.dedent("""
        You are a helpful e-commerce assistant for an Amazon-like store.
        You are given:
        1) The user's question.
        2) A CONTEXT block with retrieved products (id, name, category, image_path, similarity_score).

        RULES:
        - Use ONLY the information in CONTEXT plus general consumer knowledge.
        - Prefer products with higher similarity_score.
        - Be concise and factual.
        - If the context does not contain enough information, say that you are not sure.
        - If multiple products are relevant, compare them clearly.
        - Do NOT invent product names or specs that are not implied by the context.
        """).strip()

    prompt_parts = [base_instructions]

    # Add chat history (for multi-turn conversations)
    if history_str:
        prompt_parts.append("\n---\nCHAT HISTORY (previous turns):\n" + history_str)

    # Add few-shot or multi-shot examples
    if mode == "few-shot":
        prompt_parts.append("\n---\nFEW-SHOT EXAMPLES:\n" + _few_shot_examples())
    elif mode == "multi-shot":
        # For simplicity, reuse the same examples but label as "multi-shot"
        # (You could easily extend with 3+ examples here.)
        prompt_parts.append("\n---\nMULTI-SHOT EXAMPLES:\n" + _few_shot_examples())

    # Finally, add current question + context
    prompt_parts.append(textwrap.dedent(f"""
    ---
    CURRENT QUESTION:
    {user_question}

    CONTEXT:
    {context_block}

    Now generate a helpful answer for the CURRENT QUESTION based on the CONTEXT.
    """).strip())

    return "\n\n".join(prompt_parts)


# ===========================================================
# 4. MAIN ENTRYPOINT FOR YOUR GROUP: generate_answer()
# ===========================================================

def generate_answer(
    user_question: Optional[str] = None,
    image_path: Optional[str] = None,
    mode: str = "zero-shot",
    chat_history: Optional[List[Dict[str, str]]] = None,
    persist_dir: str = "chromadb_store",
    model_name: str = "meta-llama/Meta-Llama-3-8B-Instruct",
    llm_client: Optional["LLMClient"] = None,
) -> Dict:
    """
    High-level function your Streamlit UI can call.

    Args:
        llm_client: Optional pre-initialized LLM client (for performance optimization)

    Returns:
      {
        "answer": str,
        "products": [ {...}, ... ]  # retrieved products for display
      }
    """
    if not user_question and not image_path:
        raise ValueError("You must provide either user_question or image_path.")

    # 1. Retrieve products (text or image query)
    products = retrieve_products(
        query_text=user_question,
        image_path=image_path,
        persist_dir=persist_dir,
        top_k=5,
    )

    # 2. Build context text for the LLM
    context_block = build_context_block(products)

    # 3. Build prompt with desired mode
    # Detect if this is an image-based query
    is_image_query = image_path is not None

    prompt = build_prompt(
        user_question=user_question or "User uploaded an image and asked about the product.",
        context_block=context_block,
        mode=mode,
        chat_history=chat_history,
        is_image_query=is_image_query,
    )

    # 4. Run open-source LLM (reuse client if provided, otherwise create new)
    if llm_client is None:
        llm = LLMClient(model_name=model_name)
    else:
        llm = llm_client
    answer = llm.generate(prompt)

    return {
        "answer": answer,
        "products": products,
    }


# ===========================================================
# 5. Small CLI demo (optional)
# ===========================================================

if __name__ == "__main__":
    # Example: text-only question
    q = "What are the main features of the Samsung Galaxy phone you find?"
    result = generate_answer(user_question=q, mode="few-shot")
    print("\n=== ASSISTANT ANSWER ===\n")
    print(result["answer"])

    print("\n=== TOP PRODUCTS (for debugging) ===\n")
    for p in result["products"]:
        print(p)