| import torch | |
| from torchvision import models | |
| class SpeyeriaClassifier(torch.nn.Module): | |
| """ResNet-50 based classifier for Speyeria species.""" | |
| def __init__(self, num_classes: int = 16): | |
| super().__init__() | |
| backbone = models.resnet50(weights=None) | |
| backbone.fc = torch.nn.Linear(backbone.fc.in_features, num_classes) | |
| self.backbone = backbone | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.backbone(x) | |
| def create_model(num_classes: int = 16) -> SpeyeriaClassifier: | |
| """Factory for the ResNet-50 Speyeria classifier.""" | |
| return SpeyeriaClassifier(num_classes=num_classes) | |