Easonwangzk's picture
Initial commit with Git LFS
ab26b91
raw
history blame
12.9 kB
# -*- coding: utf-8 -*-
"""
FINAL RAG SYSTEM FOR AMAZON MULTIMODAL DATASET (LOCAL CHROMA DB)
-----------------------------------------------------------------
Features:
- Clean product text before embedding
- CLIP text + image embedding (safe 77-token truncation)
- New Chroma PersistentClient (2025 API)
- CSV loader for Amazon dataset
- Image downloader
- Build vector DB for products
- Query using text or image
"""
import os
import csv
import re
import logging
import requests
import torch
import clip
from PIL import Image
import chromadb
import argparse
import numpy as np
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ===============================================================
# TEXT CLEANING
# ===============================================================
def clean_text(text: str, max_chars: int = 400) -> str:
"""Removes Amazon noise text and limits size."""
if not isinstance(text, str):
return ""
patterns = [
r"Make sure this fits.*?model number\.",
r"Technical details:.*",
r"Specifications:.*",
r"ProductDimensions:.*?(?=\|)",
r"ShippingWeight:.*?(?=\|)",
r"ASIN:.*?(?=\|)",
r"Item model number:.*?(?=\|)",
r"Go to your orders.*",
r"Learn More.*"
]
for p in patterns:
text = re.sub(p, "", text, flags=re.IGNORECASE)
text = text.replace("|", " ")
text = re.sub(r"\s+", " ", text).strip()
return text[:max_chars]
# ===============================================================
# CLIP EMBEDDER
# ===============================================================
class CLIPEmbedder:
"""Multimodal embedder using OpenAI CLIP with safe truncation."""
def __init__(self, model_name="ViT-B/32"):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"[CLIP] Loading model on {self.device} ...")
self.model, self.preprocess = clip.load(model_name, device=self.device)
logger.info(f"[CLIP] Model {model_name} loaded successfully")
def _truncate_tokens(self, text: str):
tokens = clip.tokenize([text])[0]
tokens = tokens[:77] # CLIP max length
return tokens.unsqueeze(0).to(self.device)
def embed_text(self, text: str):
# 1. Clean text
text = clean_text(text)
# 2. HARD truncate before tokenizing (guaranteed safe limit)
words = text.split()
text = " ".join(words[:50]) # keep only first 50 words
# 3. Now tokenize safely (will NEVER exceed context length)
tokens = clip.tokenize([text], truncate=True).to(self.device)
# 4. Encode
with torch.no_grad():
emb = self.model.encode_text(tokens)[0]
emb = emb / emb.norm()
return emb.cpu().numpy().astype("float32")
def embed_image(self, path: str):
image = self.preprocess(Image.open(path)).unsqueeze(0).to(self.device)
with torch.no_grad():
vec = self.model.encode_image(image)[0]
vec = vec / vec.norm()
return vec.cpu().numpy().astype("float32")
# ===============================================================
# LOCAL CHROMA VECTORSTORE (NEW API)
# ===============================================================
class ChromaVectorStore:
"""Uses new Chroma PersistentClient."""
def __init__(self, persist_dir="chromadb_store"):
print(f"[Chroma] Initializing DB at: {persist_dir}")
self.client = chromadb.PersistentClient(path=persist_dir)
self.collection = self.client.get_or_create_collection(
name="amazon_products",
metadata={"hnsw:space": "cosine"}
)
def add_item(self, item_id: str, embedding, metadata: dict):
self.collection.add(
ids=[item_id],
embeddings=[embedding],
metadatas=[metadata]
)
def query(self, embedding, top_k=5):
return self.collection.query(
query_embeddings=[embedding],
n_results=top_k
)
# ===============================================================
# DATASET LOADING / IMAGE DOWNLOADING
# ===============================================================
def download_first_image(urls: str, save_dir="images"):
"""Downloads the first valid image from the |-separated list."""
if not urls or not isinstance(urls, str):
return None
os.makedirs(save_dir, exist_ok=True)
first_url = urls.split("|")[0].strip()
if not first_url.lower().startswith("http"):
return None
# Decode URL-encoded characters in filename to avoid mismatch with FastAPI StaticFiles
from urllib.parse import unquote
img_name = os.path.join(save_dir, unquote(os.path.basename(first_url)[:50]) + ".jpg")
try:
r = requests.get(first_url, timeout=5)
if r.status_code == 200:
with open(img_name, "wb") as f:
f.write(r.content)
return img_name
else:
logger.debug(f"Failed to download image (status {r.status_code}): {first_url}")
except requests.RequestException as e:
logger.debug(f"Image download error for {first_url}: {e}")
except Exception as e:
logger.warning(f"Unexpected error downloading image {first_url}: {e}")
return None
# ===============================================================
# BUILD INDEX
# ===============================================================
def build_index(csv_path, persist_dir, max_items=None):
embedder = CLIPEmbedder()
vectorstore = ChromaVectorStore(persist_dir)
logger.info(f"πŸ“„ Loading dataset: {csv_path}")
# Statistics tracking
stats = {
"total_processed": 0,
"text_embed_failures": 0,
"image_download_failures": 0,
"image_embed_failures": 0,
"skipped_no_image": 0
}
with open(csv_path, newline='', encoding="utf-8") as f:
reader = csv.DictReader(f)
for i, row in enumerate(reader):
if max_items and i >= max_items:
break
pid = row.get("uniq_id")
name = row.get("product_name", "")
desc = row.get("product_text", "")
cat = row.get("main_category", "")
img_urls = row.get("image", "")
full_text = f"{name} | {cat} | {clean_text(desc)}"
try:
t_emb = embedder.embed_text(full_text)
except Exception as e:
logger.error(f"Could not embed text for {pid}: {e}")
stats["text_embed_failures"] += 1
continue
img_path = download_first_image(img_urls)
if not img_path:
logger.info(f"Skipping product {pid} - no valid image")
stats["image_download_failures"] += 1
stats["skipped_no_image"] += 1
continue
try:
img_emb = embedder.embed_image(img_path)
except Exception as e:
logger.debug(f"Could not embed image for {pid}: {e}")
stats["image_embed_failures"] += 1
stats["skipped_no_image"] += 1
continue
final_emb = (t_emb + img_emb) / 2
# ChromaDB doesn't accept None values in metadata
metadata = {
"id": pid or "",
"name": name or "",
"category": cat or "",
"image_path": img_path or ""
}
vectorstore.add_item(pid, final_emb, metadata)
stats["total_processed"] += 1
if i % 20 == 0:
logger.info(f"Indexed {i} items...")
logger.info("βœ”οΈ Index build complete.")
logger.info(f"Statistics: {stats}")
return vectorstore
# ===============================================================
# QUERY FUNCTION
# ===============================================================
def run_query(query_text=None, image_path=None, persist_dir="chromadb_store"):
embedder = CLIPEmbedder()
vectorstore = ChromaVectorStore(persist_dir)
if query_text:
emb = embedder.embed_text(query_text)
elif image_path:
emb = embedder.embed_image(image_path)
else:
raise ValueError("Provide query text or image")
results = vectorstore.query(emb, top_k=5)
print("\nπŸ” QUERY RESULTS")
print("------------------------")
for i in range(len(results["ids"][0])):
pid = results["ids"][0][i]
meta = results["metadatas"][0][i]
dist = results["distances"][0][i]
print(f"\nRank {i+1}")
print(f"Product ID: {pid}")
print(f"Name: {meta.get('name')}")
print(f"Category: {meta.get('category')}")
print(f"Distance: {dist:.4f}")
return results
# ===============================================================
# RETRIEVAL EVALUATION (Recall@K)
# ===============================================================
def evaluate_retrieval(csv_path, persist_dir="chromadb_store", max_eval=50):
"""
Evaluate retrieval performance using category match as ground truth.
Computes:
- Accuracy@1
- Recall@1
- Recall@5
- Recall@10
"""
print("\nπŸ”Ž Starting retrieval evaluation...\n")
embedder = CLIPEmbedder()
vectorstore = ChromaVectorStore(persist_dir)
queries = []
with open(csv_path, newline='', encoding="utf-8") as f:
reader = csv.DictReader(f)
for i, row in enumerate(reader):
if i >= max_eval:
break
queries.append(row)
total = len(queries)
correct_at_1 = 0
recall_at_1 = 0
recall_at_5 = 0
recall_at_10 = 0
for idx, row in enumerate(queries):
pid = row["uniq_id"]
category = row["main_category"]
text_query = clean_text(row["product_name"] + " " + row["product_text"])
query_emb = embedder.embed_text(text_query)
# Retrieve top-10 results
results = vectorstore.query(query_emb, top_k=10)
retrieved_ids = results["ids"][0]
retrieved_metas = results["metadatas"][0]
retrieved_categories = [m.get("category") for m in retrieved_metas]
# Ground truth: category match
gt_category = category
# Accuracy@1 + Recall@1
if retrieved_categories[0] == gt_category:
correct_at_1 += 1
recall_at_1 += 1
# Recall@5
if gt_category in retrieved_categories[:5]:
recall_at_5 += 1
# Recall@10
if gt_category in retrieved_categories[:10]:
recall_at_10 += 1
if idx % 10 == 0:
print(f"Evaluated {idx}/{total} queries...")
# Convert counts to percentages
accuracy_at_1 = correct_at_1 / total
recall_1 = recall_at_1 / total
recall_5 = recall_at_5 / total
recall_10 = recall_at_10 / total
print("\nπŸ“Š RETRIEVAL EVALUATION RESULTS")
print("-----------------------------------")
print(f"Accuracy@1: {accuracy_at_1:.3f}")
print(f"Recall@1: {recall_1:.3f}")
print(f"Recall@5: {recall_5:.3f}")
print(f"Recall@10: {recall_10:.3f}")
return {
"Accuracy@1": accuracy_at_1,
"Recall@1": recall_1,
"Recall@5": recall_5,
"Recall@10": recall_10
}
# ===============================================================
# CLI
# ===============================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--build", action="store_true")
parser.add_argument("--csv", type=str)
parser.add_argument("--max", type=int)
parser.add_argument("--text", type=str)
parser.add_argument("--image", type=str)
parser.add_argument("--db", type=str, default="chromadb_store")
parser.add_argument("--eval", action="store_true")
args = parser.parse_args()
# -------------------------------
# MODE 1: Build Index
# -------------------------------
if args.build:
build_index(args.csv, args.db, args.max)
exit()
# -------------------------------
# MODE 2: Evaluate Retrieval
# -------------------------------
if args.eval:
evaluate_retrieval(args.csv, persist_dir=args.db, max_eval=50)
exit()
# -------------------------------
# MODE 3: Query (text or image)
# -------------------------------
if args.text or args.image:
run_query(args.text, args.image, persist_dir=args.db)
exit()
# -------------------------------
# If no arguments provided
# -------------------------------
print("❌ No action specified. Use one of:")
print(" --build --csv yourfile.csv")
print(" --eval --csv yourfile.csv")
print(" --text \"your query\"")
print(" --image path_to_image")