FrAnKu34t23's picture
Update models.py
40348b6 verified
"""
Bird classification model architectures with overfitting prevention.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from typing import Optional
# Try to import EfficientNet
try:
from efficientnet_pytorch import EfficientNet
EFFICIENTNET_AVAILABLE = True
except ImportError:
EFFICIENTNET_AVAILABLE = False
print("EfficientNet not available. Install with: pip install efficientnet-pytorch")
class BirdClassifier(nn.Module):
"""
Bird classification model with ResNet backbone and overfitting prevention.
"""
def __init__(self, num_classes: int, architecture: str = 'resnet50',
pretrained: bool = True, dropout_rate: float = 0.5,
freeze_backbone: bool = False):
"""
Initialize the bird classifier.
Args:
num_classes: Number of bird classes
architecture: Backbone architecture ('resnet50', 'resnet18', 'efficientnet_b0')
pretrained: Whether to use pretrained weights
dropout_rate: Dropout rate for regularization
freeze_backbone: Whether to freeze backbone weights
"""
super(BirdClassifier, self).__init__()
self.num_classes = num_classes
self.dropout_rate = dropout_rate
# Choose backbone architecture
if architecture == 'resnet50':
self.backbone = models.resnet50(pretrained=pretrained)
num_features = self.backbone.fc.in_features
self.backbone.fc = nn.Identity() # Remove original classifier
elif architecture == 'resnet18':
self.backbone = models.resnet18(pretrained=pretrained)
num_features = self.backbone.fc.in_features
self.backbone.fc = nn.Identity()
elif architecture == 'resnet101':
self.backbone = models.resnet101(pretrained=pretrained)
num_features = self.backbone.fc.in_features
self.backbone.fc = nn.Identity()
elif architecture == 'efficientnet_b0':
self.backbone = models.efficientnet_b0(pretrained=pretrained)
num_features = self.backbone.classifier[1].in_features
self.backbone.classifier = nn.Identity()
elif architecture in ['efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4'] and EFFICIENTNET_AVAILABLE:
model_name = architecture.replace('_', '-')
if pretrained:
self.backbone = EfficientNet.from_pretrained(model_name)
else:
self.backbone = EfficientNet.from_name(model_name)
num_features = self.backbone._fc.in_features
self.backbone._fc = nn.Identity()
else:
raise ValueError(f"Unsupported architecture: {architecture}")
# Freeze backbone if requested
if freeze_backbone:
for param in self.backbone.parameters():
param.requires_grad = False
# Enhanced classifier head with batch normalization and progressive dimension reduction
# Optimized regularization for Stage 2 performance (76.74% accuracy)
self.classifier = nn.Sequential(
nn.Dropout(p=dropout_rate * 0.6), # Stage 2 optimization: 0.3 * 0.6 = 0.18
nn.Linear(num_features, 512), # Optimized size
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout_rate * 0.5), # Stage 2 optimization: 0.3 * 0.5 = 0.15
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout_rate * 0.3), # Stage 2 optimization: 0.3 * 0.3 = 0.09
nn.Linear(256, num_classes)
)
# Initialize weights
self._initialize_weights()
def _initialize_weights(self):
"""Initialize classifier weights with better initialization."""
for m in self.classifier.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
"""Forward pass."""
features = self.backbone(x)
output = self.classifier(features)
return output
class LightweightBirdClassifier(nn.Module):
"""
Lightweight CNN model for bird classification with batch normalization.
"""
def __init__(self, num_classes: int, dropout_rate: float = 0.5):
"""
Initialize lightweight classifier.
Args:
num_classes: Number of bird classes
dropout_rate: Dropout rate for regularization
"""
super(LightweightBirdClassifier, self).__init__()
self.features = nn.Sequential(
# Block 1
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Dropout2d(p=dropout_rate/2),
# Block 2
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Dropout2d(p=dropout_rate/2),
# Block 3
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Dropout2d(p=dropout_rate/2),
# Block 4
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1)),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Dropout(p=dropout_rate),
nn.Linear(256, 128),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout_rate),
nn.Linear(128, num_classes)
)
self._initialize_weights()
def _initialize_weights(self):
"""Initialize model weights."""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0)
def forward(self, x):
"""Forward pass."""
x = self.features(x)
x = self.classifier(x)
return x
def create_model(num_classes: int, model_type: str = 'resnet50',
pretrained: bool = True, dropout_rate: float = 0.5,
freeze_backbone: bool = False) -> nn.Module:
"""
Create a bird classification model.
Args:
num_classes: Number of bird classes
model_type: Type of model ('resnet50', 'resnet18', 'efficientnet_b0', 'lightweight')
pretrained: Whether to use pretrained weights
dropout_rate: Dropout rate for regularization
freeze_backbone: Whether to freeze backbone weights (ignored for lightweight model)
Returns:
PyTorch model
"""
if model_type == 'lightweight':
return LightweightBirdClassifier(num_classes, dropout_rate)
else:
return BirdClassifier(num_classes, model_type, pretrained,
dropout_rate, freeze_backbone)
class ModelEnsemble(nn.Module):
"""
Ensemble of multiple models for improved performance.
"""
def __init__(self, models_list: list):
"""
Initialize model ensemble.
Args:
models_list: List of trained models to ensemble
"""
super(ModelEnsemble, self).__init__()
self.models = nn.ModuleList(models_list)
def forward(self, x):
"""Forward pass through all models and average predictions."""
predictions = []
for model in self.models:
with torch.no_grad():
pred = F.softmax(model(x), dim=1)
predictions.append(pred)
# Average predictions
ensemble_pred = torch.stack(predictions, dim=0).mean(dim=0)
return torch.log(ensemble_pred + 1e-8) # Convert back to log probabilities