File size: 7,871 Bytes
b9a3037 a23e6ed b9a3037 a23e6ed b9a3037 a23e6ed b9a3037 a23e6ed b9a3037 a23e6ed b9a3037 a23e6ed b9a3037 a23e6ed b9a3037 a23e6ed b9a3037 a23e6ed b9a3037 a23e6ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
"""
CIFAR100 ResNet-34 Model Definition with Bottleneck Layers
Contains the model architecture classes for CIFAR100 classification.
This module provides:
- ModelConfig: Configuration for model architecture
- BottleneckBlock: 1x1 bottleneck convolution block
- BasicBlock: Basic residual block
- CIFAR100ResNet34: ResNet-34 architecture for CIFAR-100
Author: Krishnakanth
Date: 2025-10-10
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
from dataclasses import dataclass
# =============================================================================
# MODEL CONFIGURATION
# =============================================================================
@dataclass
class ModelConfig:
"""Configuration for model architecture."""
input_channels: int = 3
input_size: Tuple[int, int] = (32, 32)
num_classes: int = 100
dropout_rate: float = 0.05
# =============================================================================
# BOTTLENECK BLOCK WITH 1x1 CONVOLUTIONS
# =============================================================================
class BottleneckBlock(nn.Module):
"""
1x1 Bottleneck block for ResNet architecture.
Reduces computational complexity by using 1x1 convolutions to reduce and expand channels.
"""
def __init__(self, in_channels, out_channels, stride=1, downsample=None, dropout_rate=0.0):
super(BottleneckBlock, self).__init__()
# First 1x1 conv: reduces channels
self.conv1 = nn.Conv2d(in_channels, out_channels // 4, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels // 4)
# 3x3 conv: main convolution
self.conv2 = nn.Conv2d(out_channels // 4, out_channels // 4, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels // 4)
# Second 1x1 conv: expands channels back
self.conv3 = nn.Conv2d(out_channels // 4, out_channels, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dropout = nn.Dropout2d(dropout_rate) if dropout_rate > 0 else None
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
if self.dropout is not None:
out = self.dropout(out)
return out
class BasicBlock(nn.Module):
"""Basic residual block for ResNet."""
def __init__(self, in_channels, out_channels, stride=1, downsample=None, dropout_rate=0.0):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dropout = nn.Dropout2d(dropout_rate) if dropout_rate > 0 else None
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
if self.dropout is not None:
out = self.dropout(out)
return out
# =============================================================================
# RESNET-34 FOR CIFAR-100
# =============================================================================
class CIFAR100ResNet34(nn.Module):
"""
ResNet-34 architecture for CIFAR-100.
Uses BasicBlock with the 3-4-6-3 layer structure of ResNet-34.
"""
def __init__(self, config: ModelConfig):
super(CIFAR100ResNet34, self).__init__()
self.config = config
# For CIFAR-32x32, use modified initial layer (no stride/pooling)
self.conv1 = nn.Conv2d(config.input_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
# ResNet-34 uses BasicBlock with [3, 4, 6, 3] blocks per layer
# Layer 1: 64 channels, 3 blocks
self.layer1 = self._make_layer(BasicBlock, 64, 64, 3, stride=1, dropout_rate=config.dropout_rate)
# Layer 2: 128 channels, 4 blocks
self.layer2 = self._make_layer(BasicBlock, 64, 128, 4, stride=2, dropout_rate=config.dropout_rate)
# Layer 3: 256 channels, 6 blocks
self.layer3 = self._make_layer(BasicBlock, 128, 256, 6, stride=2, dropout_rate=config.dropout_rate)
# Layer 4: 512 channels, 3 blocks
self.layer4 = self._make_layer(BasicBlock, 256, 512, 3, stride=2, dropout_rate=config.dropout_rate)
# Global Average Pooling and classifier
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.dropout = nn.Dropout(config.dropout_rate)
self.fc = nn.Linear(512, config.num_classes)
# Initialize weights
self._initialize_weights()
def _make_layer(self, block, in_channels, out_channels, blocks, stride=1, dropout_rate=0.0):
"""Create a layer with specified number of blocks."""
downsample = None
# Downsample for first block in layer if needed
if stride != 1 or in_channels != out_channels:
downsample = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
layers = []
# First block with potential downsampling
layers.append(block(in_channels, out_channels, stride, downsample, dropout_rate))
# Remaining blocks
for _ in range(1, blocks):
layers.append(block(out_channels, out_channels, dropout_rate=dropout_rate))
return nn.Sequential(*layers)
def _initialize_weights(self):
"""Initialize network weights."""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
# Initial layer
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
# ResNet layers (3-4-6-3 blocks)
x = self.layer1(x) # 3 blocks
x = self.layer2(x) # 4 blocks
x = self.layer3(x) # 6 blocks
x = self.layer4(x) # 3 blocks
# Global average pooling and classification
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.dropout(x)
x = self.fc(x)
return F.log_softmax(x, dim=1)
# Aliases for compatibility
CIFAR100Model = CIFAR100ResNet34
CIFAR100ResNet18 = CIFAR100ResNet34 # Backward compatibility alias |