# src/model.py import torch import torch.nn as nn import torchvision.models as models import torch.nn.functional as F class EmbeddingNet(nn.Module): def __init__(self, embedding_dim=128): super().__init__() # Pretrained ResNet-18, strip the final FC layer backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) self.backbone = nn.Sequential(*list(backbone.children())[:-1]) # → [B, 512, 1, 1] # Embedding head: 512 → 256 → 128, L2-normalised output self.head = nn.Sequential( nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(inplace=True), nn.Linear(256, embedding_dim), ) def forward(self, x): # Omniglot is grayscale — replicate channel to fake RGB for ResNet if x.shape[1] == 1: x = x.repeat(1, 3, 1, 1) # [B, 1, H, W] → [B, 3, H, W] x = self.backbone(x) # [B, 512, 1, 1] x = x.view(x.size(0), -1) # [B, 512] x = self.head(x) # [B, 128] x = F.normalize(x, p=2, dim=1) # L2 normalise → unit sphere return x class SiameseNet(nn.Module): def __init__(self, embedding_dim=128): super().__init__() self.embedding_net = EmbeddingNet(embedding_dim) def forward(self, img1, img2): emb1 = self.embedding_net(img1) emb2 = self.embedding_net(img2) return emb1, emb2 def get_embedding(self, img): return self.embedding_net(img)