Spaces:
Runtime error
Runtime error
File size: 1,148 Bytes
c1e438c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
class ImageEncoder(nn.Module):
def __init__(self, embed_dim: int = 128, pretrained: bool = True):
super().__init__()
# Load backbone properly
if pretrained:
backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
else:
backbone = models.resnet18(weights=None)
# Remove classification head
self.feature_extractor = nn.Sequential(*list(backbone.children())[:-1])
# Projection head (stronger)
self.projection = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, embed_dim)
)
def forward(self, x):
"""
x: (B, 3, H, W)
returns: (B, embed_dim)
"""
features = self.feature_extractor(x) # (B, 512, 1, 1)
features = features.view(features.size(0), -1)
embeddings = self.projection(features)
# Normalize embeddings (very important)
embeddings = F.normalize(embeddings, dim=1)
return embeddings
|