|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import textwrap |
|
|
import logging |
|
|
from typing import List, Dict, Optional |
|
|
|
|
|
from transformers import pipeline |
|
|
|
|
|
|
|
|
from rag import CLIPEmbedder, ChromaVectorStore, clean_text |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from openai import OpenAI |
|
|
OPENAI_AVAILABLE = True |
|
|
except ImportError: |
|
|
logger.warning("OpenAI package not installed. Install with: pip install openai") |
|
|
OPENAI_AVAILABLE = False |
|
|
|
|
|
|
|
|
class OpenAILLMClient: |
|
|
""" |
|
|
OpenAI GPT-4 client with same interface as LLMClient. |
|
|
Compatible drop-in replacement for HuggingFace pipeline. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
api_key: str, |
|
|
model: str = "gpt-4o", |
|
|
max_tokens: int = 512, |
|
|
temperature: float = 0.2, |
|
|
): |
|
|
if not OPENAI_AVAILABLE: |
|
|
raise ImportError("OpenAI package not installed. Install with: pip install openai") |
|
|
|
|
|
if not api_key: |
|
|
raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable.") |
|
|
|
|
|
self.client = OpenAI(api_key=api_key) |
|
|
self.model = model |
|
|
self.max_tokens = max_tokens |
|
|
self.temperature = temperature |
|
|
logger.info(f"Initialized OpenAI client with model: {model}") |
|
|
|
|
|
def generate(self, prompt: str) -> str: |
|
|
""" |
|
|
Generate text using OpenAI API. |
|
|
Interface compatible with LLMClient.generate() |
|
|
""" |
|
|
try: |
|
|
response = self.client.chat.completions.create( |
|
|
model=self.model, |
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
max_tokens=self.max_tokens, |
|
|
temperature=self.temperature |
|
|
) |
|
|
return response.choices[0].message.content.strip() |
|
|
except Exception as e: |
|
|
logger.error(f"OpenAI API error: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LLMClient: |
|
|
""" |
|
|
Thin wrapper around a HuggingFace text-generation pipeline. |
|
|
Swap model_name for any open-source instruct model you can run. |
|
|
Examples: |
|
|
- "meta-llama/Meta-Llama-3-8B-Instruct" |
|
|
- "mistralai/Mixtral-8x7B-Instruct-v0.1" |
|
|
- "mistralai/Mistral-7B-Instruct-v0.3 |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "mistralai/Mistral-7B-Instruct-v0.3", |
|
|
max_new_tokens: int = 512, |
|
|
temperature: float = 0.2, |
|
|
): |
|
|
self.generator = pipeline( |
|
|
"text-generation", |
|
|
model=model_name, |
|
|
device_map="auto", |
|
|
) |
|
|
self.max_new_tokens = max_new_tokens |
|
|
self.temperature = temperature |
|
|
|
|
|
def generate(self, prompt: str) -> str: |
|
|
out = self.generator( |
|
|
prompt, |
|
|
max_new_tokens=self.max_new_tokens, |
|
|
do_sample=True, |
|
|
temperature=self.temperature, |
|
|
pad_token_id=self.generator.tokenizer.eos_token_id, |
|
|
)[0]["generated_text"] |
|
|
|
|
|
return out[len(prompt):].strip() if out.startswith(prompt) else out.strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def retrieve_products( |
|
|
query_text: Optional[str] = None, |
|
|
image_path: Optional[str] = None, |
|
|
persist_dir: str = "chromadb_store", |
|
|
top_k: int = 5, |
|
|
) -> List[Dict]: |
|
|
""" |
|
|
Uses the same CLIP + Chroma setup as rag.py, |
|
|
but returns a clean Python list of product dicts. |
|
|
""" |
|
|
if not query_text and not image_path: |
|
|
raise ValueError("Provide either query_text or image_path.") |
|
|
|
|
|
embedder = CLIPEmbedder() |
|
|
vectorstore = ChromaVectorStore(persist_dir=persist_dir) |
|
|
|
|
|
|
|
|
if query_text and image_path: |
|
|
|
|
|
text_emb = embedder.embed_text(query_text) |
|
|
img_emb = embedder.embed_image(image_path) |
|
|
emb = (text_emb + img_emb) / 2 |
|
|
elif query_text: |
|
|
|
|
|
emb = embedder.embed_text(query_text) |
|
|
elif image_path: |
|
|
|
|
|
emb = embedder.embed_image(image_path) |
|
|
else: |
|
|
raise ValueError("Provide either query_text or image_path.") |
|
|
|
|
|
results = vectorstore.query(emb, top_k=top_k) |
|
|
|
|
|
products = [] |
|
|
ids = results["ids"][0] |
|
|
metas = results["metadatas"][0] |
|
|
dists = results["distances"][0] |
|
|
|
|
|
for pid, meta, dist in zip(ids, metas, dists): |
|
|
products.append( |
|
|
{ |
|
|
"id": pid, |
|
|
"name": meta.get("name", ""), |
|
|
"category": meta.get("category", ""), |
|
|
"image_path": meta.get("image_path", None), |
|
|
"distance": float(dist), |
|
|
} |
|
|
) |
|
|
|
|
|
return products |
|
|
|
|
|
|
|
|
def build_context_block(products: List[Dict]) -> str: |
|
|
""" |
|
|
Turns retrieved products into a readable text block |
|
|
that we can feed to the LLM as 'CONTEXT'. |
|
|
""" |
|
|
lines = [] |
|
|
for i, p in enumerate(products, start=1): |
|
|
snippet = textwrap.dedent(f""" |
|
|
[Product {i}] |
|
|
id: {p.get("id")} |
|
|
name: {p.get("name")} |
|
|
category: {p.get("category")} |
|
|
image_path: {p.get("image_path")} |
|
|
similarity_score: {1 - p.get("distance", 0):.4f} |
|
|
""").strip() |
|
|
lines.append(snippet) |
|
|
return "\n\n".join(lines) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _few_shot_examples() -> str: |
|
|
""" |
|
|
Two short in-context examples using the same format. |
|
|
This satisfies the 'few-shot' requirement. |
|
|
""" |
|
|
return textwrap.dedent(""" |
|
|
### Example 1 |
|
|
USER QUESTION: |
|
|
"What are the main features of this Bluetooth speaker?" |
|
|
|
|
|
CONTEXT: |
|
|
[Product 1] |
|
|
name: JBL Go 3 Portable Speaker |
|
|
category: Electronics |
|
|
image_path: images/jbl_go3.jpg |
|
|
|
|
|
ASSISTANT ANSWER: |
|
|
The JBL Go 3 is a small portable Bluetooth speaker designed for travel. |
|
|
It offers wireless Bluetooth audio, IP67 water and dust resistance, |
|
|
and up to about 5 hours of playback on a single charge. |
|
|
|
|
|
### Example 2 |
|
|
USER QUESTION: |
|
|
"Can you compare the two smartwatches you found for me?" |
|
|
|
|
|
CONTEXT: |
|
|
[Product 1] |
|
|
name: Apple Watch Series 9 GPS |
|
|
category: Wearable Technology |
|
|
|
|
|
[Product 2] |
|
|
name: Samsung Galaxy Watch 6 |
|
|
category: Wearable Technology |
|
|
|
|
|
ASSISTANT ANSWER: |
|
|
Both watches are full-featured smartwatches for fitness and daily use. |
|
|
The Apple Watch Series 9 is tightly integrated with the Apple ecosystem |
|
|
and works best with iPhones. The Galaxy Watch 6 is built for Android |
|
|
phones and integrates with Samsung Health. Choose based on whether |
|
|
you mainly use iOS or Android. |
|
|
""").strip() |
|
|
|
|
|
|
|
|
def build_prompt( |
|
|
user_question: str, |
|
|
context_block: str, |
|
|
mode: str = "zero-shot", |
|
|
chat_history: Optional[List[Dict[str, str]]] = None, |
|
|
is_image_query: bool = False, |
|
|
) -> str: |
|
|
""" |
|
|
mode: "zero-shot" | "few-shot" | "multi-shot" |
|
|
chat_history: list of {"role": "user"/"assistant", "content": "..."} |
|
|
is_image_query: True if user uploaded an image (changes prompt strategy) |
|
|
""" |
|
|
|
|
|
history_str = "" |
|
|
if chat_history: |
|
|
formatted_turns = [] |
|
|
for turn in chat_history: |
|
|
role = turn.get("role", "user").upper() |
|
|
content = turn.get("content", "") |
|
|
formatted_turns.append(f"{role}: {content}") |
|
|
history_str = "\n".join(formatted_turns) |
|
|
|
|
|
|
|
|
if is_image_query: |
|
|
base_instructions = textwrap.dedent(""" |
|
|
You are a helpful e-commerce assistant for an Amazon-like store. |
|
|
|
|
|
IMPORTANT: The user uploaded an image, and our visual similarity search system (powered by CLIP) |
|
|
has retrieved the most visually similar products from our database. |
|
|
|
|
|
You are given: |
|
|
1) The user's question about the uploaded image. |
|
|
2) A CONTEXT block with retrieved products ranked by visual similarity. |
|
|
- similarity_score: Higher scores (closer to 1.0) mean the product looks more similar to the uploaded image. |
|
|
- Each product includes: id, name, category, image_path, similarity_score. |
|
|
|
|
|
RULES FOR IMAGE-BASED QUERIES: |
|
|
- The products in CONTEXT were selected because they visually resemble the uploaded image. |
|
|
- Trust the similarity_score: products with scores > 0.8 are highly similar to the uploaded image. |
|
|
- Describe the retrieved products based on their names, categories, and similarity scores. |
|
|
- If the top result has high similarity (>0.8), you can confidently say "This appears to be..." or "The uploaded image shows...". |
|
|
- If similarity scores are moderate (0.6-0.8), say "This looks similar to..." and list top matches. |
|
|
- Compare multiple products if their similarity scores are close. |
|
|
- You can infer product characteristics from the product name and category. |
|
|
- Be helpful and descriptive based on the retrieved product information. |
|
|
- Do NOT say you cannot see the image - the visual search has already been performed for you. |
|
|
""").strip() |
|
|
else: |
|
|
base_instructions = textwrap.dedent(""" |
|
|
You are a helpful e-commerce assistant for an Amazon-like store. |
|
|
You are given: |
|
|
1) The user's question. |
|
|
2) A CONTEXT block with retrieved products (id, name, category, image_path, similarity_score). |
|
|
|
|
|
RULES: |
|
|
- Use ONLY the information in CONTEXT plus general consumer knowledge. |
|
|
- Prefer products with higher similarity_score. |
|
|
- Be concise and factual. |
|
|
- If the context does not contain enough information, say that you are not sure. |
|
|
- If multiple products are relevant, compare them clearly. |
|
|
- Do NOT invent product names or specs that are not implied by the context. |
|
|
""").strip() |
|
|
|
|
|
prompt_parts = [base_instructions] |
|
|
|
|
|
|
|
|
if history_str: |
|
|
prompt_parts.append("\n---\nCHAT HISTORY (previous turns):\n" + history_str) |
|
|
|
|
|
|
|
|
if mode == "few-shot": |
|
|
prompt_parts.append("\n---\nFEW-SHOT EXAMPLES:\n" + _few_shot_examples()) |
|
|
elif mode == "multi-shot": |
|
|
|
|
|
|
|
|
prompt_parts.append("\n---\nMULTI-SHOT EXAMPLES:\n" + _few_shot_examples()) |
|
|
|
|
|
|
|
|
prompt_parts.append(textwrap.dedent(f""" |
|
|
--- |
|
|
CURRENT QUESTION: |
|
|
{user_question} |
|
|
|
|
|
CONTEXT: |
|
|
{context_block} |
|
|
|
|
|
Now generate a helpful answer for the CURRENT QUESTION based on the CONTEXT. |
|
|
""").strip()) |
|
|
|
|
|
return "\n\n".join(prompt_parts) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_answer( |
|
|
user_question: Optional[str] = None, |
|
|
image_path: Optional[str] = None, |
|
|
mode: str = "zero-shot", |
|
|
chat_history: Optional[List[Dict[str, str]]] = None, |
|
|
persist_dir: str = "chromadb_store", |
|
|
model_name: str = "meta-llama/Meta-Llama-3-8B-Instruct", |
|
|
llm_client: Optional["LLMClient"] = None, |
|
|
) -> Dict: |
|
|
""" |
|
|
High-level function your Streamlit UI can call. |
|
|
|
|
|
Args: |
|
|
llm_client: Optional pre-initialized LLM client (for performance optimization) |
|
|
|
|
|
Returns: |
|
|
{ |
|
|
"answer": str, |
|
|
"products": [ {...}, ... ] # retrieved products for display |
|
|
} |
|
|
""" |
|
|
if not user_question and not image_path: |
|
|
raise ValueError("You must provide either user_question or image_path.") |
|
|
|
|
|
|
|
|
products = retrieve_products( |
|
|
query_text=user_question, |
|
|
image_path=image_path, |
|
|
persist_dir=persist_dir, |
|
|
top_k=5, |
|
|
) |
|
|
|
|
|
|
|
|
context_block = build_context_block(products) |
|
|
|
|
|
|
|
|
|
|
|
is_image_query = image_path is not None |
|
|
|
|
|
prompt = build_prompt( |
|
|
user_question=user_question or "User uploaded an image and asked about the product.", |
|
|
context_block=context_block, |
|
|
mode=mode, |
|
|
chat_history=chat_history, |
|
|
is_image_query=is_image_query, |
|
|
) |
|
|
|
|
|
|
|
|
if llm_client is None: |
|
|
llm = LLMClient(model_name=model_name) |
|
|
else: |
|
|
llm = llm_client |
|
|
answer = llm.generate(prompt) |
|
|
|
|
|
return { |
|
|
"answer": answer, |
|
|
"products": products, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
q = "What are the main features of the Samsung Galaxy phone you find?" |
|
|
result = generate_answer(user_question=q, mode="few-shot") |
|
|
print("\n=== ASSISTANT ANSWER ===\n") |
|
|
print(result["answer"]) |
|
|
|
|
|
print("\n=== TOP PRODUCTS (for debugging) ===\n") |
|
|
for p in result["products"]: |
|
|
print(p) |
|
|
|