Spaces:
Sleeping
Sleeping
| import math | |
| import torch | |
| import torchvision | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| # Add more imports if required | |
| # Transformation function - exactly from notebook | |
| trnscm = transforms.Compose([transforms.Resize((100,100)), transforms.ToTensor()]) | |
| # Complete Siamese Network from notebook | |
| class SiameseNetwork(nn.Module): | |
| def __init__(self): | |
| super(SiameseNetwork, self).__init__() | |
| self.cnn1 = nn.Sequential( | |
| nn.ReflectionPad2d(1), # Pads the input tensor using the reflection of the input boundary | |
| nn.Conv2d(1, 4, kernel_size=3), | |
| nn.ReLU(inplace=True), | |
| nn.BatchNorm2d(4), | |
| nn.ReflectionPad2d(1), | |
| nn.Conv2d(4, 8, kernel_size=3), | |
| nn.ReLU(inplace=True), | |
| nn.BatchNorm2d(8), | |
| nn.ReflectionPad2d(1), | |
| nn.Conv2d(8, 8, kernel_size=3), | |
| nn.ReLU(inplace=True), | |
| nn.BatchNorm2d(8), | |
| ) | |
| self.fc1 = nn.Sequential( | |
| nn.Linear(8*100*100, 500), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(500, 500), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(500, 5)) | |
| # forward_once is for one image. This can be used while classifying the face images | |
| def forward_once(self, x): | |
| output = self.cnn1(x) | |
| output = output.view(output.size()[0], -1) | |
| output = self.fc1(output) | |
| return output | |
| def forward(self, input1, input2): | |
| output1 = self.forward_once(input1) | |
| output2 = self.forward_once(input2) | |
| return output1, output2 | |
| # Backward compatibility alias | |
| Siamese = SiameseNetwork | |
| # Contrastive Loss for reference | |
| class ContrastiveLoss(torch.nn.Module): | |
| def __init__(self, margin=2.0): | |
| super(ContrastiveLoss, self).__init__() | |
| self.margin = margin | |
| def forward(self, output1, output2, label): | |
| euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True) | |
| loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) + | |
| (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)) | |
| return loss_contrastive | |
| # Definition of classes as dictionary - Updated to match 5 classes from training | |
| classes = ['Aayush','Aditya','Vikram','Aditi','Suchitra'] | |