|
|
""" |
|
|
CIFAR100 Custom ResNet Model Definition |
|
|
Matches the architecture used for training cifar100_model.pth |
|
|
|
|
|
Author: Krishnakanth |
|
|
Date: 2025-10-10 |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Tuple |
|
|
from dataclasses import dataclass |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelConfig: |
|
|
"""Configuration for model architecture.""" |
|
|
input_channels: int = 3 |
|
|
input_size: Tuple[int, int] = (32, 32) |
|
|
num_classes: int = 100 |
|
|
dropout_rate: float = 0.05 |
|
|
|
|
|
|
|
|
class CIFAR100Model(nn.Module): |
|
|
""" |
|
|
CIFAR100 classification model - Custom ResNet architecture. |
|
|
This matches the architecture of your trained model (23M parameters). |
|
|
""" |
|
|
|
|
|
def __init__(self, config: ModelConfig): |
|
|
super(CIFAR100Model, self).__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.conv1 = nn.Sequential( |
|
|
nn.Conv2d(self.config.input_channels, 64, kernel_size=7, stride=1, padding=4), |
|
|
nn.BatchNorm2d(64), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
|
|
|
|
|
|
self.res1 = nn.Sequential(self.conv_block(64, 64, activation=True), self.conv_block(64, 64)) |
|
|
self.res2 = nn.Sequential(self.conv_block(64, 64, activation=True), self.conv_block(64, 64)) |
|
|
self.res3 = nn.Sequential(self.conv_block(64, 64, activation=True), self.conv_block(64, 64)) |
|
|
|
|
|
self.downsample1 = nn.Sequential(self.conv_block(64, 128, pool=True)) |
|
|
self.res4 = nn.Sequential( |
|
|
self.conv_block(64, 128, activation=True, pool=True), |
|
|
self.conv_block(128, 128) |
|
|
) |
|
|
self.res5 = nn.Sequential(self.conv_block(128, 128, activation=True), self.conv_block(128, 128)) |
|
|
self.res6 = nn.Sequential(self.conv_block(128, 128, activation=True), self.conv_block(128, 128)) |
|
|
self.res7 = nn.Sequential(self.conv_block(128, 128, activation=True), self.conv_block(128, 128)) |
|
|
|
|
|
self.res8 = nn.Sequential( |
|
|
self.conv_block(128, 256, activation=True, pool=True), |
|
|
self.conv_block(256, 256) |
|
|
) |
|
|
self.downsample2 = nn.Sequential(self.conv_block(128, 256, pool=True)) |
|
|
self.res9 = nn.Sequential(self.conv_block(256, 256, activation=True), self.conv_block(256, 256)) |
|
|
self.res10 = nn.Sequential(self.conv_block(256, 256, activation=True), self.conv_block(256, 256)) |
|
|
self.res11 = nn.Sequential(self.conv_block(256, 256, activation=True), self.conv_block(256, 256)) |
|
|
self.res12 = nn.Sequential(self.conv_block(256, 256, activation=True), self.conv_block(256, 256)) |
|
|
self.res13 = nn.Sequential(self.conv_block(256, 256, activation=True), self.conv_block(256, 256)) |
|
|
|
|
|
self.res14 = nn.Sequential( |
|
|
self.conv_block(256, 512, activation=True, pool=True), |
|
|
self.conv_block(512, 512) |
|
|
) |
|
|
self.downsample3 = nn.Sequential(self.conv_block(256, 512, pool=True)) |
|
|
self.res15 = nn.Sequential(self.conv_block(512, 512, activation=True), self.conv_block(512, 512)) |
|
|
self.res16 = nn.Sequential( |
|
|
self.conv_block(512, 512, activation=True), |
|
|
self.conv_block(512, 512, activation=True) |
|
|
) |
|
|
|
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.AdaptiveMaxPool2d((1, 1)), |
|
|
nn.Flatten(), |
|
|
nn.Linear(512, 1000) |
|
|
) |
|
|
|
|
|
self.apply(self.init_weights) |
|
|
|
|
|
def conv_block(self, in_channels, out_channels, activation=True, pool=False): |
|
|
"""Create a convolutional block.""" |
|
|
layers = [ |
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), |
|
|
nn.BatchNorm2d(out_channels) |
|
|
] |
|
|
if activation: |
|
|
layers.append(nn.ReLU(inplace=True)) |
|
|
if pool: |
|
|
layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) |
|
|
return nn.Sequential(*layers) |
|
|
|
|
|
def init_weights(self, m): |
|
|
"""Initialize weights.""" |
|
|
if isinstance(m, nn.Conv2d): |
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') |
|
|
|
|
|
def forward(self, xb): |
|
|
out = self.conv1(xb) |
|
|
out = self.res1(out) + out |
|
|
out = self.res2(out) + out |
|
|
out = self.res3(out) + out |
|
|
out = self.downsample1(out) + self.res4(out) |
|
|
out = self.res5(out) + out |
|
|
out = self.res6(out) + out |
|
|
out = self.res7(out) + out |
|
|
out = self.downsample2(out) + self.res8(out) |
|
|
out = self.res9(out) + out |
|
|
out = self.res10(out) + out |
|
|
out = self.res11(out) + out |
|
|
out = self.res12(out) + out |
|
|
out = self.res13(out) + out |
|
|
out = self.downsample3(out) + self.res14(out) |
|
|
out = self.res15(out) + out |
|
|
out = self.res16(out) + out |
|
|
out = self.classifier(out) |
|
|
|
|
|
|
|
|
return F.log_softmax(out[:, :100], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
CIFAR100ResNet34 = CIFAR100Model |
|
|
CIFAR100ResNet18 = CIFAR100Model |