FrAnKu34t23 commited on
Commit
40348b6
·
verified ·
1 Parent(s): deb766a

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +246 -246
models.py CHANGED
@@ -1,247 +1,247 @@
1
- """
2
- Bird classification model architectures with overfitting prevention.
3
- """
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- from torchvision import models
8
- from typing import Optional
9
-
10
- # Try to import EfficientNet
11
- try:
12
- from efficientnet_pytorch import EfficientNet
13
- EFFICIENTNET_AVAILABLE = True
14
- except ImportError:
15
- EFFICIENTNET_AVAILABLE = False
16
- print("EfficientNet not available. Install with: pip install efficientnet-pytorch")
17
-
18
-
19
- class BirdClassifier(nn.Module):
20
- """
21
- Bird classification model with ResNet backbone and overfitting prevention.
22
- """
23
-
24
- def __init__(self, num_classes: int, architecture: str = 'resnet50',
25
- pretrained: bool = True, dropout_rate: float = 0.5,
26
- freeze_backbone: bool = False):
27
- """
28
- Initialize the bird classifier.
29
-
30
- Args:
31
- num_classes: Number of bird classes
32
- architecture: Backbone architecture ('resnet50', 'resnet18', 'efficientnet_b0')
33
- pretrained: Whether to use pretrained weights
34
- dropout_rate: Dropout rate for regularization
35
- freeze_backbone: Whether to freeze backbone weights
36
- """
37
- super(BirdClassifier, self).__init__()
38
-
39
- self.num_classes = num_classes
40
- self.dropout_rate = dropout_rate
41
-
42
- # Choose backbone architecture
43
- if architecture == 'resnet50':
44
- self.backbone = models.resnet50(pretrained=pretrained)
45
- num_features = self.backbone.fc.in_features
46
- self.backbone.fc = nn.Identity() # Remove original classifier
47
- elif architecture == 'resnet18':
48
- self.backbone = models.resnet18(pretrained=pretrained)
49
- num_features = self.backbone.fc.in_features
50
- self.backbone.fc = nn.Identity()
51
- elif architecture == 'resnet101':
52
- self.backbone = models.resnet101(pretrained=pretrained)
53
- num_features = self.backbone.fc.in_features
54
- self.backbone.fc = nn.Identity()
55
- elif architecture == 'efficientnet_b0':
56
- self.backbone = models.efficientnet_b0(pretrained=pretrained)
57
- num_features = self.backbone.classifier[1].in_features
58
- self.backbone.classifier = nn.Identity()
59
- elif architecture in ['efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4'] and EFFICIENTNET_AVAILABLE:
60
- model_name = architecture.replace('_', '-')
61
- if pretrained:
62
- self.backbone = EfficientNet.from_pretrained(model_name)
63
- else:
64
- self.backbone = EfficientNet.from_name(model_name)
65
- num_features = self.backbone._fc.in_features
66
- self.backbone._fc = nn.Identity()
67
- else:
68
- raise ValueError(f"Unsupported architecture: {architecture}")
69
-
70
- # Freeze backbone if requested
71
- if freeze_backbone:
72
- for param in self.backbone.parameters():
73
- param.requires_grad = False
74
-
75
- # Enhanced classifier head with batch normalization and progressive dimension reduction
76
- # Optimized regularization for Stage 2 performance (76.74% accuracy)
77
- self.classifier = nn.Sequential(
78
- nn.Dropout(p=dropout_rate * 0.6), # Stage 2 optimization: 0.3 * 0.6 = 0.18
79
- nn.Linear(num_features, 512), # Optimized size
80
- nn.BatchNorm1d(512),
81
- nn.ReLU(inplace=True),
82
- nn.Dropout(p=dropout_rate * 0.5), # Stage 2 optimization: 0.3 * 0.5 = 0.15
83
- nn.Linear(512, 256),
84
- nn.BatchNorm1d(256),
85
- nn.ReLU(inplace=True),
86
- nn.Dropout(p=dropout_rate * 0.3), # Stage 2 optimization: 0.3 * 0.3 = 0.09
87
- nn.Linear(256, num_classes)
88
- )
89
-
90
- # Initialize weights
91
- self._initialize_weights()
92
-
93
- def _initialize_weights(self):
94
- """Initialize classifier weights with better initialization."""
95
- for m in self.classifier.modules():
96
- if isinstance(m, nn.Linear):
97
- nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
98
- if m.bias is not None:
99
- nn.init.constant_(m.bias, 0)
100
- elif isinstance(m, nn.BatchNorm1d):
101
- nn.init.constant_(m.weight, 1)
102
- nn.init.constant_(m.bias, 0)
103
-
104
- def forward(self, x):
105
- """Forward pass."""
106
- features = self.backbone(x)
107
- output = self.classifier(features)
108
- return output
109
-
110
-
111
- class LightweightBirdClassifier(nn.Module):
112
- """
113
- Lightweight CNN model for bird classification with batch normalization.
114
- """
115
-
116
- def __init__(self, num_classes: int, dropout_rate: float = 0.5):
117
- """
118
- Initialize lightweight classifier.
119
-
120
- Args:
121
- num_classes: Number of bird classes
122
- dropout_rate: Dropout rate for regularization
123
- """
124
- super(LightweightBirdClassifier, self).__init__()
125
-
126
- self.features = nn.Sequential(
127
- # Block 1
128
- nn.Conv2d(3, 32, kernel_size=3, padding=1),
129
- nn.BatchNorm2d(32),
130
- nn.ReLU(inplace=True),
131
- nn.Conv2d(32, 32, kernel_size=3, padding=1),
132
- nn.BatchNorm2d(32),
133
- nn.ReLU(inplace=True),
134
- nn.MaxPool2d(2, 2),
135
- nn.Dropout2d(p=dropout_rate/2),
136
-
137
- # Block 2
138
- nn.Conv2d(32, 64, kernel_size=3, padding=1),
139
- nn.BatchNorm2d(64),
140
- nn.ReLU(inplace=True),
141
- nn.Conv2d(64, 64, kernel_size=3, padding=1),
142
- nn.BatchNorm2d(64),
143
- nn.ReLU(inplace=True),
144
- nn.MaxPool2d(2, 2),
145
- nn.Dropout2d(p=dropout_rate/2),
146
-
147
- # Block 3
148
- nn.Conv2d(64, 128, kernel_size=3, padding=1),
149
- nn.BatchNorm2d(128),
150
- nn.ReLU(inplace=True),
151
- nn.Conv2d(128, 128, kernel_size=3, padding=1),
152
- nn.BatchNorm2d(128),
153
- nn.ReLU(inplace=True),
154
- nn.MaxPool2d(2, 2),
155
- nn.Dropout2d(p=dropout_rate/2),
156
-
157
- # Block 4
158
- nn.Conv2d(128, 256, kernel_size=3, padding=1),
159
- nn.BatchNorm2d(256),
160
- nn.ReLU(inplace=True),
161
- nn.Conv2d(256, 256, kernel_size=3, padding=1),
162
- nn.BatchNorm2d(256),
163
- nn.ReLU(inplace=True),
164
- nn.AdaptiveAvgPool2d((1, 1)),
165
- )
166
-
167
- self.classifier = nn.Sequential(
168
- nn.Flatten(),
169
- nn.Dropout(p=dropout_rate),
170
- nn.Linear(256, 128),
171
- nn.ReLU(inplace=True),
172
- nn.Dropout(p=dropout_rate),
173
- nn.Linear(128, num_classes)
174
- )
175
-
176
- self._initialize_weights()
177
-
178
- def _initialize_weights(self):
179
- """Initialize model weights."""
180
- for m in self.modules():
181
- if isinstance(m, nn.Conv2d):
182
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
183
- if m.bias is not None:
184
- nn.init.constant_(m.bias, 0)
185
- elif isinstance(m, nn.BatchNorm2d):
186
- nn.init.constant_(m.weight, 1)
187
- nn.init.constant_(m.bias, 0)
188
- elif isinstance(m, nn.Linear):
189
- nn.init.xavier_uniform_(m.weight)
190
- nn.init.constant_(m.bias, 0)
191
-
192
- def forward(self, x):
193
- """Forward pass."""
194
- x = self.features(x)
195
- x = self.classifier(x)
196
- return x
197
-
198
-
199
- def create_model(num_classes: int, model_type: str = 'resnet50',
200
- pretrained: bool = True, dropout_rate: float = 0.5,
201
- freeze_backbone: bool = False) -> nn.Module:
202
- """
203
- Create a bird classification model.
204
-
205
- Args:
206
- num_classes: Number of bird classes
207
- model_type: Type of model ('resnet50', 'resnet18', 'efficientnet_b0', 'lightweight')
208
- pretrained: Whether to use pretrained weights
209
- dropout_rate: Dropout rate for regularization
210
- freeze_backbone: Whether to freeze backbone weights (ignored for lightweight model)
211
-
212
- Returns:
213
- PyTorch model
214
- """
215
- if model_type == 'lightweight':
216
- return LightweightBirdClassifier(num_classes, dropout_rate)
217
- else:
218
- return BirdClassifier(num_classes, model_type, pretrained,
219
- dropout_rate, freeze_backbone)
220
-
221
-
222
- class ModelEnsemble(nn.Module):
223
- """
224
- Ensemble of multiple models for improved performance.
225
- """
226
-
227
- def __init__(self, models_list: list):
228
- """
229
- Initialize model ensemble.
230
-
231
- Args:
232
- models_list: List of trained models to ensemble
233
- """
234
- super(ModelEnsemble, self).__init__()
235
- self.models = nn.ModuleList(models_list)
236
-
237
- def forward(self, x):
238
- """Forward pass through all models and average predictions."""
239
- predictions = []
240
- for model in self.models:
241
- with torch.no_grad():
242
- pred = F.softmax(model(x), dim=1)
243
- predictions.append(pred)
244
-
245
- # Average predictions
246
- ensemble_pred = torch.stack(predictions, dim=0).mean(dim=0)
247
  return torch.log(ensemble_pred + 1e-8) # Convert back to log probabilities
 
