import torch import torch.nn as nn import torchvision.models as models class CatvsDogResNet50(nn.Module): def __init__(self, freeze_backbone: bool = True): super().__init__() self.backbone = models.resnet50(pretrained=True) if freeze_backbone: for param in self.backbone.parameters(): param.requires_grad = False num_ftrs = self.backbone.fc.in_features self.backbone.fc = nn.Sequential( nn.Dropout(0.5), nn.Linear(num_ftrs, 1), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.backbone(x)