File size: 5,443 Bytes
3517f21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
"""
Simple but Powerful Advanced Pre-trained CNN Classifier
Uses EfficientNetV2 with advanced training techniques for architectural style classification.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from typing import Dict, List, Tuple, Optional
import numpy as np
class SimpleAdvancedClassifier(nn.Module):
"""
Simple but powerful classifier using EfficientNetV2 with advanced techniques:
- EfficientNetV2 (state-of-the-art CNN)
- Advanced feature extraction
- Multi-scale pooling
- Attention mechanism
- Dropout and regularization
"""
def __init__(self, num_classes: int = 25, dropout_rate: float = 0.3):
super().__init__()
# Pre-trained EfficientNetV2 backbone
self.backbone = timm.create_model(
'tf_efficientnetv2_m',
pretrained=True,
num_classes=0,
global_pool=''
)
# Get feature dimensions
self.feature_dim = self.backbone.num_features
print(f"EfficientNetV2 feature dimension: {self.feature_dim}")
# Multi-scale pooling
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
# Feature enhancement
self.feature_enhancement = nn.Sequential(
nn.Linear(self.feature_dim * 2, self.feature_dim), # *2 for avg + max pooling
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(self.feature_dim, self.feature_dim // 2),
nn.ReLU(),
nn.Dropout(dropout_rate)
)
# Attention mechanism
self.attention = nn.Sequential(
nn.Linear(self.feature_dim // 2, self.feature_dim // 4),
nn.ReLU(),
nn.Linear(self.feature_dim // 4, 1),
nn.Sigmoid()
)
# Final classifier
self.classifier = nn.Sequential(
nn.Linear(self.feature_dim // 2, self.feature_dim // 4),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(self.feature_dim // 4, num_classes)
)
# Temperature scaling for calibration
self.temperature = nn.Parameter(torch.ones(1) * 1.5)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
# Extract features from backbone
features = self.backbone.forward_features(x)
# Multi-scale pooling
avg_pooled = self.global_pool(features).flatten(1)
max_pooled = self.max_pool(features).flatten(1)
# Concatenate pooled features
pooled_features = torch.cat([avg_pooled, max_pooled], dim=1)
# Feature enhancement
enhanced_features = self.feature_enhancement(pooled_features)
# Apply attention
attention_weights = self.attention(enhanced_features)
attended_features = enhanced_features * attention_weights
# Classification
logits = self.classifier(attended_features)
# Apply temperature scaling
logits = logits / self.temperature
return {
'logits': logits,
'features': attended_features,
'attention_weights': attention_weights
}
class AdvancedLossFunction(nn.Module):
"""Advanced loss function with label smoothing and focal loss."""
def __init__(self, num_classes: int = 25, alpha: float = 1.0, gamma: float = 2.0):
super().__init__()
self.alpha = alpha
self.gamma = gamma
# Loss functions
self.cross_entropy = nn.CrossEntropyLoss(label_smoothing=0.1)
self.focal_loss = FocalLoss(alpha=alpha, gamma=gamma)
def forward(self, outputs: Dict[str, torch.Tensor], targets: torch.Tensor) -> Dict[str, torch.Tensor]:
logits = outputs['logits']
# Cross entropy loss
ce_loss = self.cross_entropy(logits, targets)
# Focal loss for hard examples
focal_loss = self.focal_loss(logits, targets)
# Combine losses
total_loss = 0.7 * ce_loss + 0.3 * focal_loss
return {
'total_loss': total_loss,
'ce_loss': ce_loss,
'focal_loss': focal_loss
}
class FocalLoss(nn.Module):
"""Focal Loss for handling class imbalance."""
def __init__(self, alpha: float = 1.0, gamma: float = 2.0):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
return focal_loss.mean()
def create_simple_advanced_classifier(num_classes: int = 25) -> SimpleAdvancedClassifier:
"""Factory function to create the simple advanced classifier."""
return SimpleAdvancedClassifier(num_classes=num_classes)
def create_advanced_loss(num_classes: int = 25) -> AdvancedLossFunction:
"""Factory function to create the advanced loss function."""
return AdvancedLossFunction(num_classes=num_classes)
|