architectural-style-classifier / src /models\advanced_pretrained_classifier.py
fxxkingusername's picture
Upload src/models\advanced_pretrained_classifier.py with huggingface_hub
37736f2 verified
"""
Advanced Pre-trained CNN Classifier for Architectural Style Classification
Uses multiple state-of-the-art architectures with ensemble methods.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from transformers import AutoImageProcessor, AutoModel
from typing import Dict, List, Tuple, Optional
import numpy as np
class AdvancedPretrainedClassifier(nn.Module):
"""
Advanced pre-trained classifier using multiple architectures:
- EfficientNetV2 (for general features)
- ConvNeXt (for modern architectural features)
- Swin Transformer (for hierarchical features)
- Vision Transformer (for global attention)
"""
def __init__(self, num_classes: int = 25, dropout_rate: float = 0.3):
super().__init__()
# Multiple pre-trained backbones
self.efficientnet = timm.create_model(
'tf_efficientnetv2_m',
pretrained=True,
num_classes=0,
global_pool='avg'
)
self.convnext = timm.create_model(
'convnext_base',
pretrained=True,
num_classes=0,
global_pool='avg'
)
self.swin = timm.create_model(
'swin_base_patch4_window7_224',
pretrained=True,
num_classes=0,
global_pool='avg'
)
# Vision Transformer from HuggingFace
self.vit_processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224')
self.vit = AutoModel.from_pretrained('google/vit-base-patch16-224')
# Feature dimensions
self.efficientnet_dim = self.efficientnet.num_features
self.convnext_dim = self.convnext.num_features
self.swin_dim = self.swin.num_features
self.vit_dim = 768 # ViT base hidden size
# Print feature dimensions for debugging
print(f"Feature dimensions:")
print(f" EfficientNet: {self.efficientnet_dim}")
print(f" ConvNeXt: {self.convnext_dim}")
print(f" Swin: {self.swin_dim}")
print(f" ViT: {self.vit_dim}")
# Feature fusion layers
total_features = self.efficientnet_dim + self.convnext_dim + self.swin_dim + self.vit_dim
self.feature_fusion = nn.Sequential(
nn.Linear(total_features, 1024),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(dropout_rate)
)
# Multi-scale attention
self.attention = MultiScaleAttention(
efficientnet_dim=self.efficientnet_dim,
convnext_dim=self.convnext_dim,
swin_dim=self.swin_dim,
vit_dim=self.vit_dim
)
# Final classifier with multiple heads
self.classifier = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(256, num_classes)
)
# Auxiliary classifiers for each backbone
self.aux_efficientnet = nn.Linear(self.efficientnet_dim, num_classes)
self.aux_convnext = nn.Linear(self.convnext_dim, num_classes)
self.aux_swin = nn.Linear(self.swin_dim, num_classes)
self.aux_vit = nn.Linear(self.vit_dim, num_classes)
# Temperature scaling for calibration
self.temperature = nn.Parameter(torch.ones(1) * 1.5)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
# Extract features from each backbone
efficientnet_features = self.efficientnet.forward_features(x)
if isinstance(efficientnet_features, tuple):
efficientnet_features = efficientnet_features[0]
efficientnet_features = F.adaptive_avg_pool2d(efficientnet_features, 1).flatten(1)
convnext_features = self.convnext.forward_features(x)
if isinstance(convnext_features, tuple):
convnext_features = convnext_features[0]
convnext_features = F.adaptive_avg_pool2d(convnext_features, 1).flatten(1)
swin_features = self.swin.forward_features(x)
if isinstance(swin_features, tuple):
swin_features = swin_features[0]
swin_features = F.adaptive_avg_pool2d(swin_features, 1).flatten(1)
# ViT features (need to process differently)
vit_features = self._extract_vit_features(x)
# Apply attention mechanism
attended_features = self.attention(
efficientnet_features, convnext_features, swin_features, vit_features
)
# Concatenate all features
combined_features = torch.cat([
efficientnet_features, convnext_features, swin_features, vit_features
], dim=1)
# Feature fusion
fused_features = self.feature_fusion(combined_features)
# Main classifier
main_logits = self.classifier(fused_features)
# Auxiliary classifiers
aux_efficientnet_logits = self.aux_efficientnet(efficientnet_features)
aux_convnext_logits = self.aux_convnext(convnext_features)
aux_swin_logits = self.aux_swin(swin_features)
aux_vit_logits = self.aux_vit(vit_features)
# Apply temperature scaling
main_logits = main_logits / self.temperature
return {
'logits': main_logits,
'aux_efficientnet': aux_efficientnet_logits,
'aux_convnext': aux_convnext_logits,
'aux_swin': aux_swin_logits,
'aux_vit': aux_vit_logits,
'features': fused_features,
'attended_features': attended_features
}
def _extract_vit_features(self, x: torch.Tensor) -> torch.Tensor:
"""Extract features from Vision Transformer."""
# Convert to PIL-like format for ViT
# ViT expects normalized images in [0, 1] range
x_normalized = x / 255.0
# Use the CLS token output as features
with torch.no_grad():
outputs = self.vit(pixel_values=x_normalized)
# Get the CLS token (first token)
cls_output = outputs.last_hidden_state[:, 0, :]
return cls_output
class MultiScaleAttention(nn.Module):
"""Multi-scale attention mechanism for feature fusion."""
def __init__(self, efficientnet_dim: int, convnext_dim: int, swin_dim: int, vit_dim: int):
super().__init__()
# Project all features to a common dimension
self.common_dim = 512
# Projection layers to common dimension
self.efficientnet_projection = nn.Linear(efficientnet_dim, self.common_dim)
self.convnext_projection = nn.Linear(convnext_dim, self.common_dim)
self.swin_projection = nn.Linear(swin_dim, self.common_dim)
self.vit_projection = nn.Linear(vit_dim, self.common_dim)
# Attention weights for each feature type
self.efficientnet_attention = nn.Linear(self.common_dim, 1)
self.convnext_attention = nn.Linear(self.common_dim, 1)
self.swin_attention = nn.Linear(self.common_dim, 1)
self.vit_attention = nn.Linear(self.common_dim, 1)
def forward(self, efficientnet_features: torch.Tensor, convnext_features: torch.Tensor,
swin_features: torch.Tensor, vit_features: torch.Tensor) -> torch.Tensor:
# Project all features to common dimension
efficientnet_proj = self.efficientnet_projection(efficientnet_features)
convnext_proj = self.convnext_projection(convnext_features)
swin_proj = self.swin_projection(swin_features)
vit_proj = self.vit_projection(vit_features)
# Calculate attention weights
efficientnet_attn = torch.sigmoid(self.efficientnet_attention(efficientnet_proj))
convnext_attn = torch.sigmoid(self.convnext_attention(convnext_proj))
swin_attn = torch.sigmoid(self.swin_attention(swin_proj))
vit_attn = torch.sigmoid(self.vit_attention(vit_proj))
# Weighted features
weighted_efficientnet = efficientnet_proj * efficientnet_attn
weighted_convnext = convnext_proj * convnext_attn
weighted_swin = swin_proj * swin_attn
weighted_vit = vit_proj * vit_attn
# Combine weighted features
attended_features = (
weighted_efficientnet + weighted_convnext + weighted_swin + weighted_vit
) / 4.0
return attended_features
class AdvancedLossFunction(nn.Module):
"""Advanced loss function combining multiple loss types."""
def __init__(self, num_classes: int = 25, alpha: float = 0.4, beta: float = 0.3, gamma: float = 0.3):
super().__init__()
self.alpha = alpha # Main loss weight
self.beta = beta # Auxiliary loss weight
self.gamma = gamma # Focal loss weight
# Loss functions
self.cross_entropy = nn.CrossEntropyLoss(label_smoothing=0.1)
self.focal_loss = FocalLoss(alpha=1.0, gamma=2.0)
self.center_loss = CenterLoss(num_classes=num_classes, feat_dim=512)
def forward(self, outputs: Dict[str, torch.Tensor], targets: torch.Tensor) -> Dict[str, torch.Tensor]:
main_logits = outputs['logits']
aux_logits = [
outputs['aux_efficientnet'],
outputs['aux_convnext'],
outputs['aux_swin'],
outputs['aux_vit']
]
features = outputs['features']
# Main classification loss
main_loss = self.cross_entropy(main_logits, targets)
# Auxiliary losses
aux_losses = []
for aux_logit in aux_logits:
aux_loss = self.cross_entropy(aux_logit, targets)
aux_losses.append(aux_loss)
aux_loss = torch.mean(torch.stack(aux_losses))
# Focal loss for hard examples
focal_loss = self.focal_loss(main_logits, targets)
# Center loss for feature learning
center_loss = self.center_loss(features, targets)
# Total loss
total_loss = (
self.alpha * main_loss +
self.beta * aux_loss +
self.gamma * focal_loss +
0.1 * center_loss
)
return {
'total_loss': total_loss,
'main_loss': main_loss,
'aux_loss': aux_loss,
'focal_loss': focal_loss,
'center_loss': center_loss
}
class FocalLoss(nn.Module):
"""Focal Loss for handling class imbalance."""
def __init__(self, alpha: float = 1.0, gamma: float = 2.0):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
return focal_loss.mean()
class CenterLoss(nn.Module):
"""Center Loss for learning discriminative features."""
def __init__(self, num_classes: int, feat_dim: int, device: str = 'cpu'):
super().__init__()
self.num_classes = num_classes
self.feat_dim = feat_dim
self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
def forward(self, features: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
centers_batch = self.centers.index_select(0, targets)
return F.mse_loss(features, centers_batch)
def create_advanced_classifier(num_classes: int = 25) -> AdvancedPretrainedClassifier:
"""Factory function to create the advanced classifier."""
return AdvancedPretrainedClassifier(num_classes=num_classes)
def create_advanced_loss(num_classes: int = 25) -> AdvancedLossFunction:
"""Factory function to create the advanced loss function."""
return AdvancedLossFunction(num_classes=num_classes)