Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| class ResNet50_ArcFace(nn.Module): | |
| """ | |
| ResNet-50 model modified for ArcFace loss. | |
| Outputs embeddings that can be used for feature extraction. | |
| """ | |
| def __init__(self, embedding_size=512, pretrained=True): | |
| super(ResNet50_ArcFace, self).__init__() | |
| self.embedding_size = embedding_size | |
| # Load a pre-trained ResNet-50 model | |
| self.backbone = models.resnet50(pretrained=pretrained) | |
| # Modify the final fully connected layer | |
| # Replace the last fully connected layer with a linear layer to get embeddings | |
| in_features = self.backbone.fc.in_features | |
| self.backbone.fc = nn.Linear(in_features, self.embedding_size) | |
| # Normalize the embedding vectors | |
| self.l2_norm = nn.functional.normalize | |
| def forward(self, x): | |
| x = self.backbone(x) | |
| # Normalize embeddings to have unit length | |
| x = self.l2_norm(x, p=2, dim=1) | |
| return x | |
| # Example usage | |
| if __name__ == "__main__": | |
| # Load config parameters if needed | |
| import yaml | |
| with open('config.yml', 'r') as f: | |
| config = yaml.safe_load(f) | |
| device = torch.device(config['device'] if torch.cuda.is_available() else 'cpu') | |
| model = ResNet50_ArcFace( | |
| embedding_size=config['embedding_size'], | |
| pretrained=True | |
| ).to(device) | |
| # Print model architecture | |
| print(model) | |
| # Test with a random input | |
| dummy_input = torch.randn(1, 3, config['image_height'], config['image_width']).to(device) | |
| embeddings = model(dummy_input) | |
| print("Embeddings shape:", embeddings.shape) |