1
+ """
2
+ Bird classification model architectures with overfitting prevention.
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torchvision import models
8
+ from typing import Optional
9
+
10
+ # Try to import EfficientNet
11
+ try:
12
+ from efficientnet_pytorch import EfficientNet
13
+ EFFICIENTNET_AVAILABLE = True
14
+ except ImportError:
15
+ EFFICIENTNET_AVAILABLE = False
16
+ print("EfficientNet not available. Install with: pip install efficientnet-pytorch")
17
+
18
+
19
+ class BirdClassifier(nn.Module):
20
+ """
21
+ Bird classification model with ResNet backbone and overfitting prevention.
22
+ """
23
+
24
+ def __init__(self, num_classes: int, architecture: str = 'resnet50',
25
+ pretrained: bool = True, dropout_rate: float = 0.5,
26
+ freeze_backbone: bool = False):
27
+ """
28
+ Initialize the bird classifier.
29
+
30
+ Args:
31
+ num_classes: Number of bird classes
32
+ architecture: Backbone architecture ('resnet50', 'resnet18', 'efficientnet_b0')
33
+ pretrained: Whether to use pretrained weights
34
+ dropout_rate: Dropout rate for regularization
35
+ freeze_backbone: Whether to freeze backbone weights
36
+ """
37
+ super(BirdClassifier, self).__init__()
38
+
39
+ self.num_classes = num_classes
40
+ self.dropout_rate = dropout_rate
41
+
42
+ # Choose backbone architecture
43
+ if architecture == 'resnet50':
44
+ self.backbone = models.resnet50(pretrained=pretrained)
45
+ num_features = self.backbone.fc.in_features
46
+ self.backbone.fc = nn.Identity() # Remove original classifier
47
+ elif architecture == 'resnet18':
48
+ self.backbone = models.resnet18(pretrained=pretrained)
49
+ num_features = self.backbone.fc.in_features
50
+ self.backbone.fc = nn.Identity()
51
+ elif architecture == 'resnet101':
52
+ self.backbone = models.resnet101(pretrained=pretrained)
53
+ num_features = self.backbone.fc.in_features
54
+ self.backbone.fc = nn.Identity()
55
+ elif architecture == 'efficientnet_b0':
56
+ self.backbone = models.efficientnet_b0(pretrained=pretrained)
57
+ num_features = self.backbone.classifier[1].in_features
58
+ self.backbone.classifier = nn.Identity()
59
+ elif architecture in ['efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4'] and EFFICIENTNET_AVAILABLE:
60
+ model_name = architecture.replace('_', '-')
61
+ if pretrained:
62
+ self.backbone = EfficientNet.from_pretrained(model_name)
63
+ else:
64
+ self.backbone = EfficientNet.from_name(model_name)
65
+ num_features = self.backbone._fc.in_features
66
+ self.backbone._fc = nn.Identity()
67
+ else:
68
+ raise ValueError(f"Unsupported architecture: {architecture}")
69
+
70
+ # Freeze backbone if requested
71
+ if freeze_backbone:
72
+ for param in self.backbone.parameters():
73
+ param.requires_grad = False
74
+
75
+ # Enhanced classifier head with batch normalization and progressive dimension reduction
76
+ # Optimized regularization for Stage 2 performance (76.74% accuracy)
77
+ self.classifier = nn.Sequential(
78
+ nn.Dropout(p=dropout_rate * 0.6), # Stage 2 optimization: 0.3 * 0.6 = 0.18
79
+ nn.Linear(num_features, 512), # Optimized size
80
+ nn.BatchNorm1d(512),
81
+ nn.ReLU(inplace=True),
82
+ nn.Dropout(p=dropout_rate * 0.5), # Stage 2 optimization: 0.3 * 0.5 = 0.15
83
+ nn.Linear(512, 256),
84
+ nn.BatchNorm1d(256),
85
+ nn.ReLU(inplace=True),
86
+ nn.Dropout(p=dropout_rate * 0.3), # Stage 2 optimization: 0.3 * 0.3 = 0.09
87
+ nn.Linear(256, num_classes)
88
+ )
89
+
90
+ # Initialize weights
91
+ self._initialize_weights()
92
+
93
+ def _initialize_weights(self):
94
+ """Initialize classifier weights with better initialization."""
95
+ for m in self.classifier.modules():
96
+ if isinstance(m, nn.Linear):
97
+ nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
98
+ if m.bias is not None:
99
+ nn.init.constant_(m.bias, 0)
100
+ elif isinstance(m, nn.BatchNorm1d):
101
+ nn.init.constant_(m.weight, 1)
102
+ nn.init.constant_(m.bias, 0)
103
+
104
+ def forward(self, x):
105
+ """Forward pass."""
106
+ features = self.backbone(x)
107
+ output = self.classifier(features)
108
+ return output
109
+
110
+
111
+ class LightweightBirdClassifier(nn.Module):
112
+ """
113
+ Lightweight CNN model for bird classification with batch normalization.
114
+ """
115
+
116
+ def __init__(self, num_classes: int, dropout_rate: float = 0.5):
117
+ """
118
+ Initialize lightweight classifier.
119
+
120
+ Args:
121
+ num_classes: Number of bird classes
122
+ dropout_rate: Dropout rate for regularization
123
+ """
124
+ super(LightweightBirdClassifier, self).__init__()
125
+
126
+ self.features = nn.Sequential(
127
+ # Block 1
128
+ nn.Conv2d(3, 32, kernel_size=3, padding=1),
129
+ nn.BatchNorm2d(32),
130
+ nn.ReLU(inplace=True),
131
+ nn.Conv2d(32, 32, kernel_size=3, padding=1),
132
+ nn.BatchNorm2d(32),
133
+ nn.ReLU(inplace=True),
134
+ nn.MaxPool2d(2, 2),
135
+ nn.Dropout2d(p=dropout_rate/2),
136
+
137
+ # Block 2
138
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
139
+ nn.BatchNorm2d(64),
140
+ nn.ReLU(inplace=True),
141
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
142
+ nn.BatchNorm2d(64),
143
+ nn.ReLU(inplace=True),
144
+ nn.MaxPool2d(2, 2),
145
+ nn.Dropout2d(p=dropout_rate/2),
146
+
147
+ # Block 3
148
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
149
+ nn.BatchNorm2d(128),
150
+ nn.ReLU(inplace=True),
151
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
152
+ nn.BatchNorm2d(128),
153
+ nn.ReLU(inplace=True),
154
+ nn.MaxPool2d(2, 2),
155
+ nn.Dropout2d(p=dropout_rate/2),
156
+
157
+ # Block 4
158
+ nn.Conv2d(128, 256, kernel_size=3, padding=1),
159
+ nn.BatchNorm2d(256),
160
+ nn.ReLU(inplace=True),
161
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
162
+ nn.BatchNorm2d(256),
163
+ nn.ReLU(inplace=True),
164
+ nn.AdaptiveAvgPool2d((1, 1)),
165
+ )
166
+
167
+ self.classifier = nn.Sequential(
168
+ nn.Flatten(),
169
+ nn.Dropout(p=dropout_rate),
170
+ nn.Linear(256, 128),
171
+ nn.ReLU(inplace=True),
172
+ nn.Dropout(p=dropout_rate),
173
+ nn.Linear(128, num_classes)
174
+ )
175
+
176
+ self._initialize_weights()
177
+
178
+ def _initialize_weights(self):
179
+ """Initialize model weights."""
180
+ for m in self.modules():
181
+ if isinstance(m, nn.Conv2d):
182
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
183
+ if m.bias is not None:
184
+ nn.init.constant_(m.bias, 0)
185
+ elif isinstance(m, nn.BatchNorm2d):
186
+ nn.init.constant_(m.weight, 1)
187
+ nn.init.constant_(m.bias, 0)
188
+ elif isinstance(m, nn.Linear):
189
+ nn.init.xavier_uniform_(m.weight)
190
+ nn.init.constant_(m.bias, 0)
191
+
192
+ def forward(self, x):
193
+ """Forward pass."""
194
+ x = self.features(x)
195
+ x = self.classifier(x)
196
+ return x
197
+
198
+
199
+ def create_model(num_classes: int, model_type: str = 'resnet50',
200
+ pretrained: bool = True, dropout_rate: float = 0.5,
201
+ freeze_backbone: bool = False) -> nn.Module:
202
+ """
203
+ Create a bird classification model.
204
+
205
+ Args:
206
+ num_classes: Number of bird classes
207
+ model_type: Type of model ('resnet50', 'resnet18', 'efficientnet_b0', 'lightweight')
208
+ pretrained: Whether to use pretrained weights
209
+ dropout_rate: Dropout rate for regularization
210
+ freeze_backbone: Whether to freeze backbone weights (ignored for lightweight model)
211
+
212
+ Returns:
213
+ PyTorch model
214
+ """
215
+ if model_type == 'lightweight':
216
+ return LightweightBirdClassifier(num_classes, dropout_rate)
217
+ else:
218
+ return BirdClassifier(num_classes, model_type, pretrained,
219
+ dropout_rate, freeze_backbone)
220
+
221
+
222
+ class ModelEnsemble(nn.Module):
223
+ """
224
+ Ensemble of multiple models for improved performance.
225
+ """
226
+
227
+ def __init__(self, models_list: list):
228
+ """
229
+ Initialize model ensemble.
230
+
231
+ Args:
232
+ models_list: List of trained models to ensemble
233
+ """
234
+ super(ModelEnsemble, self).__init__()
235
+ self.models = nn.ModuleList(models_list)
236
+
237
+ def forward(self, x):
238
+ """Forward pass through all models and average predictions."""
239
+ predictions = []
240
+ for model in self.models:
241
+ with torch.no_grad():
242
+ pred = F.softmax(model(x), dim=1)
243
+ predictions.append(pred)
244
+
245
+ # Average predictions
246
+ ensemble_pred = torch.stack(predictions, dim=0).mean(dim=0)
247
  return torch.log(ensemble_pred + 1e-8) # Convert back to log probabilities