| | """
|
| | Advanced Pre-trained CNN Classifier for Architectural Style Classification
|
| | Uses multiple state-of-the-art architectures with ensemble methods.
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | import timm
|
| | from transformers import AutoImageProcessor, AutoModel
|
| | from typing import Dict, List, Tuple, Optional
|
| | import numpy as np
|
| |
|
| |
|
| | class AdvancedPretrainedClassifier(nn.Module):
|
| | """
|
| | Advanced pre-trained classifier using multiple architectures:
|
| | - EfficientNetV2 (for general features)
|
| | - ConvNeXt (for modern architectural features)
|
| | - Swin Transformer (for hierarchical features)
|
| | - Vision Transformer (for global attention)
|
| | """
|
| |
|
| | def __init__(self, num_classes: int = 25, dropout_rate: float = 0.3):
|
| | super().__init__()
|
| |
|
| |
|
| | self.efficientnet = timm.create_model(
|
| | 'tf_efficientnetv2_m',
|
| | pretrained=True,
|
| | num_classes=0,
|
| | global_pool='avg'
|
| | )
|
| |
|
| | self.convnext = timm.create_model(
|
| | 'convnext_base',
|
| | pretrained=True,
|
| | num_classes=0,
|
| | global_pool='avg'
|
| | )
|
| |
|
| | self.swin = timm.create_model(
|
| | 'swin_base_patch4_window7_224',
|
| | pretrained=True,
|
| | num_classes=0,
|
| | global_pool='avg'
|
| | )
|
| |
|
| |
|
| | self.vit_processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224')
|
| | self.vit = AutoModel.from_pretrained('google/vit-base-patch16-224')
|
| |
|
| |
|
| | self.efficientnet_dim = self.efficientnet.num_features
|
| | self.convnext_dim = self.convnext.num_features
|
| | self.swin_dim = self.swin.num_features
|
| | self.vit_dim = 768
|
| |
|
| |
|
| | print(f"Feature dimensions:")
|
| | print(f" EfficientNet: {self.efficientnet_dim}")
|
| | print(f" ConvNeXt: {self.convnext_dim}")
|
| | print(f" Swin: {self.swin_dim}")
|
| | print(f" ViT: {self.vit_dim}")
|
| |
|
| |
|
| | total_features = self.efficientnet_dim + self.convnext_dim + self.swin_dim + self.vit_dim
|
| |
|
| | self.feature_fusion = nn.Sequential(
|
| | nn.Linear(total_features, 1024),
|
| | nn.ReLU(),
|
| | nn.Dropout(dropout_rate),
|
| | nn.Linear(1024, 512),
|
| | nn.ReLU(),
|
| | nn.Dropout(dropout_rate)
|
| | )
|
| |
|
| |
|
| | self.attention = MultiScaleAttention(
|
| | efficientnet_dim=self.efficientnet_dim,
|
| | convnext_dim=self.convnext_dim,
|
| | swin_dim=self.swin_dim,
|
| | vit_dim=self.vit_dim
|
| | )
|
| |
|
| |
|
| | self.classifier = nn.Sequential(
|
| | nn.Linear(512, 256),
|
| | nn.ReLU(),
|
| | nn.Dropout(dropout_rate),
|
| | nn.Linear(256, num_classes)
|
| | )
|
| |
|
| |
|
| | self.aux_efficientnet = nn.Linear(self.efficientnet_dim, num_classes)
|
| | self.aux_convnext = nn.Linear(self.convnext_dim, num_classes)
|
| | self.aux_swin = nn.Linear(self.swin_dim, num_classes)
|
| | self.aux_vit = nn.Linear(self.vit_dim, num_classes)
|
| |
|
| |
|
| | self.temperature = nn.Parameter(torch.ones(1) * 1.5)
|
| |
|
| | def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
|
| |
|
| | efficientnet_features = self.efficientnet.forward_features(x)
|
| | if isinstance(efficientnet_features, tuple):
|
| | efficientnet_features = efficientnet_features[0]
|
| | efficientnet_features = F.adaptive_avg_pool2d(efficientnet_features, 1).flatten(1)
|
| |
|
| | convnext_features = self.convnext.forward_features(x)
|
| | if isinstance(convnext_features, tuple):
|
| | convnext_features = convnext_features[0]
|
| | convnext_features = F.adaptive_avg_pool2d(convnext_features, 1).flatten(1)
|
| |
|
| | swin_features = self.swin.forward_features(x)
|
| | if isinstance(swin_features, tuple):
|
| | swin_features = swin_features[0]
|
| | swin_features = F.adaptive_avg_pool2d(swin_features, 1).flatten(1)
|
| |
|
| |
|
| | vit_features = self._extract_vit_features(x)
|
| |
|
| |
|
| | attended_features = self.attention(
|
| | efficientnet_features, convnext_features, swin_features, vit_features
|
| | )
|
| |
|
| |
|
| | combined_features = torch.cat([
|
| | efficientnet_features, convnext_features, swin_features, vit_features
|
| | ], dim=1)
|
| |
|
| |
|
| | fused_features = self.feature_fusion(combined_features)
|
| |
|
| |
|
| | main_logits = self.classifier(fused_features)
|
| |
|
| |
|
| | aux_efficientnet_logits = self.aux_efficientnet(efficientnet_features)
|
| | aux_convnext_logits = self.aux_convnext(convnext_features)
|
| | aux_swin_logits = self.aux_swin(swin_features)
|
| | aux_vit_logits = self.aux_vit(vit_features)
|
| |
|
| |
|
| | main_logits = main_logits / self.temperature
|
| |
|
| | return {
|
| | 'logits': main_logits,
|
| | 'aux_efficientnet': aux_efficientnet_logits,
|
| | 'aux_convnext': aux_convnext_logits,
|
| | 'aux_swin': aux_swin_logits,
|
| | 'aux_vit': aux_vit_logits,
|
| | 'features': fused_features,
|
| | 'attended_features': attended_features
|
| | }
|
| |
|
| | def _extract_vit_features(self, x: torch.Tensor) -> torch.Tensor:
|
| | """Extract features from Vision Transformer."""
|
| |
|
| |
|
| | x_normalized = x / 255.0
|
| |
|
| |
|
| | with torch.no_grad():
|
| | outputs = self.vit(pixel_values=x_normalized)
|
| |
|
| | cls_output = outputs.last_hidden_state[:, 0, :]
|
| |
|
| | return cls_output
|
| |
|
| |
|
| | class MultiScaleAttention(nn.Module):
|
| | """Multi-scale attention mechanism for feature fusion."""
|
| |
|
| | def __init__(self, efficientnet_dim: int, convnext_dim: int, swin_dim: int, vit_dim: int):
|
| | super().__init__()
|
| |
|
| |
|
| | self.common_dim = 512
|
| |
|
| |
|
| | self.efficientnet_projection = nn.Linear(efficientnet_dim, self.common_dim)
|
| | self.convnext_projection = nn.Linear(convnext_dim, self.common_dim)
|
| | self.swin_projection = nn.Linear(swin_dim, self.common_dim)
|
| | self.vit_projection = nn.Linear(vit_dim, self.common_dim)
|
| |
|
| |
|
| | self.efficientnet_attention = nn.Linear(self.common_dim, 1)
|
| | self.convnext_attention = nn.Linear(self.common_dim, 1)
|
| | self.swin_attention = nn.Linear(self.common_dim, 1)
|
| | self.vit_attention = nn.Linear(self.common_dim, 1)
|
| |
|
| | def forward(self, efficientnet_features: torch.Tensor, convnext_features: torch.Tensor,
|
| | swin_features: torch.Tensor, vit_features: torch.Tensor) -> torch.Tensor:
|
| |
|
| |
|
| | efficientnet_proj = self.efficientnet_projection(efficientnet_features)
|
| | convnext_proj = self.convnext_projection(convnext_features)
|
| | swin_proj = self.swin_projection(swin_features)
|
| | vit_proj = self.vit_projection(vit_features)
|
| |
|
| |
|
| | efficientnet_attn = torch.sigmoid(self.efficientnet_attention(efficientnet_proj))
|
| | convnext_attn = torch.sigmoid(self.convnext_attention(convnext_proj))
|
| | swin_attn = torch.sigmoid(self.swin_attention(swin_proj))
|
| | vit_attn = torch.sigmoid(self.vit_attention(vit_proj))
|
| |
|
| |
|
| | weighted_efficientnet = efficientnet_proj * efficientnet_attn
|
| | weighted_convnext = convnext_proj * convnext_attn
|
| | weighted_swin = swin_proj * swin_attn
|
| | weighted_vit = vit_proj * vit_attn
|
| |
|
| |
|
| | attended_features = (
|
| | weighted_efficientnet + weighted_convnext + weighted_swin + weighted_vit
|
| | ) / 4.0
|
| |
|
| | return attended_features
|
| |
|
| |
|
| | class AdvancedLossFunction(nn.Module):
|
| | """Advanced loss function combining multiple loss types."""
|
| |
|
| | def __init__(self, num_classes: int = 25, alpha: float = 0.4, beta: float = 0.3, gamma: float = 0.3):
|
| | super().__init__()
|
| | self.alpha = alpha
|
| | self.beta = beta
|
| | self.gamma = gamma
|
| |
|
| |
|
| | self.cross_entropy = nn.CrossEntropyLoss(label_smoothing=0.1)
|
| | self.focal_loss = FocalLoss(alpha=1.0, gamma=2.0)
|
| | self.center_loss = CenterLoss(num_classes=num_classes, feat_dim=512)
|
| |
|
| | def forward(self, outputs: Dict[str, torch.Tensor], targets: torch.Tensor) -> Dict[str, torch.Tensor]:
|
| | main_logits = outputs['logits']
|
| | aux_logits = [
|
| | outputs['aux_efficientnet'],
|
| | outputs['aux_convnext'],
|
| | outputs['aux_swin'],
|
| | outputs['aux_vit']
|
| | ]
|
| | features = outputs['features']
|
| |
|
| |
|
| | main_loss = self.cross_entropy(main_logits, targets)
|
| |
|
| |
|
| | aux_losses = []
|
| | for aux_logit in aux_logits:
|
| | aux_loss = self.cross_entropy(aux_logit, targets)
|
| | aux_losses.append(aux_loss)
|
| | aux_loss = torch.mean(torch.stack(aux_losses))
|
| |
|
| |
|
| | focal_loss = self.focal_loss(main_logits, targets)
|
| |
|
| |
|
| | center_loss = self.center_loss(features, targets)
|
| |
|
| |
|
| | total_loss = (
|
| | self.alpha * main_loss +
|
| | self.beta * aux_loss +
|
| | self.gamma * focal_loss +
|
| | 0.1 * center_loss
|
| | )
|
| |
|
| | return {
|
| | 'total_loss': total_loss,
|
| | 'main_loss': main_loss,
|
| | 'aux_loss': aux_loss,
|
| | 'focal_loss': focal_loss,
|
| | 'center_loss': center_loss
|
| | }
|
| |
|
| |
|
| | class FocalLoss(nn.Module):
|
| | """Focal Loss for handling class imbalance."""
|
| |
|
| | def __init__(self, alpha: float = 1.0, gamma: float = 2.0):
|
| | super().__init__()
|
| | self.alpha = alpha
|
| | self.gamma = gamma
|
| |
|
| | def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| | ce_loss = F.cross_entropy(inputs, targets, reduction='none')
|
| | pt = torch.exp(-ce_loss)
|
| | focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
|
| | return focal_loss.mean()
|
| |
|
| |
|
| | class CenterLoss(nn.Module):
|
| | """Center Loss for learning discriminative features."""
|
| |
|
| | def __init__(self, num_classes: int, feat_dim: int, device: str = 'cpu'):
|
| | super().__init__()
|
| | self.num_classes = num_classes
|
| | self.feat_dim = feat_dim
|
| | self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
|
| |
|
| | def forward(self, features: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| | centers_batch = self.centers.index_select(0, targets)
|
| | return F.mse_loss(features, centers_batch)
|
| |
|
| |
|
| | def create_advanced_classifier(num_classes: int = 25) -> AdvancedPretrainedClassifier:
|
| | """Factory function to create the advanced classifier."""
|
| | return AdvancedPretrainedClassifier(num_classes=num_classes)
|
| |
|
| |
|
| | def create_advanced_loss(num_classes: int = 25) -> AdvancedLossFunction:
|
| | """Factory function to create the advanced loss function."""
|
| | return AdvancedLossFunction(num_classes=num_classes)
|
| |
|