File size: 647 Bytes
06d2725
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from torchvision import models


class SpeyeriaClassifier(torch.nn.Module):
    """ResNet-50 based classifier for Speyeria species."""

    def __init__(self, num_classes: int = 16):
        super().__init__()
        backbone = models.resnet50(weights=None)
        backbone.fc = torch.nn.Linear(backbone.fc.in_features, num_classes)
        self.backbone = backbone

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.backbone(x)


def create_model(num_classes: int = 16) -> SpeyeriaClassifier:
    """Factory for the ResNet-50 Speyeria classifier."""
    return SpeyeriaClassifier(num_classes=num_classes)