Spaces:
Sleeping
Sleeping
File size: 1,699 Bytes
24c8da0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
import torch.nn as nn
import torchvision.models as models
class ResNet50_ArcFace(nn.Module):
"""
ResNet-50 model modified for ArcFace loss.
Outputs embeddings that can be used for feature extraction.
"""
def __init__(self, embedding_size=512, pretrained=True):
super(ResNet50_ArcFace, self).__init__()
self.embedding_size = embedding_size
# Load a pre-trained ResNet-50 model
self.backbone = models.resnet50(pretrained=pretrained)
# Modify the final fully connected layer
# Replace the last fully connected layer with a linear layer to get embeddings
in_features = self.backbone.fc.in_features
self.backbone.fc = nn.Linear(in_features, self.embedding_size)
# Normalize the embedding vectors
self.l2_norm = nn.functional.normalize
def forward(self, x):
x = self.backbone(x)
# Normalize embeddings to have unit length
x = self.l2_norm(x, p=2, dim=1)
return x
# Example usage
if __name__ == "__main__":
# Load config parameters if needed
import yaml
with open('config.yml', 'r') as f:
config = yaml.safe_load(f)
device = torch.device(config['device'] if torch.cuda.is_available() else 'cpu')
model = ResNet50_ArcFace(
embedding_size=config['embedding_size'],
pretrained=True
).to(device)
# Print model architecture
print(model)
# Test with a random input
dummy_input = torch.randn(1, 3, config['image_height'], config['image_width']).to(device)
embeddings = model(dummy_input)
print("Embeddings shape:", embeddings.shape) |