ORBITT / orbiitt_engine.py
aniketkumar1106's picture
Upload 4 files
9182674 verified
import os
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import AutoModel, AutoProcessor
import chromadb
from tqdm import tqdm
class OrbiittEngine:
def __init__(self, db_path="./orbiitt_db"):
self.image_folder = "Productimages"
self.db_path = db_path
# 1. Device detection (Native Apple Silicon Support)
self.device = "mps" if torch.backends.mps.is_available() else "cpu"
# 2. Load SigLIP 2 (The Modern Champ)
print(f"🧠 Loading SigLIP 2 (google/siglip2-base-patch16-256) on {self.device}...")
self.model_name = "google/siglip2-base-patch16-256"
self.model = AutoModel.from_pretrained(self.model_name).to(self.device).eval()
self.processor = AutoProcessor.from_pretrained(self.model_name)
# 3. Get Expected Dimension (768 for Base)
self.expected_dim = self.model.config.vision_config.hidden_size
# 4. Connect to Database with Safety Logic
self.client = chromadb.PersistentClient(path=self.db_path)
self._check_db_compatibility()
# CRITICAL: Set space to 'cosine' for AI search
self.collection = self.client.get_or_create_collection(
name="product_catalog",
metadata={"hnsw:space": "cosine"}
)
def _check_db_compatibility(self):
"""Ensures the stored vectors match SigLIP 2's 768 dimensions."""
try:
col = self.client.get_collection(name="product_catalog")
sample = col.peek(limit=1)
if sample and sample['embeddings']:
existing_dim = len(sample['embeddings'][0])
if existing_dim != self.expected_dim:
print(f"⚠️ Dimension Mismatch: DB is {existing_dim}, Model is {self.expected_dim}")
if input("Wipe DB and restart? (y/n): ").lower() == 'y':
self.client.delete_collection(name="product_catalog")
else:
exit()
except: pass
def get_image_embedding(self, image_path):
"""Processes image and returns a normalized 768D vector."""
image = Image.open(image_path).convert("RGB")
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
features = self.model.get_image_features(**inputs)
# Normalize to unit length (Unit Sphere)
features = F.normalize(features, p=2, dim=-1)
return features.squeeze().cpu().numpy().tolist()
def get_text_embedding(self, text):
"""Processes text and returns a normalized 768D vector."""
# Use the SigLIP 2 standard prompt template
prompt = f"this is a photo of {text}"
inputs = self.processor(text=[prompt], padding="max_length", return_tensors="pt").to(self.device)
with torch.no_grad():
features = self.model.get_text_features(**inputs)
features = F.normalize(features, p=2, dim=-1)
return features.squeeze().cpu().numpy().tolist()
def index_images(self):
"""Scans the Productimages folder and indexes them."""
if not os.path.exists(self.image_folder):
print(f"❌ Error: {self.image_folder} not found."); return
files = [f for f in os.listdir(self.image_folder) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
print(f"🏗️ Indexing {len(files)} products...")
for fname in tqdm(files, desc="SigLIP 2 Processing"):
path = os.path.join(self.image_folder, fname)
if len(self.collection.get(ids=[fname])['ids']) > 0: continue
try:
emb = self.get_image_embedding(path)
self.collection.add(ids=[fname], embeddings=[emb], metadatas=[{"path": path}])
except Exception as e:
tqdm.write(f"⚠️ Skipped {fname}: {e}")
def search(self, text_query=None, image_file=None, text_weight=0.5):
"""Hybrid search blending visual and text embeddings."""
img_vec = None
txt_vec = None
if image_file:
img_vec = torch.tensor(self.get_image_embedding(image_file))
if text_query:
txt_vec = torch.tensor(self.get_text_embedding(text_query))
# BLENDING LOGIC
if img_vec is not None and txt_vec is not None:
# Combined and then re-normalized to maintain 1.0 length
combined = (img_vec * (1.0 - text_weight)) + (txt_vec * text_weight)
query_emb = F.normalize(combined, p=2, dim=0).tolist()
elif img_vec is not None:
query_emb = img_vec.tolist()
elif txt_vec is not None:
query_emb = txt_vec.tolist()
else:
return []
results = self.collection.query(query_embeddings=[query_emb], n_results=10)
output = []
for i in range(len(results['ids'][0])):
fname = results['ids'][0][i]
# distance for 'cosine' is 1 - similarity.
# 0 distance = perfect match.
score = round((1.0 - results['distances'][0][i]) * 100)
output.append({
"id": fname,
"url": f"http://localhost:8000/Productimages/{fname}",
"score": score
})
return output