File size: 1,587 Bytes
02ac88d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | # src/model.py
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
class EmbeddingNet(nn.Module):
def __init__(self, embedding_dim=128):
super().__init__()
# Pretrained ResNet-18, strip the final FC layer
backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
self.backbone = nn.Sequential(*list(backbone.children())[:-1]) # → [B, 512, 1, 1]
# Embedding head: 512 → 256 → 128, L2-normalised output
self.head = nn.Sequential(
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(inplace=True),
nn.Linear(256, embedding_dim),
)
def forward(self, x):
# Omniglot is grayscale — replicate channel to fake RGB for ResNet
if x.shape[1] == 1:
x = x.repeat(1, 3, 1, 1) # [B, 1, H, W] → [B, 3, H, W]
x = self.backbone(x) # [B, 512, 1, 1]
x = x.view(x.size(0), -1) # [B, 512]
x = self.head(x) # [B, 128]
x = F.normalize(x, p=2, dim=1) # L2 normalise → unit sphere
return x
class SiameseNet(nn.Module):
def __init__(self, embedding_dim=128):
super().__init__()
self.embedding_net = EmbeddingNet(embedding_dim)
def forward(self, img1, img2):
emb1 = self.embedding_net(img1)
emb2 = self.embedding_net(img2)
return emb1, emb2
def get_embedding(self, img):
return self.embedding_net(img) |