Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class CustomCNN(nn.Module): | |
| """ | |
| Custom CNN optimized for CIFAR-10 with ~3-4M parameters. | |
| Designed to compete with ResNet18 while being lightweight. | |
| """ | |
| def __init__(self, num_classes=10, dropout_rate=0.4): | |
| super(CustomCNN, self).__init__() | |
| # Feature extractor with efficient blocks | |
| self.features = nn.Sequential( | |
| # Block 1: 32x32 -> 32x32 | |
| self._conv_block(3, 64, stride=1), | |
| self._conv_block(64, 64, stride=1), | |
| # Block 2: 32x32 -> 16x16 | |
| self._conv_block(64, 128, stride=2), | |
| self._conv_block(128, 128, stride=1), | |
| # Block 3: 16x16 -> 8x8 | |
| self._conv_block(128, 256, stride=2), | |
| self._conv_block(256, 256, stride=1), | |
| # Block 4: 8x8 -> 4x4 (deep feature extraction) | |
| self._conv_block(256, 512, stride=2), | |
| ) | |
| # Global pooling and classifier | |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.dropout = nn.Dropout(dropout_rate) | |
| self.classifier = nn.Linear(512, num_classes) | |
| # Initialize weights | |
| self._initialize_weights() | |
| def _conv_block(self, in_channels, out_channels, stride=1): | |
| """Efficient conv block with BatchNorm and ReLU.""" | |
| return nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, | |
| stride=stride, padding=1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True) | |
| ) | |
| def _initialize_weights(self): | |
| """Initialize weights using He initialization for ReLU.""" | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
| elif isinstance(m, nn.BatchNorm2d): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.Linear): | |
| nn.init.normal_(m.weight, 0, 0.01) | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = self.avgpool(x) | |
| x = torch.flatten(x, 1) | |
| x = self.dropout(x) | |
| x = self.classifier(x) | |
| return x | |
| def get_model_info(self): | |
| """Return model architecture info.""" | |
| total_params = sum(p.numel() for p in self.parameters()) | |
| trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| return { | |
| 'total_params': total_params, | |
| 'trainable_params': trainable_params, | |
| 'model_size_mb': total_params * 4 / (1024 * 1024), # Assuming float32 | |
| 'architecture': '7-layer CNN with BatchNorm and Dropout' | |
| } | |
| # Factory function for easy instantiation | |
| def create_custom_cnn(num_classes=10, dropout_rate=0.4): | |
| """Create and return a CustomCNN instance.""" | |
| return CustomCNN(num_classes=num_classes, dropout_rate=dropout_rate) | |