import torch import torch.nn as nn import torchvision.models as models import torch.nn.functional as F class ImageEncoder(nn.Module): def __init__(self, embed_dim: int = 128, pretrained: bool = True): super().__init__() # Load backbone properly if pretrained: backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) else: backbone = models.resnet18(weights=None) # Remove classification head self.feature_extractor = nn.Sequential(*list(backbone.children())[:-1]) # Projection head (stronger) self.projection = nn.Sequential( nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, embed_dim) ) def forward(self, x): """ x: (B, 3, H, W) returns: (B, embed_dim) """ features = self.feature_extractor(x) # (B, 512, 1, 1) features = features.view(features.size(0), -1) embeddings = self.projection(features) # Normalize embeddings (very important) embeddings = F.normalize(embeddings, dim=1) return embeddings