File size: 1,914 Bytes
c65e61c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torch.nn as nn
from torchvision import models

def load_resnet18(num_classes=10, pretrained=False):
    """

    Load ResNet18 modified for CIFAR-10 classification.

    

    Args:

        num_classes: Number of output classes (default: 10 for CIFAR-10)

        pretrained: Whether to use ImageNet pretrained weights (default: False for fair comparison)

        

    Returns:

        Modified ResNet18 model

    """
    # Load ResNet18 without pretrained weights for fair comparison
    weights = models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
    model = models.resnet18(weights=weights)
    
    # Replace final layer for CIFAR-10 (10 classes)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    
    # Initialize the new classifier layer properly
    nn.init.normal_(model.fc.weight, 0, 0.01)
    nn.init.constant_(model.fc.bias, 0)
    
    return model

def get_resnet18_info():
    """Return ResNet18 model information."""
    model = load_resnet18()
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    return {
        'total_params': total_params,
        'trainable_params': trainable_params,
        'model_size_mb': total_params * 4 / (1024 * 1024),
        'architecture': 'ResNet18 with modified classifier',
        'original_fc_features': 512,
        'modified_fc_classes': 10
    }

def freeze_backbone(model, freeze=True):
    """

    Freeze/unfreeze ResNet18 backbone for transfer learning experiments.

    

    Args:

        model: ResNet18 model

        freeze: Whether to freeze backbone parameters

    """
    for name, param in model.named_parameters():
        if 'fc' not in name:  # Don't freeze the final classifier
            param.requires_grad = not freeze
    
    return model