mohammedelabbas's picture
Add Speyeria classifier files
06d2725 verified
import torch
from torchvision import models
class SpeyeriaClassifier(torch.nn.Module):
"""ResNet-50 based classifier for Speyeria species."""
def __init__(self, num_classes: int = 16):
super().__init__()
backbone = models.resnet50(weights=None)
backbone.fc = torch.nn.Linear(backbone.fc.in_features, num_classes)
self.backbone = backbone
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.backbone(x)
def create_model(num_classes: int = 16) -> SpeyeriaClassifier:
"""Factory for the ResNet-50 Speyeria classifier."""
return SpeyeriaClassifier(num_classes=num_classes)