| |
| import torch |
| import torch.nn as nn |
| import torchvision.models as models |
| import torch.nn.functional as F |
|
|
| class EmbeddingNet(nn.Module): |
| def __init__(self, embedding_dim=128): |
| super().__init__() |
|
|
| |
| backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) |
| self.backbone = nn.Sequential(*list(backbone.children())[:-1]) |
|
|
| |
| self.head = nn.Sequential( |
| nn.Linear(512, 256), |
| nn.BatchNorm1d(256), |
| nn.ReLU(inplace=True), |
| nn.Linear(256, embedding_dim), |
| ) |
|
|
| def forward(self, x): |
| |
| if x.shape[1] == 1: |
| x = x.repeat(1, 3, 1, 1) |
|
|
| x = self.backbone(x) |
| x = x.view(x.size(0), -1) |
| x = self.head(x) |
| x = F.normalize(x, p=2, dim=1) |
| return x |
|
|
|
|
| class SiameseNet(nn.Module): |
| def __init__(self, embedding_dim=128): |
| super().__init__() |
| self.embedding_net = EmbeddingNet(embedding_dim) |
|
|
| def forward(self, img1, img2): |
| emb1 = self.embedding_net(img1) |
| emb2 = self.embedding_net(img2) |
| return emb1, emb2 |
|
|
| def get_embedding(self, img): |
| return self.embedding_net(img) |