Easonwangzk's picture
Initial commit with Git LFS
ab26b91
raw
history blame
13.9 kB
# llm_agent.py
# ============================================
# LLM layer for Amazon Multimodal RAG project
# - Reuses CLIP + Chroma from rag.py
# - Supports zero-shot / few-shot / multi-shot prompts
# - Exposes generate_answer() for UI team
# ============================================
import textwrap
import logging
from typing import List, Dict, Optional
from transformers import pipeline
# Import teammates' code
from rag import CLIPEmbedder, ChromaVectorStore, clean_text
logger = logging.getLogger(__name__)
# ===========================================================
# 1. LLM CLIENTS
# ===========================================================
# 1a. OpenAI GPT-4 Client
try:
from openai import OpenAI
OPENAI_AVAILABLE = True
except ImportError:
logger.warning("OpenAI package not installed. Install with: pip install openai")
OPENAI_AVAILABLE = False
class OpenAILLMClient:
"""
OpenAI GPT-4 client with same interface as LLMClient.
Compatible drop-in replacement for HuggingFace pipeline.
"""
def __init__(
self,
api_key: str,
model: str = "gpt-4o",
max_tokens: int = 512,
temperature: float = 0.2,
):
if not OPENAI_AVAILABLE:
raise ImportError("OpenAI package not installed. Install with: pip install openai")
if not api_key:
raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable.")
self.client = OpenAI(api_key=api_key)
self.model = model
self.max_tokens = max_tokens
self.temperature = temperature
logger.info(f"Initialized OpenAI client with model: {model}")
def generate(self, prompt: str) -> str:
"""
Generate text using OpenAI API.
Interface compatible with LLMClient.generate()
"""
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
max_tokens=self.max_tokens,
temperature=self.temperature
)
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"OpenAI API error: {e}")
raise
# 1b. HuggingFace Local Model Client
class LLMClient:
"""
Thin wrapper around a HuggingFace text-generation pipeline.
Swap model_name for any open-source instruct model you can run.
Examples:
- "meta-llama/Meta-Llama-3-8B-Instruct"
- "mistralai/Mixtral-8x7B-Instruct-v0.1"
- "mistralai/Mistral-7B-Instruct-v0.3
"""
def __init__(
self,
model_name: str = "mistralai/Mistral-7B-Instruct-v0.3",
max_new_tokens: int = 512,
temperature: float = 0.2,
):
self.generator = pipeline(
"text-generation",
model=model_name,
device_map="auto",
)
self.max_new_tokens = max_new_tokens
self.temperature = temperature
def generate(self, prompt: str) -> str:
out = self.generator(
prompt,
max_new_tokens=self.max_new_tokens,
do_sample=True,
temperature=self.temperature,
pad_token_id=self.generator.tokenizer.eos_token_id,
)[0]["generated_text"]
# Many instruct models echo the prompt; strip it out if needed
return out[len(prompt):].strip() if out.startswith(prompt) else out.strip()
# ===========================================================
# 2. RETRIEVAL → CONTEXT BUILDING
# ===========================================================
def retrieve_products(
query_text: Optional[str] = None,
image_path: Optional[str] = None,
persist_dir: str = "chromadb_store",
top_k: int = 5,
) -> List[Dict]:
"""
Uses the same CLIP + Chroma setup as rag.py,
but returns a clean Python list of product dicts.
"""
if not query_text and not image_path:
raise ValueError("Provide either query_text or image_path.")
embedder = CLIPEmbedder()
vectorstore = ChromaVectorStore(persist_dir=persist_dir)
# True multimodal fusion: combine text + image when both are provided
if query_text and image_path:
# Both text and image provided: fuse embeddings (matches rag.py:229)
text_emb = embedder.embed_text(query_text)
img_emb = embedder.embed_image(image_path)
emb = (text_emb + img_emb) / 2 # Simple averaging, consistent with index building
elif query_text:
# Text-only query
emb = embedder.embed_text(query_text)
elif image_path:
# Image-only query
emb = embedder.embed_image(image_path)
else:
raise ValueError("Provide either query_text or image_path.")
results = vectorstore.query(emb, top_k=top_k)
products = []
ids = results["ids"][0]
metas = results["metadatas"][0]
dists = results["distances"][0]
for pid, meta, dist in zip(ids, metas, dists):
products.append(
{
"id": pid,
"name": meta.get("name", ""),
"category": meta.get("category", ""),
"image_path": meta.get("image_path", None),
"distance": float(dist),
}
)
return products
def build_context_block(products: List[Dict]) -> str:
"""
Turns retrieved products into a readable text block
that we can feed to the LLM as 'CONTEXT'.
"""
lines = []
for i, p in enumerate(products, start=1):
snippet = textwrap.dedent(f"""
[Product {i}]
id: {p.get("id")}
name: {p.get("name")}
category: {p.get("category")}
image_path: {p.get("image_path")}
similarity_score: {1 - p.get("distance", 0):.4f}
""").strip()
lines.append(snippet)
return "\n\n".join(lines)
# ===========================================================
# 3. PROMPT TEMPLATES
# (Zero-shot / Few-shot / Multi-shot)
# ===========================================================
def _few_shot_examples() -> str:
"""
Two short in-context examples using the same format.
This satisfies the 'few-shot' requirement.
"""
return textwrap.dedent("""
### Example 1
USER QUESTION:
"What are the main features of this Bluetooth speaker?"
CONTEXT:
[Product 1]
name: JBL Go 3 Portable Speaker
category: Electronics
image_path: images/jbl_go3.jpg
ASSISTANT ANSWER:
The JBL Go 3 is a small portable Bluetooth speaker designed for travel.
It offers wireless Bluetooth audio, IP67 water and dust resistance,
and up to about 5 hours of playback on a single charge.
### Example 2
USER QUESTION:
"Can you compare the two smartwatches you found for me?"
CONTEXT:
[Product 1]
name: Apple Watch Series 9 GPS
category: Wearable Technology
[Product 2]
name: Samsung Galaxy Watch 6
category: Wearable Technology
ASSISTANT ANSWER:
Both watches are full-featured smartwatches for fitness and daily use.
The Apple Watch Series 9 is tightly integrated with the Apple ecosystem
and works best with iPhones. The Galaxy Watch 6 is built for Android
phones and integrates with Samsung Health. Choose based on whether
you mainly use iOS or Android.
""").strip()
def build_prompt(
user_question: str,
context_block: str,
mode: str = "zero-shot",
chat_history: Optional[List[Dict[str, str]]] = None,
is_image_query: bool = False,
) -> str:
"""
mode: "zero-shot" | "few-shot" | "multi-shot"
chat_history: list of {"role": "user"/"assistant", "content": "..."}
is_image_query: True if user uploaded an image (changes prompt strategy)
"""
history_str = ""
if chat_history:
formatted_turns = []
for turn in chat_history:
role = turn.get("role", "user").upper()
content = turn.get("content", "")
formatted_turns.append(f"{role}: {content}")
history_str = "\n".join(formatted_turns)
# Different instructions for image vs text queries
if is_image_query:
base_instructions = textwrap.dedent("""
You are a helpful e-commerce assistant for an Amazon-like store.
IMPORTANT: The user uploaded an image, and our visual similarity search system (powered by CLIP)
has retrieved the most visually similar products from our database.
You are given:
1) The user's question about the uploaded image.
2) A CONTEXT block with retrieved products ranked by visual similarity.
- similarity_score: Higher scores (closer to 1.0) mean the product looks more similar to the uploaded image.
- Each product includes: id, name, category, image_path, similarity_score.
RULES FOR IMAGE-BASED QUERIES:
- The products in CONTEXT were selected because they visually resemble the uploaded image.
- Trust the similarity_score: products with scores > 0.8 are highly similar to the uploaded image.
- Describe the retrieved products based on their names, categories, and similarity scores.
- If the top result has high similarity (>0.8), you can confidently say "This appears to be..." or "The uploaded image shows...".
- If similarity scores are moderate (0.6-0.8), say "This looks similar to..." and list top matches.
- Compare multiple products if their similarity scores are close.
- You can infer product characteristics from the product name and category.
- Be helpful and descriptive based on the retrieved product information.
- Do NOT say you cannot see the image - the visual search has already been performed for you.
""").strip()
else:
base_instructions = textwrap.dedent("""
You are a helpful e-commerce assistant for an Amazon-like store.
You are given:
1) The user's question.
2) A CONTEXT block with retrieved products (id, name, category, image_path, similarity_score).
RULES:
- Use ONLY the information in CONTEXT plus general consumer knowledge.
- Prefer products with higher similarity_score.
- Be concise and factual.
- If the context does not contain enough information, say that you are not sure.
- If multiple products are relevant, compare them clearly.
- Do NOT invent product names or specs that are not implied by the context.
""").strip()
prompt_parts = [base_instructions]
# Add chat history (for multi-turn conversations)
if history_str:
prompt_parts.append("\n---\nCHAT HISTORY (previous turns):\n" + history_str)
# Add few-shot or multi-shot examples
if mode == "few-shot":
prompt_parts.append("\n---\nFEW-SHOT EXAMPLES:\n" + _few_shot_examples())
elif mode == "multi-shot":
# For simplicity, reuse the same examples but label as "multi-shot"
# (You could easily extend with 3+ examples here.)
prompt_parts.append("\n---\nMULTI-SHOT EXAMPLES:\n" + _few_shot_examples())
# Finally, add current question + context
prompt_parts.append(textwrap.dedent(f"""
---
CURRENT QUESTION:
{user_question}
CONTEXT:
{context_block}
Now generate a helpful answer for the CURRENT QUESTION based on the CONTEXT.
""").strip())
return "\n\n".join(prompt_parts)
# ===========================================================
# 4. MAIN ENTRYPOINT FOR YOUR GROUP: generate_answer()
# ===========================================================
def generate_answer(
user_question: Optional[str] = None,
image_path: Optional[str] = None,
mode: str = "zero-shot",
chat_history: Optional[List[Dict[str, str]]] = None,
persist_dir: str = "chromadb_store",
model_name: str = "meta-llama/Meta-Llama-3-8B-Instruct",
llm_client: Optional["LLMClient"] = None,
) -> Dict:
"""
High-level function your Streamlit UI can call.
Args:
llm_client: Optional pre-initialized LLM client (for performance optimization)
Returns:
{
"answer": str,
"products": [ {...}, ... ] # retrieved products for display
}
"""
if not user_question and not image_path:
raise ValueError("You must provide either user_question or image_path.")
# 1. Retrieve products (text or image query)
products = retrieve_products(
query_text=user_question,
image_path=image_path,
persist_dir=persist_dir,
top_k=5,
)
# 2. Build context text for the LLM
context_block = build_context_block(products)
# 3. Build prompt with desired mode
# Detect if this is an image-based query
is_image_query = image_path is not None
prompt = build_prompt(
user_question=user_question or "User uploaded an image and asked about the product.",
context_block=context_block,
mode=mode,
chat_history=chat_history,
is_image_query=is_image_query,
)
# 4. Run open-source LLM (reuse client if provided, otherwise create new)
if llm_client is None:
llm = LLMClient(model_name=model_name)
else:
llm = llm_client
answer = llm.generate(prompt)
return {
"answer": answer,
"products": products,
}
# ===========================================================
# 5. Small CLI demo (optional)
# ===========================================================
if __name__ == "__main__":
# Example: text-only question
q = "What are the main features of the Samsung Galaxy phone you find?"
result = generate_answer(user_question=q, mode="few-shot")
print("\n=== ASSISTANT ANSWER ===\n")
print(result["answer"])
print("\n=== TOP PRODUCTS (for debugging) ===\n")
for p in result["products"]:
print(p)