fxxkingusername commited on
Commit
37736f2
·
verified ·
1 Parent(s): d12ab12

Upload src/models\advanced_pretrained_classifier.py with huggingface_hub

Browse files
src/models//advanced_pretrained_classifier.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Advanced Pre-trained CNN Classifier for Architectural Style Classification
3
+ Uses multiple state-of-the-art architectures with ensemble methods.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import timm
10
+ from transformers import AutoImageProcessor, AutoModel
11
+ from typing import Dict, List, Tuple, Optional
12
+ import numpy as np
13
+
14
+
15
+ class AdvancedPretrainedClassifier(nn.Module):
16
+ """
17
+ Advanced pre-trained classifier using multiple architectures:
18
+ - EfficientNetV2 (for general features)
19
+ - ConvNeXt (for modern architectural features)
20
+ - Swin Transformer (for hierarchical features)
21
+ - Vision Transformer (for global attention)
22
+ """
23
+
24
+ def __init__(self, num_classes: int = 25, dropout_rate: float = 0.3):
25
+ super().__init__()
26
+
27
+ # Multiple pre-trained backbones
28
+ self.efficientnet = timm.create_model(
29
+ 'tf_efficientnetv2_m',
30
+ pretrained=True,
31
+ num_classes=0,
32
+ global_pool='avg'
33
+ )
34
+
35
+ self.convnext = timm.create_model(
36
+ 'convnext_base',
37
+ pretrained=True,
38
+ num_classes=0,
39
+ global_pool='avg'
40
+ )
41
+
42
+ self.swin = timm.create_model(
43
+ 'swin_base_patch4_window7_224',
44
+ pretrained=True,
45
+ num_classes=0,
46
+ global_pool='avg'
47
+ )
48
+
49
+ # Vision Transformer from HuggingFace
50
+ self.vit_processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224')
51
+ self.vit = AutoModel.from_pretrained('google/vit-base-patch16-224')
52
+
53
+ # Feature dimensions
54
+ self.efficientnet_dim = self.efficientnet.num_features
55
+ self.convnext_dim = self.convnext.num_features
56
+ self.swin_dim = self.swin.num_features
57
+ self.vit_dim = 768 # ViT base hidden size
58
+
59
+ # Print feature dimensions for debugging
60
+ print(f"Feature dimensions:")
61
+ print(f" EfficientNet: {self.efficientnet_dim}")
62
+ print(f" ConvNeXt: {self.convnext_dim}")
63
+ print(f" Swin: {self.swin_dim}")
64
+ print(f" ViT: {self.vit_dim}")
65
+
66
+ # Feature fusion layers
67
+ total_features = self.efficientnet_dim + self.convnext_dim + self.swin_dim + self.vit_dim
68
+
69
+ self.feature_fusion = nn.Sequential(
70
+ nn.Linear(total_features, 1024),
71
+ nn.ReLU(),
72
+ nn.Dropout(dropout_rate),
73
+ nn.Linear(1024, 512),
74
+ nn.ReLU(),
75
+ nn.Dropout(dropout_rate)
76
+ )
77
+
78
+ # Multi-scale attention
79
+ self.attention = MultiScaleAttention(
80
+ efficientnet_dim=self.efficientnet_dim,
81
+ convnext_dim=self.convnext_dim,
82
+ swin_dim=self.swin_dim,
83
+ vit_dim=self.vit_dim
84
+ )
85
+
86
+ # Final classifier with multiple heads
87
+ self.classifier = nn.Sequential(
88
+ nn.Linear(512, 256),
89
+ nn.ReLU(),
90
+ nn.Dropout(dropout_rate),
91
+ nn.Linear(256, num_classes)
92
+ )
93
+
94
+ # Auxiliary classifiers for each backbone
95
+ self.aux_efficientnet = nn.Linear(self.efficientnet_dim, num_classes)
96
+ self.aux_convnext = nn.Linear(self.convnext_dim, num_classes)
97
+ self.aux_swin = nn.Linear(self.swin_dim, num_classes)
98
+ self.aux_vit = nn.Linear(self.vit_dim, num_classes)
99
+
100
+ # Temperature scaling for calibration
101
+ self.temperature = nn.Parameter(torch.ones(1) * 1.5)
102
+
103
+ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
104
+ # Extract features from each backbone
105
+ efficientnet_features = self.efficientnet.forward_features(x)
106
+ if isinstance(efficientnet_features, tuple):
107
+ efficientnet_features = efficientnet_features[0]
108
+ efficientnet_features = F.adaptive_avg_pool2d(efficientnet_features, 1).flatten(1)
109
+
110
+ convnext_features = self.convnext.forward_features(x)
111
+ if isinstance(convnext_features, tuple):
112
+ convnext_features = convnext_features[0]
113
+ convnext_features = F.adaptive_avg_pool2d(convnext_features, 1).flatten(1)
114
+
115
+ swin_features = self.swin.forward_features(x)
116
+ if isinstance(swin_features, tuple):
117
+ swin_features = swin_features[0]
118
+ swin_features = F.adaptive_avg_pool2d(swin_features, 1).flatten(1)
119
+
120
+ # ViT features (need to process differently)
121
+ vit_features = self._extract_vit_features(x)
122
+
123
+ # Apply attention mechanism
124
+ attended_features = self.attention(
125
+ efficientnet_features, convnext_features, swin_features, vit_features
126
+ )
127
+
128
+ # Concatenate all features
129
+ combined_features = torch.cat([
130
+ efficientnet_features, convnext_features, swin_features, vit_features
131
+ ], dim=1)
132
+
133
+ # Feature fusion
134
+ fused_features = self.feature_fusion(combined_features)
135
+
136
+ # Main classifier
137
+ main_logits = self.classifier(fused_features)
138
+
139
+ # Auxiliary classifiers
140
+ aux_efficientnet_logits = self.aux_efficientnet(efficientnet_features)
141
+ aux_convnext_logits = self.aux_convnext(convnext_features)
142
+ aux_swin_logits = self.aux_swin(swin_features)
143
+ aux_vit_logits = self.aux_vit(vit_features)
144
+
145
+ # Apply temperature scaling
146
+ main_logits = main_logits / self.temperature
147
+
148
+ return {
149
+ 'logits': main_logits,
150
+ 'aux_efficientnet': aux_efficientnet_logits,
151
+ 'aux_convnext': aux_convnext_logits,
152
+ 'aux_swin': aux_swin_logits,
153
+ 'aux_vit': aux_vit_logits,
154
+ 'features': fused_features,
155
+ 'attended_features': attended_features
156
+ }
157
+
158
+ def _extract_vit_features(self, x: torch.Tensor) -> torch.Tensor:
159
+ """Extract features from Vision Transformer."""
160
+ # Convert to PIL-like format for ViT
161
+ # ViT expects normalized images in [0, 1] range
162
+ x_normalized = x / 255.0
163
+
164
+ # Use the CLS token output as features
165
+ with torch.no_grad():
166
+ outputs = self.vit(pixel_values=x_normalized)
167
+ # Get the CLS token (first token)
168
+ cls_output = outputs.last_hidden_state[:, 0, :]
169
+
170
+ return cls_output
171
+
172
+
173
+ class MultiScaleAttention(nn.Module):
174
+ """Multi-scale attention mechanism for feature fusion."""
175
+
176
+ def __init__(self, efficientnet_dim: int, convnext_dim: int, swin_dim: int, vit_dim: int):
177
+ super().__init__()
178
+
179
+ # Project all features to a common dimension
180
+ self.common_dim = 512
181
+
182
+ # Projection layers to common dimension
183
+ self.efficientnet_projection = nn.Linear(efficientnet_dim, self.common_dim)
184
+ self.convnext_projection = nn.Linear(convnext_dim, self.common_dim)
185
+ self.swin_projection = nn.Linear(swin_dim, self.common_dim)
186
+ self.vit_projection = nn.Linear(vit_dim, self.common_dim)
187
+
188
+ # Attention weights for each feature type
189
+ self.efficientnet_attention = nn.Linear(self.common_dim, 1)
190
+ self.convnext_attention = nn.Linear(self.common_dim, 1)
191
+ self.swin_attention = nn.Linear(self.common_dim, 1)
192
+ self.vit_attention = nn.Linear(self.common_dim, 1)
193
+
194
+ def forward(self, efficientnet_features: torch.Tensor, convnext_features: torch.Tensor,
195
+ swin_features: torch.Tensor, vit_features: torch.Tensor) -> torch.Tensor:
196
+
197
+ # Project all features to common dimension
198
+ efficientnet_proj = self.efficientnet_projection(efficientnet_features)
199
+ convnext_proj = self.convnext_projection(convnext_features)
200
+ swin_proj = self.swin_projection(swin_features)
201
+ vit_proj = self.vit_projection(vit_features)
202
+
203
+ # Calculate attention weights
204
+ efficientnet_attn = torch.sigmoid(self.efficientnet_attention(efficientnet_proj))
205
+ convnext_attn = torch.sigmoid(self.convnext_attention(convnext_proj))
206
+ swin_attn = torch.sigmoid(self.swin_attention(swin_proj))
207
+ vit_attn = torch.sigmoid(self.vit_attention(vit_proj))
208
+
209
+ # Weighted features
210
+ weighted_efficientnet = efficientnet_proj * efficientnet_attn
211
+ weighted_convnext = convnext_proj * convnext_attn
212
+ weighted_swin = swin_proj * swin_attn
213
+ weighted_vit = vit_proj * vit_attn
214
+
215
+ # Combine weighted features
216
+ attended_features = (
217
+ weighted_efficientnet + weighted_convnext + weighted_swin + weighted_vit
218
+ ) / 4.0
219
+
220
+ return attended_features
221
+
222
+
223
+ class AdvancedLossFunction(nn.Module):
224
+ """Advanced loss function combining multiple loss types."""
225
+
226
+ def __init__(self, num_classes: int = 25, alpha: float = 0.4, beta: float = 0.3, gamma: float = 0.3):
227
+ super().__init__()
228
+ self.alpha = alpha # Main loss weight
229
+ self.beta = beta # Auxiliary loss weight
230
+ self.gamma = gamma # Focal loss weight
231
+
232
+ # Loss functions
233
+ self.cross_entropy = nn.CrossEntropyLoss(label_smoothing=0.1)
234
+ self.focal_loss = FocalLoss(alpha=1.0, gamma=2.0)
235
+ self.center_loss = CenterLoss(num_classes=num_classes, feat_dim=512)
236
+
237
+ def forward(self, outputs: Dict[str, torch.Tensor], targets: torch.Tensor) -> Dict[str, torch.Tensor]:
238
+ main_logits = outputs['logits']
239
+ aux_logits = [
240
+ outputs['aux_efficientnet'],
241
+ outputs['aux_convnext'],
242
+ outputs['aux_swin'],
243
+ outputs['aux_vit']
244
+ ]
245
+ features = outputs['features']
246
+
247
+ # Main classification loss
248
+ main_loss = self.cross_entropy(main_logits, targets)
249
+
250
+ # Auxiliary losses
251
+ aux_losses = []
252
+ for aux_logit in aux_logits:
253
+ aux_loss = self.cross_entropy(aux_logit, targets)
254
+ aux_losses.append(aux_loss)
255
+ aux_loss = torch.mean(torch.stack(aux_losses))
256
+
257
+ # Focal loss for hard examples
258
+ focal_loss = self.focal_loss(main_logits, targets)
259
+
260
+ # Center loss for feature learning
261
+ center_loss = self.center_loss(features, targets)
262
+
263
+ # Total loss
264
+ total_loss = (
265
+ self.alpha * main_loss +
266
+ self.beta * aux_loss +
267
+ self.gamma * focal_loss +
268
+ 0.1 * center_loss
269
+ )
270
+
271
+ return {
272
+ 'total_loss': total_loss,
273
+ 'main_loss': main_loss,
274
+ 'aux_loss': aux_loss,
275
+ 'focal_loss': focal_loss,
276
+ 'center_loss': center_loss
277
+ }
278
+
279
+
280
+ class FocalLoss(nn.Module):
281
+ """Focal Loss for handling class imbalance."""
282
+
283
+ def __init__(self, alpha: float = 1.0, gamma: float = 2.0):
284
+ super().__init__()
285
+ self.alpha = alpha
286
+ self.gamma = gamma
287
+
288
+ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
289
+ ce_loss = F.cross_entropy(inputs, targets, reduction='none')
290
+ pt = torch.exp(-ce_loss)
291
+ focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
292
+ return focal_loss.mean()
293
+
294
+
295
+ class CenterLoss(nn.Module):
296
+ """Center Loss for learning discriminative features."""
297
+
298
+ def __init__(self, num_classes: int, feat_dim: int, device: str = 'cpu'):
299
+ super().__init__()
300
+ self.num_classes = num_classes
301
+ self.feat_dim = feat_dim
302
+ self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
303
+
304
+ def forward(self, features: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
305
+ centers_batch = self.centers.index_select(0, targets)
306
+ return F.mse_loss(features, centers_batch)
307
+
308
+
309
+ def create_advanced_classifier(num_classes: int = 25) -> AdvancedPretrainedClassifier:
310
+ """Factory function to create the advanced classifier."""
311
+ return AdvancedPretrainedClassifier(num_classes=num_classes)
312
+
313
+
314
+ def create_advanced_loss(num_classes: int = 25) -> AdvancedLossFunction:
315
+ """Factory function to create the advanced loss function."""
316
+ return AdvancedLossFunction(num_classes=num_classes)