Spaces:
Sleeping
Sleeping
| """ | |
| Bird classification model architectures with overfitting prevention. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import models | |
| from typing import Optional | |
| # Try to import EfficientNet | |
| try: | |
| from efficientnet_pytorch import EfficientNet | |
| EFFICIENTNET_AVAILABLE = True | |
| except ImportError: | |
| EFFICIENTNET_AVAILABLE = False | |
| print("EfficientNet not available. Install with: pip install efficientnet-pytorch") | |
| class BirdClassifier(nn.Module): | |
| """ | |
| Bird classification model with ResNet backbone and overfitting prevention. | |
| """ | |
| def __init__(self, num_classes: int, architecture: str = 'resnet50', | |
| pretrained: bool = True, dropout_rate: float = 0.5, | |
| freeze_backbone: bool = False): | |
| """ | |
| Initialize the bird classifier. | |
| Args: | |
| num_classes: Number of bird classes | |
| architecture: Backbone architecture ('resnet50', 'resnet18', 'efficientnet_b0') | |
| pretrained: Whether to use pretrained weights | |
| dropout_rate: Dropout rate for regularization | |
| freeze_backbone: Whether to freeze backbone weights | |
| """ | |
| super(BirdClassifier, self).__init__() | |
| self.num_classes = num_classes | |
| self.dropout_rate = dropout_rate | |
| # Choose backbone architecture | |
| if architecture == 'resnet50': | |
| self.backbone = models.resnet50(pretrained=pretrained) | |
| num_features = self.backbone.fc.in_features | |
| self.backbone.fc = nn.Identity() # Remove original classifier | |
| elif architecture == 'resnet18': | |
| self.backbone = models.resnet18(pretrained=pretrained) | |
| num_features = self.backbone.fc.in_features | |
| self.backbone.fc = nn.Identity() | |
| elif architecture == 'resnet101': | |
| self.backbone = models.resnet101(pretrained=pretrained) | |
| num_features = self.backbone.fc.in_features | |
| self.backbone.fc = nn.Identity() | |
| elif architecture == 'efficientnet_b0': | |
| self.backbone = models.efficientnet_b0(pretrained=pretrained) | |
| num_features = self.backbone.classifier[1].in_features | |
| self.backbone.classifier = nn.Identity() | |
| elif architecture in ['efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4'] and EFFICIENTNET_AVAILABLE: | |
| model_name = architecture.replace('_', '-') | |
| if pretrained: | |
| self.backbone = EfficientNet.from_pretrained(model_name) | |
| else: | |
| self.backbone = EfficientNet.from_name(model_name) | |
| num_features = self.backbone._fc.in_features | |
| self.backbone._fc = nn.Identity() | |
| else: | |
| raise ValueError(f"Unsupported architecture: {architecture}") | |
| # Freeze backbone if requested | |
| if freeze_backbone: | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| # Enhanced classifier head with batch normalization and progressive dimension reduction | |
| # Optimized regularization for Stage 2 performance (76.74% accuracy) | |
| self.classifier = nn.Sequential( | |
| nn.Dropout(p=dropout_rate * 0.6), # Stage 2 optimization: 0.3 * 0.6 = 0.18 | |
| nn.Linear(num_features, 512), # Optimized size | |
| nn.BatchNorm1d(512), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(p=dropout_rate * 0.5), # Stage 2 optimization: 0.3 * 0.5 = 0.15 | |
| nn.Linear(512, 256), | |
| nn.BatchNorm1d(256), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(p=dropout_rate * 0.3), # Stage 2 optimization: 0.3 * 0.3 = 0.09 | |
| nn.Linear(256, num_classes) | |
| ) | |
| # Initialize weights | |
| self._initialize_weights() | |
| def _initialize_weights(self): | |
| """Initialize classifier weights with better initialization.""" | |
| for m in self.classifier.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu')) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.BatchNorm1d): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x): | |
| """Forward pass.""" | |
| features = self.backbone(x) | |
| output = self.classifier(features) | |
| return output | |
| class LightweightBirdClassifier(nn.Module): | |
| """ | |
| Lightweight CNN model for bird classification with batch normalization. | |
| """ | |
| def __init__(self, num_classes: int, dropout_rate: float = 0.5): | |
| """ | |
| Initialize lightweight classifier. | |
| Args: | |
| num_classes: Number of bird classes | |
| dropout_rate: Dropout rate for regularization | |
| """ | |
| super(LightweightBirdClassifier, self).__init__() | |
| self.features = nn.Sequential( | |
| # Block 1 | |
| nn.Conv2d(3, 32, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(32, 32, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(2, 2), | |
| nn.Dropout2d(p=dropout_rate/2), | |
| # Block 2 | |
| nn.Conv2d(32, 64, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(64, 64, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(2, 2), | |
| nn.Dropout2d(p=dropout_rate/2), | |
| # Block 3 | |
| nn.Conv2d(64, 128, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(128, 128, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(2, 2), | |
| nn.Dropout2d(p=dropout_rate/2), | |
| # Block 4 | |
| nn.Conv2d(128, 256, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(256, 256, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(inplace=True), | |
| nn.AdaptiveAvgPool2d((1, 1)), | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Dropout(p=dropout_rate), | |
| nn.Linear(256, 128), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(p=dropout_rate), | |
| nn.Linear(128, num_classes) | |
| ) | |
| self._initialize_weights() | |
| def _initialize_weights(self): | |
| """Initialize model weights.""" | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.BatchNorm2d): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight) | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x): | |
| """Forward pass.""" | |
| x = self.features(x) | |
| x = self.classifier(x) | |
| return x | |
| def create_model(num_classes: int, model_type: str = 'resnet50', | |
| pretrained: bool = True, dropout_rate: float = 0.5, | |
| freeze_backbone: bool = False) -> nn.Module: | |
| """ | |
| Create a bird classification model. | |
| Args: | |
| num_classes: Number of bird classes | |
| model_type: Type of model ('resnet50', 'resnet18', 'efficientnet_b0', 'lightweight') | |
| pretrained: Whether to use pretrained weights | |
| dropout_rate: Dropout rate for regularization | |
| freeze_backbone: Whether to freeze backbone weights (ignored for lightweight model) | |
| Returns: | |
| PyTorch model | |
| """ | |
| if model_type == 'lightweight': | |
| return LightweightBirdClassifier(num_classes, dropout_rate) | |
| else: | |
| return BirdClassifier(num_classes, model_type, pretrained, | |
| dropout_rate, freeze_backbone) | |
| class ModelEnsemble(nn.Module): | |
| """ | |
| Ensemble of multiple models for improved performance. | |
| """ | |
| def __init__(self, models_list: list): | |
| """ | |
| Initialize model ensemble. | |
| Args: | |
| models_list: List of trained models to ensemble | |
| """ | |
| super(ModelEnsemble, self).__init__() | |
| self.models = nn.ModuleList(models_list) | |
| def forward(self, x): | |
| """Forward pass through all models and average predictions.""" | |
| predictions = [] | |
| for model in self.models: | |
| with torch.no_grad(): | |
| pred = F.softmax(model(x), dim=1) | |
| predictions.append(pred) | |
| # Average predictions | |
| ensemble_pred = torch.stack(predictions, dim=0).mean(dim=0) | |
| return torch.log(ensemble_pred + 1e-8) # Convert back to log probabilities |