File size: 2,683 Bytes
ba2fc46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# backend/src/services/visual/engine.py

import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
from io import BytesIO

# Global variable taake Model sirf ek baar load ho (Memory Bachane ke liye)
_visual_model_instance = None
_preprocess_instance = None

def get_visual_model():
    """
    Singleton Pattern: ResNet50 ko sirf tab load karega jab pehli baar zaroorat hogi.
    Railway/Serverless par RAM bachane ke liye zaroori hai.
    """
    global _visual_model_instance, _preprocess_instance
    
    if _visual_model_instance is None:
        print("👁️ [Visual Engine] Loading ResNet50 AI Model...")
        # 1. Load Standard ResNet50
        full_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        full_model.eval() # Inference mode (Training off)
        
        # 2. Remove the last Classification Layer
        # Humein "Cat/Dog" label nahi chahiye, humein features (vectors) chahiye.
        _visual_model_instance = torch.nn.Sequential(*(list(full_model.children())[:-1]))
        
        # 3. Define Image Preprocessing steps
        _preprocess_instance = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        print("✅ [Visual Engine] Model Loaded Successfully.")
        
    return _visual_model_instance, _preprocess_instance

def get_image_embedding(image_bytes: bytes) -> list:
    """
    Image ke bytes leta hai -> ResNet50 se guzarta hai -> 2048 numbers ki list wapis karta hai.
    """
    model, preprocess = get_visual_model()
    
    try:
        # 1. Convert Bytes to Image
        img = Image.open(BytesIO(image_bytes)).convert("RGB")
        
        # 2. Preprocess
        img_tensor = preprocess(img)
        batch_img_tensor = torch.unsqueeze(img_tensor, 0) # Batch dimension add karein
        
        # 3. Generate Embedding
        with torch.no_grad():
            embedding = model(batch_img_tensor)
        
        # 4. Flatten & Normalize (Cosine Similarity ke liye zaroori)
        embedding_np = embedding.flatten().numpy()
        norm = np.linalg.norm(embedding_np)
        
        # Zero division se bachne ke liye
        if norm == 0:
            return embedding_np.tolist()
            
        normalized_embedding = (embedding_np / norm).astype('float32')
        
        return normalized_embedding.tolist()
        
    except Exception as e:
        print(f"❌ [Visual Engine] Error processing image: {e}")
        return None