USAMA BHATTI
Feat: Added Visual Search, API Key Auth, and Docker Optimization
ba2fc46
# backend/src/services/visual/engine.py
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
from io import BytesIO
# Global variable taake Model sirf ek baar load ho (Memory Bachane ke liye)
_visual_model_instance = None
_preprocess_instance = None
def get_visual_model():
"""
Singleton Pattern: ResNet50 ko sirf tab load karega jab pehli baar zaroorat hogi.
Railway/Serverless par RAM bachane ke liye zaroori hai.
"""
global _visual_model_instance, _preprocess_instance
if _visual_model_instance is None:
print("👁️ [Visual Engine] Loading ResNet50 AI Model...")
# 1. Load Standard ResNet50
full_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
full_model.eval() # Inference mode (Training off)
# 2. Remove the last Classification Layer
# Humein "Cat/Dog" label nahi chahiye, humein features (vectors) chahiye.
_visual_model_instance = torch.nn.Sequential(*(list(full_model.children())[:-1]))
# 3. Define Image Preprocessing steps
_preprocess_instance = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
print("✅ [Visual Engine] Model Loaded Successfully.")
return _visual_model_instance, _preprocess_instance
def get_image_embedding(image_bytes: bytes) -> list:
"""
Image ke bytes leta hai -> ResNet50 se guzarta hai -> 2048 numbers ki list wapis karta hai.
"""
model, preprocess = get_visual_model()
try:
# 1. Convert Bytes to Image
img = Image.open(BytesIO(image_bytes)).convert("RGB")
# 2. Preprocess
img_tensor = preprocess(img)
batch_img_tensor = torch.unsqueeze(img_tensor, 0) # Batch dimension add karein
# 3. Generate Embedding
with torch.no_grad():
embedding = model(batch_img_tensor)
# 4. Flatten & Normalize (Cosine Similarity ke liye zaroori)
embedding_np = embedding.flatten().numpy()
norm = np.linalg.norm(embedding_np)
# Zero division se bachne ke liye
if norm == 0:
return embedding_np.tolist()
normalized_embedding = (embedding_np / norm).astype('float32')
return normalized_embedding.tolist()
except Exception as e:
print(f"❌ [Visual Engine] Error processing image: {e}")
return None