|
|
"""
|
|
|
Simple but Powerful Advanced Pre-trained CNN Classifier
|
|
|
Uses EfficientNetV2 with advanced training techniques for architectural style classification.
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import timm
|
|
|
from typing import Dict, List, Tuple, Optional
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
class SimpleAdvancedClassifier(nn.Module):
|
|
|
"""
|
|
|
Simple but powerful classifier using EfficientNetV2 with advanced techniques:
|
|
|
- EfficientNetV2 (state-of-the-art CNN)
|
|
|
- Advanced feature extraction
|
|
|
- Multi-scale pooling
|
|
|
- Attention mechanism
|
|
|
- Dropout and regularization
|
|
|
"""
|
|
|
|
|
|
def __init__(self, num_classes: int = 25, dropout_rate: float = 0.3):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.backbone = timm.create_model(
|
|
|
'tf_efficientnetv2_m',
|
|
|
pretrained=True,
|
|
|
num_classes=0,
|
|
|
global_pool=''
|
|
|
)
|
|
|
|
|
|
|
|
|
self.feature_dim = self.backbone.num_features
|
|
|
print(f"EfficientNetV2 feature dimension: {self.feature_dim}")
|
|
|
|
|
|
|
|
|
self.global_pool = nn.AdaptiveAvgPool2d(1)
|
|
|
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
|
|
|
|
|
|
|
|
self.feature_enhancement = nn.Sequential(
|
|
|
nn.Linear(self.feature_dim * 2, self.feature_dim),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(dropout_rate),
|
|
|
nn.Linear(self.feature_dim, self.feature_dim // 2),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(dropout_rate)
|
|
|
)
|
|
|
|
|
|
|
|
|
self.attention = nn.Sequential(
|
|
|
nn.Linear(self.feature_dim // 2, self.feature_dim // 4),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(self.feature_dim // 4, 1),
|
|
|
nn.Sigmoid()
|
|
|
)
|
|
|
|
|
|
|
|
|
self.classifier = nn.Sequential(
|
|
|
nn.Linear(self.feature_dim // 2, self.feature_dim // 4),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(dropout_rate),
|
|
|
nn.Linear(self.feature_dim // 4, num_classes)
|
|
|
)
|
|
|
|
|
|
|
|
|
self.temperature = nn.Parameter(torch.ones(1) * 1.5)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
|
|
|
|
features = self.backbone.forward_features(x)
|
|
|
|
|
|
|
|
|
avg_pooled = self.global_pool(features).flatten(1)
|
|
|
max_pooled = self.max_pool(features).flatten(1)
|
|
|
|
|
|
|
|
|
pooled_features = torch.cat([avg_pooled, max_pooled], dim=1)
|
|
|
|
|
|
|
|
|
enhanced_features = self.feature_enhancement(pooled_features)
|
|
|
|
|
|
|
|
|
attention_weights = self.attention(enhanced_features)
|
|
|
attended_features = enhanced_features * attention_weights
|
|
|
|
|
|
|
|
|
logits = self.classifier(attended_features)
|
|
|
|
|
|
|
|
|
logits = logits / self.temperature
|
|
|
|
|
|
return {
|
|
|
'logits': logits,
|
|
|
'features': attended_features,
|
|
|
'attention_weights': attention_weights
|
|
|
}
|
|
|
|
|
|
|
|
|
class AdvancedLossFunction(nn.Module):
|
|
|
"""Advanced loss function with label smoothing and focal loss."""
|
|
|
|
|
|
def __init__(self, num_classes: int = 25, alpha: float = 1.0, gamma: float = 2.0):
|
|
|
super().__init__()
|
|
|
self.alpha = alpha
|
|
|
self.gamma = gamma
|
|
|
|
|
|
|
|
|
self.cross_entropy = nn.CrossEntropyLoss(label_smoothing=0.1)
|
|
|
self.focal_loss = FocalLoss(alpha=alpha, gamma=gamma)
|
|
|
|
|
|
def forward(self, outputs: Dict[str, torch.Tensor], targets: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
|
logits = outputs['logits']
|
|
|
|
|
|
|
|
|
ce_loss = self.cross_entropy(logits, targets)
|
|
|
|
|
|
|
|
|
focal_loss = self.focal_loss(logits, targets)
|
|
|
|
|
|
|
|
|
total_loss = 0.7 * ce_loss + 0.3 * focal_loss
|
|
|
|
|
|
return {
|
|
|
'total_loss': total_loss,
|
|
|
'ce_loss': ce_loss,
|
|
|
'focal_loss': focal_loss
|
|
|
}
|
|
|
|
|
|
|
|
|
class FocalLoss(nn.Module):
|
|
|
"""Focal Loss for handling class imbalance."""
|
|
|
|
|
|
def __init__(self, alpha: float = 1.0, gamma: float = 2.0):
|
|
|
super().__init__()
|
|
|
self.alpha = alpha
|
|
|
self.gamma = gamma
|
|
|
|
|
|
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
|
|
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
|
|
|
pt = torch.exp(-ce_loss)
|
|
|
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
|
|
|
return focal_loss.mean()
|
|
|
|
|
|
|
|
|
def create_simple_advanced_classifier(num_classes: int = 25) -> SimpleAdvancedClassifier:
|
|
|
"""Factory function to create the simple advanced classifier."""
|
|
|
return SimpleAdvancedClassifier(num_classes=num_classes)
|
|
|
|
|
|
|
|
|
def create_advanced_loss(num_classes: int = 25) -> AdvancedLossFunction:
|
|
|
"""Factory function to create the advanced loss function."""
|
|
|
return AdvancedLossFunction(num_classes=num_classes)
|
|
|
|