FrAnKu34t23 commited on
Commit
abd02d5
·
verified ·
1 Parent(s): d66cb01

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +159 -0
  2. best_model.pth +3 -0
  3. class_names.json +52 -0
  4. models.py +247 -0
  5. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio App for Bird Classification - Hugging Face Deployment
3
+ Enhanced model with 76.74% accuracy from Stage 2 training.
4
+ """
5
+ import gradio as gr
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from PIL import Image
9
+ import json
10
+ import numpy as np
11
+ from torchvision import transforms
12
+ import os
13
+
14
+ # Import our model architecture
15
+ from models import create_model
16
+
17
+ # Configuration
18
+ MODEL_PATH = "best_model.pth"
19
+ CLASS_NAMES_PATH = "class_names.json"
20
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
+
22
+ # Load class names
23
+ with open(CLASS_NAMES_PATH, 'r') as f:
24
+ class_names = json.load(f)
25
+
26
+ NUM_CLASSES = len(class_names)
27
+
28
+ # Load model
29
+ print("Loading model...")
30
+ model = create_model(
31
+ num_classes=NUM_CLASSES,
32
+ model_type='efficientnet_b2', # Stage 2 architecture
33
+ pretrained=False, # We're loading trained weights
34
+ dropout_rate=0.3 # Stage 2 dropout rate
35
+ )
36
+
37
+ # Load trained weights
38
+ if os.path.exists(MODEL_PATH):
39
+ checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
40
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
41
+ model.load_state_dict(checkpoint['model_state_dict'])
42
+ else:
43
+ model.load_state_dict(checkpoint)
44
+ print("✅ Model loaded successfully!")
45
+ else:
46
+ print("⚠️ Model file not found. Please ensure best_model.pth is in the repository.")
47
+
48
+ model.to(DEVICE)
49
+ model.eval()
50
+
51
+ # Image preprocessing (Stage 2 configuration)
52
+ transform = transforms.Compose([
53
+ transforms.Resize((320, 320)), # Stage 2 image size
54
+ transforms.ToTensor(),
55
+ transforms.Normalize(
56
+ mean=[0.485, 0.456, 0.406],
57
+ std=[0.229, 0.224, 0.225]
58
+ )
59
+ ])
60
+
61
+ def predict_bird(image):
62
+ """
63
+ Predict bird species from uploaded image.
64
+ """
65
+ try:
66
+ # Preprocess image
67
+ if isinstance(image, np.ndarray):
68
+ image = Image.fromarray(image.astype('uint8'))
69
+
70
+ # Convert to RGB if needed
71
+ if image.mode != 'RGB':
72
+ image = image.convert('RGB')
73
+
74
+ # Apply transformations
75
+ input_tensor = transform(image).unsqueeze(0).to(DEVICE)
76
+
77
+ # Prediction
78
+ with torch.no_grad():
79
+ outputs = model(input_tensor)
80
+ probabilities = F.softmax(outputs, dim=1)
81
+ confidence, predicted = torch.max(probabilities, 1)
82
+
83
+ # Get top 5 predictions
84
+ top5_prob, top5_indices = torch.topk(probabilities, 5)
85
+
86
+ # Format results
87
+ results = {}
88
+ for i in range(5):
89
+ class_idx = top5_indices[0][i].item()
90
+ prob = top5_prob[0][i].item()
91
+ class_name = class_names[class_idx].replace('_', ' ')
92
+ results[class_name] = float(prob)
93
+
94
+ return results
95
+
96
+ except Exception as e:
97
+ return {"Error": f"Prediction failed: {str(e)}"}
98
+
99
+ # Create Gradio interface
100
+ title = "🐦 Bird Species Classifier"
101
+ description = """
102
+ ## Advanced Bird Classification Model (76.74% Accuracy)
103
+
104
+ This model can classify **200 different bird species** using advanced deep learning techniques:
105
+
106
+ ### Model Details:
107
+ - **Architecture**: EfficientNet-B2 with enhanced regularization
108
+ - **Training Strategy**: Progressive training with MixUp augmentation
109
+ - **Performance**: 76.74% test accuracy (Stage 2 results)
110
+ - **Dataset**: CUB-200-2011 (200 bird species)
111
+
112
+ ### How to use:
113
+ 1. Upload a clear image of a bird
114
+ 2. The model will predict the top 5 most likely species
115
+ 3. Confidence scores show the model's certainty
116
+
117
+ ### Best Results Tips:
118
+ - Use high-quality, well-lit images
119
+ - Ensure the bird is clearly visible
120
+ - Close-up shots work better than distant ones
121
+ - Natural lighting produces better results
122
+
123
+ **Note**: This model was trained on the CUB-200-2011 dataset and works best with North American bird species.
124
+ """
125
+
126
+ article = """
127
+ ### Technical Implementation:
128
+ - **Framework**: PyTorch with EfficientNet-B2 backbone
129
+ - **Training**: Progressive training with MixUp data augmentation
130
+ - **Regularization**: Optimized dropout rates (0.3) and advanced augmentation
131
+ - **Image Size**: 320x320 pixels for optimal detail capture
132
+
133
+ ### About the Model:
134
+ This bird classifier was developed using advanced machine learning techniques including:
135
+ - Transfer learning from ImageNet-pretrained EfficientNet
136
+ - Progressive training strategy across multiple stages
137
+ - MixUp augmentation for improved generalization
138
+ - Comprehensive evaluation on 200 bird species
139
+
140
+ For more details about the training process and methodology, please refer to the repository documentation.
141
+ """
142
+
143
+ # Create the interface
144
+ iface = gr.Interface(
145
+ fn=predict_bird,
146
+ inputs=gr.Image(type="pil", label="Upload Bird Image"),
147
+ outputs=gr.Label(num_top_classes=5, label="Predictions"),
148
+ title=title,
149
+ description=description,
150
+ article=article,
151
+ examples=[
152
+ # You can add example images here if you have them
153
+ ],
154
+ allow_flagging="never",
155
+ theme=gr.themes.Soft()
156
+ )
157
+
158
+ if __name__ == "__main__":
159
+ iface.launch(debug=True)
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:429b3f9a74d67c440661705de6d86fd40355a5a781cb7b2e5ed4b20e79887d20
3
+ size 47237725
class_names.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "Black_footed_Albatross", "Laysan_Albatross", "Sooty_Albatross", "Groove_billed_Ani",
3
+ "Crested_Auklet", "Least_Auklet", "Parakeet_Auklet", "Rhinoceros_Auklet",
4
+ "Brewer_Blackbird", "Red_winged_Blackbird", "Rusty_Blackbird", "Yellow_headed_Blackbird",
5
+ "Bobolink", "Indigo_Bunting", "Lazuli_Bunting", "Painted_Bunting",
6
+ "Cardinal", "Spotted_Catbird", "Gray_Catbird", "Yellow_breasted_Chat",
7
+ "Eastern_Towhee", "Chuck_will_Widow", "Brandt_Cormorant", "Red_faced_Cormorant",
8
+ "Pelagic_Cormorant", "Bronzed_Cowbird", "Shiny_Cowbird", "Brown_Creeper",
9
+ "American_Crow", "Fish_Crow", "Black_billed_Cuckoo", "Mangrove_Cuckoo",
10
+ "Yellow_billed_Cuckoo", "Gray_crowned_Rosy_Finch", "Purple_Finch", "Northern_Flicker",
11
+ "Acadian_Flycatcher", "Great_Crested_Flycatcher", "Least_Flycatcher", "Olive_sided_Flycatcher",
12
+ "Scissor_tailed_Flycatcher", "Vermilion_Flycatcher", "Yellow_bellied_Flycatcher", "Frigatebird",
13
+ "Northern_Fulmar", "Gadwall", "American_Goldfinch", "European_Goldfinch",
14
+ "Boat_tailed_Grackle", "Eared_Grebe", "Horned_Grebe", "Pied_billed_Grebe",
15
+ "Western_Grebe", "Blue_Grosbeak", "Evening_Grosbeak", "Pine_Grosbeak",
16
+ "Rose_breasted_Grosbeak", "Pigeon_Guillemot", "California_Gull", "Glaucous_winged_Gull",
17
+ "Heermann_Gull", "Herring_Gull", "Ivory_Gull", "Ring_billed_Gull",
18
+ "Slaty_backed_Gull", "Western_Gull", "Anna_Hummingbird", "Ruby_throated_Hummingbird",
19
+ "Rufous_Hummingbird", "Green_Violetear", "Long_tailed_Jaeger", "Pomarine_Jaeger",
20
+ "Blue_Jay", "Florida_Jay", "Green_Jay", "Dark_eyed_Junco",
21
+ "Tropical_Kingbird", "Gray_Kingbird", "Belted_Kingfisher", "Green_Kingfisher",
22
+ "Pied_Kingfisher", "Ringed_Kingfisher", "White_breasted_Kingfisher", "Red_legged_Kittiwake",
23
+ "Horned_Lark", "Pacific_Lark", "Mallard", "Western_Meadowlark",
24
+ "Hooded_Merganser", "Red_breasted_Merganser", "Mockingbird", "Nighthawk",
25
+ "Clark_Nutcracker", "White_breasted_Nuthatch", "Baltimore_Oriole", "Hooded_Oriole",
26
+ "Orchard_Oriole", "Scott_Oriole", "Ovenbird", "Brown_Pelican",
27
+ "White_Pelican", "Western_Wood_Pewee", "Sayornis", "American_Pipit",
28
+ "Whip_poor_Will", "Horned_Puffin", "Common_Raven", "White_necked_Raven",
29
+ "American_Redstart", "Geococcyx", "Loggerhead_Shrike", "Great_Grey_Shrike",
30
+ "Baird_Sparrow", "Black_throated_Sparrow", "Brewer_Sparrow", "Chipping_Sparrow",
31
+ "Clay_colored_Sparrow", "House_Sparrow", "Field_Sparrow", "Fox_Sparrow",
32
+ "Grasshopper_Sparrow", "Harris_Sparrow", "Henslow_Sparrow", "Le_Conte_Sparrow",
33
+ "Lincoln_Sparrow", "Nelson_Sharp_tailed_Sparrow", "Savannah_Sparrow", "Seaside_Sparrow",
34
+ "Song_Sparrow", "Tree_Sparrow", "Vesper_Sparrow", "White_crowned_Sparrow",
35
+ "White_throated_Sparrow", "Cape_Glossy_Starling", "Bank_Swallow", "Barn_Swallow",
36
+ "Cliff_Swallow", "Tree_Swallow", "Scarlet_Tanager", "Summer_Tanager",
37
+ "Artic_Tern", "Black_Tern", "Caspian_Tern", "Common_Tern",
38
+ "Elegant_Tern", "Forsters_Tern", "Least_Tern", "Green_tailed_Towhee",
39
+ "Brown_Thrasher", "Sage_Thrasher", "Black_capped_Vireo", "Blue_headed_Vireo",
40
+ "Philadelphia_Vireo", "Red_eyed_Vireo", "Warbling_Vireo", "White_eyed_Vireo",
41
+ "Yellow_throated_Vireo", "Bay_breasted_Warbler", "Black_and_white_Warbler", "Black_throated_Blue_Warbler",
42
+ "Blue_winged_Warbler", "Canada_Warbler", "Cape_May_Warbler", "Cerulean_Warbler",
43
+ "Chestnut_sided_Warbler", "Golden_winged_Warbler", "Hooded_Warbler", "Kentucky_Warbler",
44
+ "Magnolia_Warbler", "Mourning_Warbler", "Myrtle_Warbler", "Nashville_Warbler",
45
+ "Orange_crowned_Warbler", "Palm_Warbler", "Pine_Warbler", "Prairie_Warbler",
46
+ "Prothonotary_Warbler", "Tennessee_Warbler", "Wilson_Warbler", "Worm_eating_Warbler",
47
+ "Yellow_Warbler", "Northern_Waterthrush", "Louisiana_Waterthrush", "Bohemian_Waxwing",
48
+ "Cedar_Waxwing", "American_Three_toed_Woodpecker", "Pileated_Woodpecker", "Red_bellied_Woodpecker",
49
+ "Red_cockaded_Woodpecker", "Red_headed_Woodpecker", "Downy_Woodpecker", "Bewick_Wren",
50
+ "Cactus_Wren", "Carolina_Wren", "House_Wren", "Marsh_Wren",
51
+ "Rock_Wren", "Winter_Wren", "Common_Yellowthroat"
52
+ ]
models.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Bird classification model architectures with overfitting prevention.
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torchvision import models
8
+ from typing import Optional
9
+
10
+ # Try to import EfficientNet
11
+ try:
12
+ from efficientnet_pytorch import EfficientNet
13
+ EFFICIENTNET_AVAILABLE = True
14
+ except ImportError:
15
+ EFFICIENTNET_AVAILABLE = False
16
+ print("EfficientNet not available. Install with: pip install efficientnet-pytorch")
17
+
18
+
19
+ class BirdClassifier(nn.Module):
20
+ """
21
+ Bird classification model with ResNet backbone and overfitting prevention.
22
+ """
23
+
24
+ def __init__(self, num_classes: int, architecture: str = 'resnet50',
25
+ pretrained: bool = True, dropout_rate: float = 0.5,
26
+ freeze_backbone: bool = False):
27
+ """
28
+ Initialize the bird classifier.
29
+
30
+ Args:
31
+ num_classes: Number of bird classes
32
+ architecture: Backbone architecture ('resnet50', 'resnet18', 'efficientnet_b0')
33
+ pretrained: Whether to use pretrained weights
34
+ dropout_rate: Dropout rate for regularization
35
+ freeze_backbone: Whether to freeze backbone weights
36
+ """
37
+ super(BirdClassifier, self).__init__()
38
+
39
+ self.num_classes = num_classes
40
+ self.dropout_rate = dropout_rate
41
+
42
+ # Choose backbone architecture
43
+ if architecture == 'resnet50':
44
+ self.backbone = models.resnet50(pretrained=pretrained)
45
+ num_features = self.backbone.fc.in_features
46
+ self.backbone.fc = nn.Identity() # Remove original classifier
47
+ elif architecture == 'resnet18':
48
+ self.backbone = models.resnet18(pretrained=pretrained)
49
+ num_features = self.backbone.fc.in_features
50
+ self.backbone.fc = nn.Identity()
51
+ elif architecture == 'resnet101':
52
+ self.backbone = models.resnet101(pretrained=pretrained)
53
+ num_features = self.backbone.fc.in_features
54
+ self.backbone.fc = nn.Identity()
55
+ elif architecture == 'efficientnet_b0':
56
+ self.backbone = models.efficientnet_b0(pretrained=pretrained)
57
+ num_features = self.backbone.classifier[1].in_features
58
+ self.backbone.classifier = nn.Identity()
59
+ elif architecture in ['efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4'] and EFFICIENTNET_AVAILABLE:
60
+ model_name = architecture.replace('_', '-')
61
+ if pretrained:
62
+ self.backbone = EfficientNet.from_pretrained(model_name)
63
+ else:
64
+ self.backbone = EfficientNet.from_name(model_name)
65
+ num_features = self.backbone._fc.in_features
66
+ self.backbone._fc = nn.Identity()
67
+ else:
68
+ raise ValueError(f"Unsupported architecture: {architecture}")
69
+
70
+ # Freeze backbone if requested
71
+ if freeze_backbone:
72
+ for param in self.backbone.parameters():
73
+ param.requires_grad = False
74
+
75
+ # Enhanced classifier head with batch normalization and progressive dimension reduction
76
+ # Optimized regularization for Stage 2 performance (76.74% accuracy)
77
+ self.classifier = nn.Sequential(
78
+ nn.Dropout(p=dropout_rate * 0.6), # Stage 2 optimization: 0.3 * 0.6 = 0.18
79
+ nn.Linear(num_features, 512), # Optimized size
80
+ nn.BatchNorm1d(512),
81
+ nn.ReLU(inplace=True),
82
+ nn.Dropout(p=dropout_rate * 0.5), # Stage 2 optimization: 0.3 * 0.5 = 0.15
83
+ nn.Linear(512, 256),
84
+ nn.BatchNorm1d(256),
85
+ nn.ReLU(inplace=True),
86
+ nn.Dropout(p=dropout_rate * 0.3), # Stage 2 optimization: 0.3 * 0.3 = 0.09
87
+ nn.Linear(256, num_classes)
88
+ )
89
+
90
+ # Initialize weights
91
+ self._initialize_weights()
92
+
93
+ def _initialize_weights(self):
94
+ """Initialize classifier weights with better initialization."""
95
+ for m in self.classifier.modules():
96
+ if isinstance(m, nn.Linear):
97
+ nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
98
+ if m.bias is not None:
99
+ nn.init.constant_(m.bias, 0)
100
+ elif isinstance(m, nn.BatchNorm1d):
101
+ nn.init.constant_(m.weight, 1)
102
+ nn.init.constant_(m.bias, 0)
103
+
104
+ def forward(self, x):
105
+ """Forward pass."""
106
+ features = self.backbone(x)
107
+ output = self.classifier(features)
108
+ return output
109
+
110
+
111
+ class LightweightBirdClassifier(nn.Module):
112
+ """
113
+ Lightweight CNN model for bird classification with batch normalization.
114
+ """
115
+
116
+ def __init__(self, num_classes: int, dropout_rate: float = 0.5):
117
+ """
118
+ Initialize lightweight classifier.
119
+
120
+ Args:
121
+ num_classes: Number of bird classes
122
+ dropout_rate: Dropout rate for regularization
123
+ """
124
+ super(LightweightBirdClassifier, self).__init__()
125
+
126
+ self.features = nn.Sequential(
127
+ # Block 1
128
+ nn.Conv2d(3, 32, kernel_size=3, padding=1),
129
+ nn.BatchNorm2d(32),
130
+ nn.ReLU(inplace=True),
131
+ nn.Conv2d(32, 32, kernel_size=3, padding=1),
132
+ nn.BatchNorm2d(32),
133
+ nn.ReLU(inplace=True),
134
+ nn.MaxPool2d(2, 2),
135
+ nn.Dropout2d(p=dropout_rate/2),
136
+
137
+ # Block 2
138
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
139
+ nn.BatchNorm2d(64),
140
+ nn.ReLU(inplace=True),
141
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
142
+ nn.BatchNorm2d(64),
143
+ nn.ReLU(inplace=True),
144
+ nn.MaxPool2d(2, 2),
145
+ nn.Dropout2d(p=dropout_rate/2),
146
+
147
+ # Block 3
148
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
149
+ nn.BatchNorm2d(128),
150
+ nn.ReLU(inplace=True),
151
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
152
+ nn.BatchNorm2d(128),
153
+ nn.ReLU(inplace=True),
154
+ nn.MaxPool2d(2, 2),
155
+ nn.Dropout2d(p=dropout_rate/2),
156
+
157
+ # Block 4
158
+ nn.Conv2d(128, 256, kernel_size=3, padding=1),
159
+ nn.BatchNorm2d(256),
160
+ nn.ReLU(inplace=True),
161
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
162
+ nn.BatchNorm2d(256),
163
+ nn.ReLU(inplace=True),
164
+ nn.AdaptiveAvgPool2d((1, 1)),
165
+ )
166
+
167
+ self.classifier = nn.Sequential(
168
+ nn.Flatten(),
169
+ nn.Dropout(p=dropout_rate),
170
+ nn.Linear(256, 128),
171
+ nn.ReLU(inplace=True),
172
+ nn.Dropout(p=dropout_rate),
173
+ nn.Linear(128, num_classes)
174
+ )
175
+
176
+ self._initialize_weights()
177
+
178
+ def _initialize_weights(self):
179
+ """Initialize model weights."""
180
+ for m in self.modules():
181
+ if isinstance(m, nn.Conv2d):
182
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
183
+ if m.bias is not None:
184
+ nn.init.constant_(m.bias, 0)
185
+ elif isinstance(m, nn.BatchNorm2d):
186
+ nn.init.constant_(m.weight, 1)
187
+ nn.init.constant_(m.bias, 0)
188
+ elif isinstance(m, nn.Linear):
189
+ nn.init.xavier_uniform_(m.weight)
190
+ nn.init.constant_(m.bias, 0)
191
+
192
+ def forward(self, x):
193
+ """Forward pass."""
194
+ x = self.features(x)
195
+ x = self.classifier(x)
196
+ return x
197
+
198
+
199
+ def create_model(num_classes: int, model_type: str = 'resnet50',
200
+ pretrained: bool = True, dropout_rate: float = 0.5,
201
+ freeze_backbone: bool = False) -> nn.Module:
202
+ """
203
+ Create a bird classification model.
204
+
205
+ Args:
206
+ num_classes: Number of bird classes
207
+ model_type: Type of model ('resnet50', 'resnet18', 'efficientnet_b0', 'lightweight')
208
+ pretrained: Whether to use pretrained weights
209
+ dropout_rate: Dropout rate for regularization
210
+ freeze_backbone: Whether to freeze backbone weights (ignored for lightweight model)
211
+
212
+ Returns:
213
+ PyTorch model
214
+ """
215
+ if model_type == 'lightweight':
216
+ return LightweightBirdClassifier(num_classes, dropout_rate)
217
+ else:
218
+ return BirdClassifier(num_classes, model_type, pretrained,
219
+ dropout_rate, freeze_backbone)
220
+
221
+
222
+ class ModelEnsemble(nn.Module):
223
+ """
224
+ Ensemble of multiple models for improved performance.
225
+ """
226
+
227
+ def __init__(self, models_list: list):
228
+ """
229
+ Initialize model ensemble.
230
+
231
+ Args:
232
+ models_list: List of trained models to ensemble
233
+ """
234
+ super(ModelEnsemble, self).__init__()
235
+ self.models = nn.ModuleList(models_list)
236
+
237
+ def forward(self, x):
238
+ """Forward pass through all models and average predictions."""
239
+ predictions = []
240
+ for model in self.models:
241
+ with torch.no_grad():
242
+ pred = F.softmax(model(x), dim=1)
243
+ predictions.append(pred)
244
+
245
+ # Average predictions
246
+ ensemble_pred = torch.stack(predictions, dim=0).mean(dim=0)
247
+ return torch.log(ensemble_pred + 1e-8) # Convert back to log probabilities
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.12.0
2
+ torchvision>=0.13.0
3
+ numpy>=1.21.0
4
+ Pillow>=8.3.0
5
+ matplotlib>=3.5.0
6
+ scikit-learn>=1.1.0
7
+ tqdm>=4.64.0
8
+ pandas>=1.4.0
9
+ seaborn>=0.11.0
10
+ efficientnet-pytorch>=0.7.1
11
+ gradio>=3.40.0