Spaces:
Sleeping
Sleeping
File size: 1,914 Bytes
c65e61c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | import torch
import torch.nn as nn
from torchvision import models
def load_resnet18(num_classes=10, pretrained=False):
"""
Load ResNet18 modified for CIFAR-10 classification.
Args:
num_classes: Number of output classes (default: 10 for CIFAR-10)
pretrained: Whether to use ImageNet pretrained weights (default: False for fair comparison)
Returns:
Modified ResNet18 model
"""
# Load ResNet18 without pretrained weights for fair comparison
weights = models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
model = models.resnet18(weights=weights)
# Replace final layer for CIFAR-10 (10 classes)
model.fc = nn.Linear(model.fc.in_features, num_classes)
# Initialize the new classifier layer properly
nn.init.normal_(model.fc.weight, 0, 0.01)
nn.init.constant_(model.fc.bias, 0)
return model
def get_resnet18_info():
"""Return ResNet18 model information."""
model = load_resnet18()
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {
'total_params': total_params,
'trainable_params': trainable_params,
'model_size_mb': total_params * 4 / (1024 * 1024),
'architecture': 'ResNet18 with modified classifier',
'original_fc_features': 512,
'modified_fc_classes': 10
}
def freeze_backbone(model, freeze=True):
"""
Freeze/unfreeze ResNet18 backbone for transfer learning experiments.
Args:
model: ResNet18 model
freeze: Whether to freeze backbone parameters
"""
for name, param in model.named_parameters():
if 'fc' not in name: # Don't freeze the final classifier
param.requires_grad = not freeze
return model |