mvi-ai-engine / vision /image_encoder.py
Musombi's picture
Upload folder using huggingface_hub
0045f6d verified
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
class ImageEncoder(nn.Module):
def __init__(self, embed_dim: int = 128, pretrained: bool = True):
super().__init__()
# Load backbone properly
if pretrained:
backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
else:
backbone = models.resnet18(weights=None)
# Remove classification head
self.feature_extractor = nn.Sequential(*list(backbone.children())[:-1])
# Projection head (stronger)
self.projection = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, embed_dim)
)
def forward(self, x):
"""
x: (B, 3, H, W)
returns: (B, embed_dim)
"""
features = self.feature_extractor(x) # (B, 512, 1, 1)
features = features.view(features.size(0), -1)
embeddings = self.projection(features)
# Normalize embeddings (very important)
embeddings = F.normalize(embeddings, dim=1)
return embeddings