import os import torch import torch.nn.functional as F from PIL import Image from transformers import AutoModel, AutoProcessor import chromadb from tqdm import tqdm class OrbiittEngine: def __init__(self, db_path="./orbiitt_db"): self.image_folder = "Productimages" self.db_path = db_path # 1. Device detection (Native Apple Silicon Support) self.device = "mps" if torch.backends.mps.is_available() else "cpu" # 2. Load SigLIP 2 (The Modern Champ) print(f"🧠 Loading SigLIP 2 (google/siglip2-base-patch16-256) on {self.device}...") self.model_name = "google/siglip2-base-patch16-256" self.model = AutoModel.from_pretrained(self.model_name).to(self.device).eval() self.processor = AutoProcessor.from_pretrained(self.model_name) # 3. Get Expected Dimension (768 for Base) self.expected_dim = self.model.config.vision_config.hidden_size # 4. Connect to Database with Safety Logic self.client = chromadb.PersistentClient(path=self.db_path) self._check_db_compatibility() # CRITICAL: Set space to 'cosine' for AI search self.collection = self.client.get_or_create_collection( name="product_catalog", metadata={"hnsw:space": "cosine"} ) def _check_db_compatibility(self): """Ensures the stored vectors match SigLIP 2's 768 dimensions.""" try: col = self.client.get_collection(name="product_catalog") sample = col.peek(limit=1) if sample and sample['embeddings']: existing_dim = len(sample['embeddings'][0]) if existing_dim != self.expected_dim: print(f"⚠️ Dimension Mismatch: DB is {existing_dim}, Model is {self.expected_dim}") if input("Wipe DB and restart? (y/n): ").lower() == 'y': self.client.delete_collection(name="product_catalog") else: exit() except: pass def get_image_embedding(self, image_path): """Processes image and returns a normalized 768D vector.""" image = Image.open(image_path).convert("RGB") inputs = self.processor(images=image, return_tensors="pt").to(self.device) with torch.no_grad(): features = self.model.get_image_features(**inputs) # Normalize to unit length (Unit Sphere) features = F.normalize(features, p=2, dim=-1) return features.squeeze().cpu().numpy().tolist() def get_text_embedding(self, text): """Processes text and returns a normalized 768D vector.""" # Use the SigLIP 2 standard prompt template prompt = f"this is a photo of {text}" inputs = self.processor(text=[prompt], padding="max_length", return_tensors="pt").to(self.device) with torch.no_grad(): features = self.model.get_text_features(**inputs) features = F.normalize(features, p=2, dim=-1) return features.squeeze().cpu().numpy().tolist() def index_images(self): """Scans the Productimages folder and indexes them.""" if not os.path.exists(self.image_folder): print(f"❌ Error: {self.image_folder} not found."); return files = [f for f in os.listdir(self.image_folder) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))] print(f"🏗️ Indexing {len(files)} products...") for fname in tqdm(files, desc="SigLIP 2 Processing"): path = os.path.join(self.image_folder, fname) if len(self.collection.get(ids=[fname])['ids']) > 0: continue try: emb = self.get_image_embedding(path) self.collection.add(ids=[fname], embeddings=[emb], metadatas=[{"path": path}]) except Exception as e: tqdm.write(f"⚠️ Skipped {fname}: {e}") def search(self, text_query=None, image_file=None, text_weight=0.5): """Hybrid search blending visual and text embeddings.""" img_vec = None txt_vec = None if image_file: img_vec = torch.tensor(self.get_image_embedding(image_file)) if text_query: txt_vec = torch.tensor(self.get_text_embedding(text_query)) # BLENDING LOGIC if img_vec is not None and txt_vec is not None: # Combined and then re-normalized to maintain 1.0 length combined = (img_vec * (1.0 - text_weight)) + (txt_vec * text_weight) query_emb = F.normalize(combined, p=2, dim=0).tolist() elif img_vec is not None: query_emb = img_vec.tolist() elif txt_vec is not None: query_emb = txt_vec.tolist() else: return [] results = self.collection.query(query_embeddings=[query_emb], n_results=10) output = [] for i in range(len(results['ids'][0])): fname = results['ids'][0][i] # distance for 'cosine' is 1 - similarity. # 0 distance = perfect match. score = round((1.0 - results['distances'][0][i]) * 100) output.append({ "id": fname, "url": f"http://localhost:8000/Productimages/{fname}", "score": score }) return output