File size: 5,443 Bytes
3517f21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""

Simple but Powerful Advanced Pre-trained CNN Classifier

Uses EfficientNetV2 with advanced training techniques for architectural style classification.

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from typing import Dict, List, Tuple, Optional
import numpy as np


class SimpleAdvancedClassifier(nn.Module):
    """

    Simple but powerful classifier using EfficientNetV2 with advanced techniques:

    - EfficientNetV2 (state-of-the-art CNN)

    - Advanced feature extraction

    - Multi-scale pooling

    - Attention mechanism

    - Dropout and regularization

    """
    
    def __init__(self, num_classes: int = 25, dropout_rate: float = 0.3):
        super().__init__()
        
        # Pre-trained EfficientNetV2 backbone
        self.backbone = timm.create_model(
            'tf_efficientnetv2_m', 
            pretrained=True, 
            num_classes=0,
            global_pool=''
        )
        
        # Get feature dimensions
        self.feature_dim = self.backbone.num_features
        print(f"EfficientNetV2 feature dimension: {self.feature_dim}")
        
        # Multi-scale pooling
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        # Feature enhancement
        self.feature_enhancement = nn.Sequential(
            nn.Linear(self.feature_dim * 2, self.feature_dim),  # *2 for avg + max pooling
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(self.feature_dim, self.feature_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )
        
        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(self.feature_dim // 2, self.feature_dim // 4),
            nn.ReLU(),
            nn.Linear(self.feature_dim // 4, 1),
            nn.Sigmoid()
        )
        
        # Final classifier
        self.classifier = nn.Sequential(
            nn.Linear(self.feature_dim // 2, self.feature_dim // 4),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(self.feature_dim // 4, num_classes)
        )
        
        # Temperature scaling for calibration
        self.temperature = nn.Parameter(torch.ones(1) * 1.5)
        
    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        # Extract features from backbone
        features = self.backbone.forward_features(x)
        
        # Multi-scale pooling
        avg_pooled = self.global_pool(features).flatten(1)
        max_pooled = self.max_pool(features).flatten(1)
        
        # Concatenate pooled features
        pooled_features = torch.cat([avg_pooled, max_pooled], dim=1)
        
        # Feature enhancement
        enhanced_features = self.feature_enhancement(pooled_features)
        
        # Apply attention
        attention_weights = self.attention(enhanced_features)
        attended_features = enhanced_features * attention_weights
        
        # Classification
        logits = self.classifier(attended_features)
        
        # Apply temperature scaling
        logits = logits / self.temperature
        
        return {
            'logits': logits,
            'features': attended_features,
            'attention_weights': attention_weights
        }


class AdvancedLossFunction(nn.Module):
    """Advanced loss function with label smoothing and focal loss."""
    
    def __init__(self, num_classes: int = 25, alpha: float = 1.0, gamma: float = 2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        
        # Loss functions
        self.cross_entropy = nn.CrossEntropyLoss(label_smoothing=0.1)
        self.focal_loss = FocalLoss(alpha=alpha, gamma=gamma)
        
    def forward(self, outputs: Dict[str, torch.Tensor], targets: torch.Tensor) -> Dict[str, torch.Tensor]:
        logits = outputs['logits']
        
        # Cross entropy loss
        ce_loss = self.cross_entropy(logits, targets)
        
        # Focal loss for hard examples
        focal_loss = self.focal_loss(logits, targets)
        
        # Combine losses
        total_loss = 0.7 * ce_loss + 0.3 * focal_loss
        
        return {
            'total_loss': total_loss,
            'ce_loss': ce_loss,
            'focal_loss': focal_loss
        }


class FocalLoss(nn.Module):
    """Focal Loss for handling class imbalance."""
    
    def __init__(self, alpha: float = 1.0, gamma: float = 2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        
    def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return focal_loss.mean()


def create_simple_advanced_classifier(num_classes: int = 25) -> SimpleAdvancedClassifier:
    """Factory function to create the simple advanced classifier."""
    return SimpleAdvancedClassifier(num_classes=num_classes)


def create_advanced_loss(num_classes: int = 25) -> AdvancedLossFunction:
    """Factory function to create the advanced loss function."""
    return AdvancedLossFunction(num_classes=num_classes)