Spaces:
No application file
No application file
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from transformers import AutoModel, AutoProcessor | |
| import chromadb | |
| from tqdm import tqdm | |
| class OrbiittEngine: | |
| def __init__(self, db_path="./orbiitt_db"): | |
| self.image_folder = "Productimages" | |
| self.db_path = db_path | |
| # 1. Device detection (Native Apple Silicon Support) | |
| self.device = "mps" if torch.backends.mps.is_available() else "cpu" | |
| # 2. Load SigLIP 2 (The Modern Champ) | |
| print(f"🧠 Loading SigLIP 2 (google/siglip2-base-patch16-256) on {self.device}...") | |
| self.model_name = "google/siglip2-base-patch16-256" | |
| self.model = AutoModel.from_pretrained(self.model_name).to(self.device).eval() | |
| self.processor = AutoProcessor.from_pretrained(self.model_name) | |
| # 3. Get Expected Dimension (768 for Base) | |
| self.expected_dim = self.model.config.vision_config.hidden_size | |
| # 4. Connect to Database with Safety Logic | |
| self.client = chromadb.PersistentClient(path=self.db_path) | |
| self._check_db_compatibility() | |
| # CRITICAL: Set space to 'cosine' for AI search | |
| self.collection = self.client.get_or_create_collection( | |
| name="product_catalog", | |
| metadata={"hnsw:space": "cosine"} | |
| ) | |
| def _check_db_compatibility(self): | |
| """Ensures the stored vectors match SigLIP 2's 768 dimensions.""" | |
| try: | |
| col = self.client.get_collection(name="product_catalog") | |
| sample = col.peek(limit=1) | |
| if sample and sample['embeddings']: | |
| existing_dim = len(sample['embeddings'][0]) | |
| if existing_dim != self.expected_dim: | |
| print(f"⚠️ Dimension Mismatch: DB is {existing_dim}, Model is {self.expected_dim}") | |
| if input("Wipe DB and restart? (y/n): ").lower() == 'y': | |
| self.client.delete_collection(name="product_catalog") | |
| else: | |
| exit() | |
| except: pass | |
| def get_image_embedding(self, image_path): | |
| """Processes image and returns a normalized 768D vector.""" | |
| image = Image.open(image_path).convert("RGB") | |
| inputs = self.processor(images=image, return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| features = self.model.get_image_features(**inputs) | |
| # Normalize to unit length (Unit Sphere) | |
| features = F.normalize(features, p=2, dim=-1) | |
| return features.squeeze().cpu().numpy().tolist() | |
| def get_text_embedding(self, text): | |
| """Processes text and returns a normalized 768D vector.""" | |
| # Use the SigLIP 2 standard prompt template | |
| prompt = f"this is a photo of {text}" | |
| inputs = self.processor(text=[prompt], padding="max_length", return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| features = self.model.get_text_features(**inputs) | |
| features = F.normalize(features, p=2, dim=-1) | |
| return features.squeeze().cpu().numpy().tolist() | |
| def index_images(self): | |
| """Scans the Productimages folder and indexes them.""" | |
| if not os.path.exists(self.image_folder): | |
| print(f"❌ Error: {self.image_folder} not found."); return | |
| files = [f for f in os.listdir(self.image_folder) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))] | |
| print(f"🏗️ Indexing {len(files)} products...") | |
| for fname in tqdm(files, desc="SigLIP 2 Processing"): | |
| path = os.path.join(self.image_folder, fname) | |
| if len(self.collection.get(ids=[fname])['ids']) > 0: continue | |
| try: | |
| emb = self.get_image_embedding(path) | |
| self.collection.add(ids=[fname], embeddings=[emb], metadatas=[{"path": path}]) | |
| except Exception as e: | |
| tqdm.write(f"⚠️ Skipped {fname}: {e}") | |
| def search(self, text_query=None, image_file=None, text_weight=0.5): | |
| """Hybrid search blending visual and text embeddings.""" | |
| img_vec = None | |
| txt_vec = None | |
| if image_file: | |
| img_vec = torch.tensor(self.get_image_embedding(image_file)) | |
| if text_query: | |
| txt_vec = torch.tensor(self.get_text_embedding(text_query)) | |
| # BLENDING LOGIC | |
| if img_vec is not None and txt_vec is not None: | |
| # Combined and then re-normalized to maintain 1.0 length | |
| combined = (img_vec * (1.0 - text_weight)) + (txt_vec * text_weight) | |
| query_emb = F.normalize(combined, p=2, dim=0).tolist() | |
| elif img_vec is not None: | |
| query_emb = img_vec.tolist() | |
| elif txt_vec is not None: | |
| query_emb = txt_vec.tolist() | |
| else: | |
| return [] | |
| results = self.collection.query(query_embeddings=[query_emb], n_results=10) | |
| output = [] | |
| for i in range(len(results['ids'][0])): | |
| fname = results['ids'][0][i] | |
| # distance for 'cosine' is 1 - similarity. | |
| # 0 distance = perfect match. | |
| score = round((1.0 - results['distances'][0][i]) * 100) | |
| output.append({ | |
| "id": fname, | |
| "url": f"http://localhost:8000/Productimages/{fname}", | |
| "score": score | |
| }) | |
| return output |