LETTER / src /model.py
Sharath33's picture
Upload folder using huggingface_hub
02ac88d verified
# src/model.py
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
class EmbeddingNet(nn.Module):
def __init__(self, embedding_dim=128):
super().__init__()
# Pretrained ResNet-18, strip the final FC layer
backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
self.backbone = nn.Sequential(*list(backbone.children())[:-1]) # β†’ [B, 512, 1, 1]
# Embedding head: 512 β†’ 256 β†’ 128, L2-normalised output
self.head = nn.Sequential(
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(inplace=True),
nn.Linear(256, embedding_dim),
)
def forward(self, x):
# Omniglot is grayscale β€” replicate channel to fake RGB for ResNet
if x.shape[1] == 1:
x = x.repeat(1, 3, 1, 1) # [B, 1, H, W] β†’ [B, 3, H, W]
x = self.backbone(x) # [B, 512, 1, 1]
x = x.view(x.size(0), -1) # [B, 512]
x = self.head(x) # [B, 128]
x = F.normalize(x, p=2, dim=1) # L2 normalise β†’ unit sphere
return x
class SiameseNet(nn.Module):
def __init__(self, embedding_dim=128):
super().__init__()
self.embedding_net = EmbeddingNet(embedding_dim)
def forward(self, img1, img2):
emb1 = self.embedding_net(img1)
emb2 = self.embedding_net(img2)
return emb1, emb2
def get_embedding(self, img):
return self.embedding_net(img)