File size: 1,148 Bytes
c1e438c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F


class ImageEncoder(nn.Module):
    def __init__(self, embed_dim: int = 128, pretrained: bool = True):
        super().__init__()

        # Load backbone properly
        if pretrained:
            backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        else:
            backbone = models.resnet18(weights=None)

        # Remove classification head
        self.feature_extractor = nn.Sequential(*list(backbone.children())[:-1])

        # Projection head (stronger)
        self.projection = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, embed_dim)
        )

    def forward(self, x):
        """
        x: (B, 3, H, W)
        returns: (B, embed_dim)
        """
        features = self.feature_extractor(x)          # (B, 512, 1, 1)
        features = features.view(features.size(0), -1)

        embeddings = self.projection(features)

        # Normalize embeddings (very important)
        embeddings = F.normalize(embeddings, dim=1)

        return embeddings