|
|
import torch |
|
|
from torch import nn |
|
|
from torchvision import transforms |
|
|
from torchvision.models import resnet18,resnet50,ResNet50_Weights |
|
|
|
|
|
class SimpleCNN(nn.Module): |
|
|
def __init__(self, num_inputs=1,input_size=28,num_classes=10, dropout_rate=0.3): |
|
|
super().__init__() |
|
|
|
|
|
self.features = nn.Sequential( |
|
|
|
|
|
nn.Conv2d(num_inputs, 32, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(32), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d(2), |
|
|
|
|
|
|
|
|
nn.Conv2d(32, 64, kernel_size=3, padding=1), |
|
|
|
|
|
nn.BatchNorm2d(64), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d(2), |
|
|
) |
|
|
final_size = input_size // 4 |
|
|
flatten_dim = 64 * final_size * final_size |
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.Flatten(), |
|
|
|
|
|
nn.Linear(flatten_dim, 512), |
|
|
nn.BatchNorm1d(512), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout_rate), |
|
|
nn.Linear(512, num_classes), |
|
|
) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
|
|
|
if isinstance(m, (nn.Conv2d, nn.Linear)): |
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
|
|
|
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): |
|
|
nn.init.constant_(m.weight, 1) |
|
|
|
|
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
x = self.features(x) |
|
|
x = self.classifier(x) |
|
|
return x |
|
|
|
|
|
class ResNet18_CIFAR(nn.Module): |
|
|
def __init__(self,num_inputs=3,num_classes=10,dropout_rate=0.0): |
|
|
super().__init__() |
|
|
|
|
|
self.aug = nn.Sequential( |
|
|
transforms.RandomHorizontalFlip(), |
|
|
transforms.RandomCrop(32,padding=4,padding_mode='reflect'), |
|
|
) |
|
|
|
|
|
self.net = resnet18(weights=None) |
|
|
|
|
|
self.net.conv1 = nn.Conv2d(num_inputs,64,kernel_size=3,stride=1,padding=1,bias=False) |
|
|
self.net.maxpool = nn.Identity() |
|
|
|
|
|
self.net.fc = nn.Linear(512,num_classes) |
|
|
|
|
|
def forward(self,x): |
|
|
if self.training: |
|
|
x = self.aug(x) |
|
|
return self.net(x) |
|
|
|
|
|
class TransferResNet50(nn.Module): |
|
|
def __init__(self, num_classes=10, dropout_rate=0.0): |
|
|
super().__init__() |
|
|
|
|
|
print("⬇️ Loading Pre-trained ResNet50 (ImageNet)...") |
|
|
|
|
|
self.net = resnet50(weights=ResNet50_Weights.DEFAULT) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_ftrs = self.net.fc.in_features |
|
|
self.net.fc = nn.Sequential( |
|
|
nn.Dropout(dropout_rate), |
|
|
nn.Linear(num_ftrs, num_classes), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.net(x) |