File size: 1,279 Bytes
4716563
 
 
 
24ea486
4716563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24ea486
 
4716563
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as tvm


class ResNetItemEmbedder(nn.Module):
    def __init__(self, embedding_dim: int = 512, backbone: str = "resnet50", pretrained: bool = True) -> None:
        super().__init__()
        if backbone == "resnet50":
            model = tvm.resnet50(weights=tvm.ResNet50_Weights.DEFAULT if pretrained else None)
            feat_dim = 2048
        elif backbone == "resnet101":
            model = tvm.resnet101(weights=tvm.ResNet101_Weights.DEFAULT if pretrained else None)
            feat_dim = 2048
        else:
            raise ValueError(f"Unsupported backbone: {backbone}")

        # Remove classifier, keep global average pooling output
        modules = list(model.children())[:-1]  # drop fc
        self.backbone = nn.Sequential(*modules)
        self.proj = nn.Linear(feat_dim, embedding_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, 3, H, W)
        feats = self.backbone(x)  # (B, C, 1, 1)
        feats = feats.flatten(1)  # (B, C)
        emb = self.proj(feats)    # (B, D)
        # Apply L2 normalization as specified in requirements
        emb = F.normalize(emb, p=2, dim=1)
        return emb