# llm_agent.py # ============================================ # LLM layer for Amazon Multimodal RAG project # - Reuses CLIP + Chroma from rag.py # - Supports zero-shot / few-shot / multi-shot prompts # - Exposes generate_answer() for UI team # ============================================ import textwrap import logging from typing import List, Dict, Optional from transformers import pipeline # Import teammates' code from rag import CLIPEmbedder, ChromaVectorStore, clean_text logger = logging.getLogger(__name__) # =========================================================== # 1. LLM CLIENTS # =========================================================== # 1a. OpenAI GPT-4 Client try: from openai import OpenAI OPENAI_AVAILABLE = True except ImportError: logger.warning("OpenAI package not installed. Install with: pip install openai") OPENAI_AVAILABLE = False class OpenAILLMClient: """ OpenAI GPT-4 client with same interface as LLMClient. Compatible drop-in replacement for HuggingFace pipeline. """ def __init__( self, api_key: str, model: str = "gpt-4o", max_tokens: int = 512, temperature: float = 0.2, ): if not OPENAI_AVAILABLE: raise ImportError("OpenAI package not installed. Install with: pip install openai") if not api_key: raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable.") self.client = OpenAI(api_key=api_key) self.model = model self.max_tokens = max_tokens self.temperature = temperature logger.info(f"Initialized OpenAI client with model: {model}") def generate(self, prompt: str) -> str: """ Generate text using OpenAI API. Interface compatible with LLMClient.generate() """ try: response = self.client.chat.completions.create( model=self.model, messages=[{"role": "user", "content": prompt}], max_tokens=self.max_tokens, temperature=self.temperature ) return response.choices[0].message.content.strip() except Exception as e: logger.error(f"OpenAI API error: {e}") raise # 1b. HuggingFace Local Model Client class LLMClient: """ Thin wrapper around a HuggingFace text-generation pipeline. Swap model_name for any open-source instruct model you can run. Examples: - "meta-llama/Meta-Llama-3-8B-Instruct" - "mistralai/Mixtral-8x7B-Instruct-v0.1" - "mistralai/Mistral-7B-Instruct-v0.3 """ def __init__( self, model_name: str = "mistralai/Mistral-7B-Instruct-v0.3", max_new_tokens: int = 512, temperature: float = 0.2, ): self.generator = pipeline( "text-generation", model=model_name, device_map="auto", ) self.max_new_tokens = max_new_tokens self.temperature = temperature def generate(self, prompt: str) -> str: out = self.generator( prompt, max_new_tokens=self.max_new_tokens, do_sample=True, temperature=self.temperature, pad_token_id=self.generator.tokenizer.eos_token_id, )[0]["generated_text"] # Many instruct models echo the prompt; strip it out if needed return out[len(prompt):].strip() if out.startswith(prompt) else out.strip() # =========================================================== # 2. RETRIEVAL → CONTEXT BUILDING # =========================================================== def retrieve_products( query_text: Optional[str] = None, image_path: Optional[str] = None, persist_dir: str = "chromadb_store", top_k: int = 5, ) -> List[Dict]: """ Uses the same CLIP + Chroma setup as rag.py, but returns a clean Python list of product dicts. """ if not query_text and not image_path: raise ValueError("Provide either query_text or image_path.") embedder = CLIPEmbedder() vectorstore = ChromaVectorStore(persist_dir=persist_dir) # True multimodal fusion: combine text + image when both are provided if query_text and image_path: # Both text and image provided: fuse embeddings (matches rag.py:229) text_emb = embedder.embed_text(query_text) img_emb = embedder.embed_image(image_path) emb = (text_emb + img_emb) / 2 # Simple averaging, consistent with index building elif query_text: # Text-only query emb = embedder.embed_text(query_text) elif image_path: # Image-only query emb = embedder.embed_image(image_path) else: raise ValueError("Provide either query_text or image_path.") results = vectorstore.query(emb, top_k=top_k) products = [] ids = results["ids"][0] metas = results["metadatas"][0] dists = results["distances"][0] for pid, meta, dist in zip(ids, metas, dists): products.append( { "id": pid, "name": meta.get("name", ""), "category": meta.get("category", ""), "image_path": meta.get("image_path", None), "distance": float(dist), } ) return products def build_context_block(products: List[Dict]) -> str: """ Turns retrieved products into a readable text block that we can feed to the LLM as 'CONTEXT'. """ lines = [] for i, p in enumerate(products, start=1): snippet = textwrap.dedent(f""" [Product {i}] id: {p.get("id")} name: {p.get("name")} category: {p.get("category")} image_path: {p.get("image_path")} similarity_score: {1 - p.get("distance", 0):.4f} """).strip() lines.append(snippet) return "\n\n".join(lines) # =========================================================== # 3. PROMPT TEMPLATES # (Zero-shot / Few-shot / Multi-shot) # =========================================================== def _few_shot_examples() -> str: """ Two short in-context examples using the same format. This satisfies the 'few-shot' requirement. """ return textwrap.dedent(""" ### Example 1 USER QUESTION: "What are the main features of this Bluetooth speaker?" CONTEXT: [Product 1] name: JBL Go 3 Portable Speaker category: Electronics image_path: images/jbl_go3.jpg ASSISTANT ANSWER: The JBL Go 3 is a small portable Bluetooth speaker designed for travel. It offers wireless Bluetooth audio, IP67 water and dust resistance, and up to about 5 hours of playback on a single charge. ### Example 2 USER QUESTION: "Can you compare the two smartwatches you found for me?" CONTEXT: [Product 1] name: Apple Watch Series 9 GPS category: Wearable Technology [Product 2] name: Samsung Galaxy Watch 6 category: Wearable Technology ASSISTANT ANSWER: Both watches are full-featured smartwatches for fitness and daily use. The Apple Watch Series 9 is tightly integrated with the Apple ecosystem and works best with iPhones. The Galaxy Watch 6 is built for Android phones and integrates with Samsung Health. Choose based on whether you mainly use iOS or Android. """).strip() def build_prompt( user_question: str, context_block: str, mode: str = "zero-shot", chat_history: Optional[List[Dict[str, str]]] = None, is_image_query: bool = False, ) -> str: """ mode: "zero-shot" | "few-shot" | "multi-shot" chat_history: list of {"role": "user"/"assistant", "content": "..."} is_image_query: True if user uploaded an image (changes prompt strategy) """ history_str = "" if chat_history: formatted_turns = [] for turn in chat_history: role = turn.get("role", "user").upper() content = turn.get("content", "") formatted_turns.append(f"{role}: {content}") history_str = "\n".join(formatted_turns) # Different instructions for image vs text queries if is_image_query: base_instructions = textwrap.dedent(""" You are a helpful e-commerce assistant for an Amazon-like store. IMPORTANT: The user uploaded an image, and our visual similarity search system (powered by CLIP) has retrieved the most visually similar products from our database. You are given: 1) The user's question about the uploaded image. 2) A CONTEXT block with retrieved products ranked by visual similarity. - similarity_score: Higher scores (closer to 1.0) mean the product looks more similar to the uploaded image. - Each product includes: id, name, category, image_path, similarity_score. RULES FOR IMAGE-BASED QUERIES: - The products in CONTEXT were selected because they visually resemble the uploaded image. - Trust the similarity_score: products with scores > 0.8 are highly similar to the uploaded image. - Describe the retrieved products based on their names, categories, and similarity scores. - If the top result has high similarity (>0.8), you can confidently say "This appears to be..." or "The uploaded image shows...". - If similarity scores are moderate (0.6-0.8), say "This looks similar to..." and list top matches. - Compare multiple products if their similarity scores are close. - You can infer product characteristics from the product name and category. - Be helpful and descriptive based on the retrieved product information. - Do NOT say you cannot see the image - the visual search has already been performed for you. """).strip() else: base_instructions = textwrap.dedent(""" You are a helpful e-commerce assistant for an Amazon-like store. You are given: 1) The user's question. 2) A CONTEXT block with retrieved products (id, name, category, image_path, similarity_score). RULES: - Use ONLY the information in CONTEXT plus general consumer knowledge. - Prefer products with higher similarity_score. - Be concise and factual. - If the context does not contain enough information, say that you are not sure. - If multiple products are relevant, compare them clearly. - Do NOT invent product names or specs that are not implied by the context. """).strip() prompt_parts = [base_instructions] # Add chat history (for multi-turn conversations) if history_str: prompt_parts.append("\n---\nCHAT HISTORY (previous turns):\n" + history_str) # Add few-shot or multi-shot examples if mode == "few-shot": prompt_parts.append("\n---\nFEW-SHOT EXAMPLES:\n" + _few_shot_examples()) elif mode == "multi-shot": # For simplicity, reuse the same examples but label as "multi-shot" # (You could easily extend with 3+ examples here.) prompt_parts.append("\n---\nMULTI-SHOT EXAMPLES:\n" + _few_shot_examples()) # Finally, add current question + context prompt_parts.append(textwrap.dedent(f""" --- CURRENT QUESTION: {user_question} CONTEXT: {context_block} Now generate a helpful answer for the CURRENT QUESTION based on the CONTEXT. """).strip()) return "\n\n".join(prompt_parts) # =========================================================== # 4. MAIN ENTRYPOINT FOR YOUR GROUP: generate_answer() # =========================================================== def generate_answer( user_question: Optional[str] = None, image_path: Optional[str] = None, mode: str = "zero-shot", chat_history: Optional[List[Dict[str, str]]] = None, persist_dir: str = "chromadb_store", model_name: str = "meta-llama/Meta-Llama-3-8B-Instruct", llm_client: Optional["LLMClient"] = None, ) -> Dict: """ High-level function your Streamlit UI can call. Args: llm_client: Optional pre-initialized LLM client (for performance optimization) Returns: { "answer": str, "products": [ {...}, ... ] # retrieved products for display } """ if not user_question and not image_path: raise ValueError("You must provide either user_question or image_path.") # 1. Retrieve products (text or image query) products = retrieve_products( query_text=user_question, image_path=image_path, persist_dir=persist_dir, top_k=5, ) # 2. Build context text for the LLM context_block = build_context_block(products) # 3. Build prompt with desired mode # Detect if this is an image-based query is_image_query = image_path is not None prompt = build_prompt( user_question=user_question or "User uploaded an image and asked about the product.", context_block=context_block, mode=mode, chat_history=chat_history, is_image_query=is_image_query, ) # 4. Run open-source LLM (reuse client if provided, otherwise create new) if llm_client is None: llm = LLMClient(model_name=model_name) else: llm = llm_client answer = llm.generate(prompt) return { "answer": answer, "products": products, } # =========================================================== # 5. Small CLI demo (optional) # =========================================================== if __name__ == "__main__": # Example: text-only question q = "What are the main features of the Samsung Galaxy phone you find?" result = generate_answer(user_question=q, mode="few-shot") print("\n=== ASSISTANT ANSWER ===\n") print(result["answer"]) print("\n=== TOP PRODUCTS (for debugging) ===\n") for p in result["products"]: print(p)