# -*- coding: utf-8 -*- """ FINAL RAG SYSTEM FOR AMAZON MULTIMODAL DATASET (LOCAL CHROMA DB) ----------------------------------------------------------------- Features: - Clean product text before embedding - CLIP text + image embedding (safe 77-token truncation) - New Chroma PersistentClient (2025 API) - CSV loader for Amazon dataset - Image downloader - Build vector DB for products - Query using text or image """ import os import csv import re import logging import requests import torch import clip from PIL import Image import chromadb import argparse import numpy as np # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # =============================================================== # TEXT CLEANING # =============================================================== def clean_text(text: str, max_chars: int = 400) -> str: """Removes Amazon noise text and limits size.""" if not isinstance(text, str): return "" patterns = [ r"Make sure this fits.*?model number\.", r"Technical details:.*", r"Specifications:.*", r"ProductDimensions:.*?(?=\|)", r"ShippingWeight:.*?(?=\|)", r"ASIN:.*?(?=\|)", r"Item model number:.*?(?=\|)", r"Go to your orders.*", r"Learn More.*" ] for p in patterns: text = re.sub(p, "", text, flags=re.IGNORECASE) text = text.replace("|", " ") text = re.sub(r"\s+", " ", text).strip() return text[:max_chars] # =============================================================== # CLIP EMBEDDER # =============================================================== class CLIPEmbedder: """Multimodal embedder using OpenAI CLIP with safe truncation.""" def __init__(self, model_name="ViT-B/32"): self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"[CLIP] Loading model on {self.device} ...") self.model, self.preprocess = clip.load(model_name, device=self.device) logger.info(f"[CLIP] Model {model_name} loaded successfully") def _truncate_tokens(self, text: str): tokens = clip.tokenize([text])[0] tokens = tokens[:77] # CLIP max length return tokens.unsqueeze(0).to(self.device) def embed_text(self, text: str): # 1. Clean text text = clean_text(text) # 2. HARD truncate before tokenizing (guaranteed safe limit) words = text.split() text = " ".join(words[:50]) # keep only first 50 words # 3. Now tokenize safely (will NEVER exceed context length) tokens = clip.tokenize([text], truncate=True).to(self.device) # 4. Encode with torch.no_grad(): emb = self.model.encode_text(tokens)[0] emb = emb / emb.norm() return emb.cpu().numpy().astype("float32") def embed_image(self, path: str): image = self.preprocess(Image.open(path)).unsqueeze(0).to(self.device) with torch.no_grad(): vec = self.model.encode_image(image)[0] vec = vec / vec.norm() return vec.cpu().numpy().astype("float32") # =============================================================== # LOCAL CHROMA VECTORSTORE (NEW API) # =============================================================== class ChromaVectorStore: """Uses new Chroma PersistentClient.""" def __init__(self, persist_dir="chromadb_store"): print(f"[Chroma] Initializing DB at: {persist_dir}") self.client = chromadb.PersistentClient(path=persist_dir) self.collection = self.client.get_or_create_collection( name="amazon_products", metadata={"hnsw:space": "cosine"} ) def add_item(self, item_id: str, embedding, metadata: dict): self.collection.add( ids=[item_id], embeddings=[embedding], metadatas=[metadata] ) def query(self, embedding, top_k=5): return self.collection.query( query_embeddings=[embedding], n_results=top_k ) # =============================================================== # DATASET LOADING / IMAGE DOWNLOADING # =============================================================== def download_first_image(urls: str, save_dir="images"): """Downloads the first valid image from the |-separated list.""" if not urls or not isinstance(urls, str): return None os.makedirs(save_dir, exist_ok=True) first_url = urls.split("|")[0].strip() if not first_url.lower().startswith("http"): return None # Decode URL-encoded characters in filename to avoid mismatch with FastAPI StaticFiles from urllib.parse import unquote img_name = os.path.join(save_dir, unquote(os.path.basename(first_url)[:50]) + ".jpg") try: r = requests.get(first_url, timeout=5) if r.status_code == 200: with open(img_name, "wb") as f: f.write(r.content) return img_name else: logger.debug(f"Failed to download image (status {r.status_code}): {first_url}") except requests.RequestException as e: logger.debug(f"Image download error for {first_url}: {e}") except Exception as e: logger.warning(f"Unexpected error downloading image {first_url}: {e}") return None # =============================================================== # BUILD INDEX # =============================================================== def build_index(csv_path, persist_dir, max_items=None): embedder = CLIPEmbedder() vectorstore = ChromaVectorStore(persist_dir) logger.info(f"šŸ“„ Loading dataset: {csv_path}") # Statistics tracking stats = { "total_processed": 0, "text_embed_failures": 0, "image_download_failures": 0, "image_embed_failures": 0, "skipped_no_image": 0 } with open(csv_path, newline='', encoding="utf-8") as f: reader = csv.DictReader(f) for i, row in enumerate(reader): if max_items and i >= max_items: break pid = row.get("uniq_id") name = row.get("product_name", "") desc = row.get("product_text", "") cat = row.get("main_category", "") img_urls = row.get("image", "") full_text = f"{name} | {cat} | {clean_text(desc)}" try: t_emb = embedder.embed_text(full_text) except Exception as e: logger.error(f"Could not embed text for {pid}: {e}") stats["text_embed_failures"] += 1 continue img_path = download_first_image(img_urls) if not img_path: logger.info(f"Skipping product {pid} - no valid image") stats["image_download_failures"] += 1 stats["skipped_no_image"] += 1 continue try: img_emb = embedder.embed_image(img_path) except Exception as e: logger.debug(f"Could not embed image for {pid}: {e}") stats["image_embed_failures"] += 1 stats["skipped_no_image"] += 1 continue final_emb = (t_emb + img_emb) / 2 # ChromaDB doesn't accept None values in metadata metadata = { "id": pid or "", "name": name or "", "category": cat or "", "image_path": img_path or "" } vectorstore.add_item(pid, final_emb, metadata) stats["total_processed"] += 1 if i % 20 == 0: logger.info(f"Indexed {i} items...") logger.info("āœ”ļø Index build complete.") logger.info(f"Statistics: {stats}") return vectorstore # =============================================================== # QUERY FUNCTION # =============================================================== def run_query(query_text=None, image_path=None, persist_dir="chromadb_store"): embedder = CLIPEmbedder() vectorstore = ChromaVectorStore(persist_dir) if query_text: emb = embedder.embed_text(query_text) elif image_path: emb = embedder.embed_image(image_path) else: raise ValueError("Provide query text or image") results = vectorstore.query(emb, top_k=5) print("\nšŸ” QUERY RESULTS") print("------------------------") for i in range(len(results["ids"][0])): pid = results["ids"][0][i] meta = results["metadatas"][0][i] dist = results["distances"][0][i] print(f"\nRank {i+1}") print(f"Product ID: {pid}") print(f"Name: {meta.get('name')}") print(f"Category: {meta.get('category')}") print(f"Distance: {dist:.4f}") return results # =============================================================== # RETRIEVAL EVALUATION (Recall@K) # =============================================================== def evaluate_retrieval(csv_path, persist_dir="chromadb_store", max_eval=50): """ Evaluate retrieval performance using category match as ground truth. Computes: - Accuracy@1 - Recall@1 - Recall@5 - Recall@10 """ print("\nšŸ”Ž Starting retrieval evaluation...\n") embedder = CLIPEmbedder() vectorstore = ChromaVectorStore(persist_dir) queries = [] with open(csv_path, newline='', encoding="utf-8") as f: reader = csv.DictReader(f) for i, row in enumerate(reader): if i >= max_eval: break queries.append(row) total = len(queries) correct_at_1 = 0 recall_at_1 = 0 recall_at_5 = 0 recall_at_10 = 0 for idx, row in enumerate(queries): pid = row["uniq_id"] category = row["main_category"] text_query = clean_text(row["product_name"] + " " + row["product_text"]) query_emb = embedder.embed_text(text_query) # Retrieve top-10 results results = vectorstore.query(query_emb, top_k=10) retrieved_ids = results["ids"][0] retrieved_metas = results["metadatas"][0] retrieved_categories = [m.get("category") for m in retrieved_metas] # Ground truth: category match gt_category = category # Accuracy@1 + Recall@1 if retrieved_categories[0] == gt_category: correct_at_1 += 1 recall_at_1 += 1 # Recall@5 if gt_category in retrieved_categories[:5]: recall_at_5 += 1 # Recall@10 if gt_category in retrieved_categories[:10]: recall_at_10 += 1 if idx % 10 == 0: print(f"Evaluated {idx}/{total} queries...") # Convert counts to percentages accuracy_at_1 = correct_at_1 / total recall_1 = recall_at_1 / total recall_5 = recall_at_5 / total recall_10 = recall_at_10 / total print("\nšŸ“Š RETRIEVAL EVALUATION RESULTS") print("-----------------------------------") print(f"Accuracy@1: {accuracy_at_1:.3f}") print(f"Recall@1: {recall_1:.3f}") print(f"Recall@5: {recall_5:.3f}") print(f"Recall@10: {recall_10:.3f}") return { "Accuracy@1": accuracy_at_1, "Recall@1": recall_1, "Recall@5": recall_5, "Recall@10": recall_10 } # =============================================================== # CLI # =============================================================== if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--build", action="store_true") parser.add_argument("--csv", type=str) parser.add_argument("--max", type=int) parser.add_argument("--text", type=str) parser.add_argument("--image", type=str) parser.add_argument("--db", type=str, default="chromadb_store") parser.add_argument("--eval", action="store_true") args = parser.parse_args() # ------------------------------- # MODE 1: Build Index # ------------------------------- if args.build: build_index(args.csv, args.db, args.max) exit() # ------------------------------- # MODE 2: Evaluate Retrieval # ------------------------------- if args.eval: evaluate_retrieval(args.csv, persist_dir=args.db, max_eval=50) exit() # ------------------------------- # MODE 3: Query (text or image) # ------------------------------- if args.text or args.image: run_query(args.text, args.image, persist_dir=args.db) exit() # ------------------------------- # If no arguments provided # ------------------------------- print("āŒ No action specified. Use one of:") print(" --build --csv yourfile.csv") print(" --eval --csv yourfile.csv") print(" --text \"your query\"") print(" --image path_to_image")