Spaces:
Paused
Paused
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.models as tvm | |
| class ResNetItemEmbedder(nn.Module): | |
| def __init__(self, embedding_dim: int = 512, backbone: str = "resnet50", pretrained: bool = True) -> None: | |
| super().__init__() | |
| if backbone == "resnet50": | |
| model = tvm.resnet50(weights=tvm.ResNet50_Weights.DEFAULT if pretrained else None) | |
| feat_dim = 2048 | |
| elif backbone == "resnet101": | |
| model = tvm.resnet101(weights=tvm.ResNet101_Weights.DEFAULT if pretrained else None) | |
| feat_dim = 2048 | |
| else: | |
| raise ValueError(f"Unsupported backbone: {backbone}") | |
| # Remove classifier, keep global average pooling output | |
| modules = list(model.children())[:-1] # drop fc | |
| self.backbone = nn.Sequential(*modules) | |
| self.proj = nn.Linear(feat_dim, embedding_dim) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # x: (B, 3, H, W) | |
| feats = self.backbone(x) # (B, C, 1, 1) | |
| feats = feats.flatten(1) # (B, C) | |
| emb = self.proj(feats) # (B, D) | |
| # Apply L2 normalization as specified in requirements | |
| emb = F.normalize(emb, p=2, dim=1) | |
| return emb | |