resnet50-flowers102-classifier / resnet_classifier.py
sukinggg's picture
Upload Final ResNet50 with Class Mapping (id2label) and Code
23f55fe verified
import torch
import torch.nn as nn
import torchvision.models as models
from huggingface_hub import PyTorchModelHubMixin
class ResNetClassifier(nn.Module, PyTorchModelHubMixin):
def __init__(self, num_classes=102, model_name='resnet50', freeze_backbone=True):
super().__init__()
self.num_classes = num_classes
self.model_name = model_name
self.freeze_backbone = freeze_backbone
if model_name == 'resnet50':
# NOTE: We load weights=None here as the trained weights will be loaded later
self.backbone = models.resnet50(weights=None)
num_ftrs = self.backbone.fc.in_features
self.backbone.fc = nn.Linear(num_ftrs, num_classes)
else:
raise ValueError(f"Unsupported model: {model_name}")
if freeze_backbone:
print(f"Freezing all layers except the final classification layer for {model_name}.")
for param in self.backbone.parameters():
param.requires_grad = False
for param in self.backbone.fc.parameters():
param.requires_grad = True
def forward(self, x):
return self.backbone(x)