FrAnKu34t23 commited on
Commit
7d3d1f4
·
verified ·
1 Parent(s): abd02d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -158
app.py CHANGED
@@ -1,159 +1,207 @@
1
- """
2
- Gradio App for Bird Classification - Hugging Face Deployment
3
- Enhanced model with 76.74% accuracy from Stage 2 training.
4
- """
5
- import gradio as gr
6
- import torch
7
- import torch.nn.functional as F
8
- from PIL import Image
9
- import json
10
- import numpy as np
11
- from torchvision import transforms
12
- import os
13
-
14
- # Import our model architecture
15
- from models import create_model
16
-
17
- # Configuration
18
- MODEL_PATH = "best_model.pth"
19
- CLASS_NAMES_PATH = "class_names.json"
20
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
-
22
- # Load class names
23
- with open(CLASS_NAMES_PATH, 'r') as f:
24
- class_names = json.load(f)
25
-
26
- NUM_CLASSES = len(class_names)
27
-
28
- # Load model
29
- print("Loading model...")
30
- model = create_model(
31
- num_classes=NUM_CLASSES,
32
- model_type='efficientnet_b2', # Stage 2 architecture
33
- pretrained=False, # We're loading trained weights
34
- dropout_rate=0.3 # Stage 2 dropout rate
35
- )
36
-
37
- # Load trained weights
38
- if os.path.exists(MODEL_PATH):
39
- checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
40
- if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
41
- model.load_state_dict(checkpoint['model_state_dict'])
42
- else:
43
- model.load_state_dict(checkpoint)
44
- print("✅ Model loaded successfully!")
45
- else:
46
- print("⚠️ Model file not found. Please ensure best_model.pth is in the repository.")
47
-
48
- model.to(DEVICE)
49
- model.eval()
50
-
51
- # Image preprocessing (Stage 2 configuration)
52
- transform = transforms.Compose([
53
- transforms.Resize((320, 320)), # Stage 2 image size
54
- transforms.ToTensor(),
55
- transforms.Normalize(
56
- mean=[0.485, 0.456, 0.406],
57
- std=[0.229, 0.224, 0.225]
58
- )
59
- ])
60
-
61
- def predict_bird(image):
62
- """
63
- Predict bird species from uploaded image.
64
- """
65
- try:
66
- # Preprocess image
67
- if isinstance(image, np.ndarray):
68
- image = Image.fromarray(image.astype('uint8'))
69
-
70
- # Convert to RGB if needed
71
- if image.mode != 'RGB':
72
- image = image.convert('RGB')
73
-
74
- # Apply transformations
75
- input_tensor = transform(image).unsqueeze(0).to(DEVICE)
76
-
77
- # Prediction
78
- with torch.no_grad():
79
- outputs = model(input_tensor)
80
- probabilities = F.softmax(outputs, dim=1)
81
- confidence, predicted = torch.max(probabilities, 1)
82
-
83
- # Get top 5 predictions
84
- top5_prob, top5_indices = torch.topk(probabilities, 5)
85
-
86
- # Format results
87
- results = {}
88
- for i in range(5):
89
- class_idx = top5_indices[0][i].item()
90
- prob = top5_prob[0][i].item()
91
- class_name = class_names[class_idx].replace('_', ' ')
92
- results[class_name] = float(prob)
93
-
94
- return results
95
-
96
- except Exception as e:
97
- return {"Error": f"Prediction failed: {str(e)}"}
98
-
99
- # Create Gradio interface
100
- title = "🐦 Bird Species Classifier"
101
- description = """
102
- ## Advanced Bird Classification Model (76.74% Accuracy)
103
-
104
- This model can classify **200 different bird species** using advanced deep learning techniques:
105
-
106
- ### Model Details:
107
- - **Architecture**: EfficientNet-B2 with enhanced regularization
108
- - **Training Strategy**: Progressive training with MixUp augmentation
109
- - **Performance**: 76.74% test accuracy (Stage 2 results)
110
- - **Dataset**: CUB-200-2011 (200 bird species)
111
-
112
- ### How to use:
113
- 1. Upload a clear image of a bird
114
- 2. The model will predict the top 5 most likely species
115
- 3. Confidence scores show the model's certainty
116
-
117
- ### Best Results Tips:
118
- - Use high-quality, well-lit images
119
- - Ensure the bird is clearly visible
120
- - Close-up shots work better than distant ones
121
- - Natural lighting produces better results
122
-
123
- **Note**: This model was trained on the CUB-200-2011 dataset and works best with North American bird species.
124
- """
125
-
126
- article = """
127
- ### Technical Implementation:
128
- - **Framework**: PyTorch with EfficientNet-B2 backbone
129
- - **Training**: Progressive training with MixUp data augmentation
130
- - **Regularization**: Optimized dropout rates (0.3) and advanced augmentation
131
- - **Image Size**: 320x320 pixels for optimal detail capture
132
-
133
- ### About the Model:
134
- This bird classifier was developed using advanced machine learning techniques including:
135
- - Transfer learning from ImageNet-pretrained EfficientNet
136
- - Progressive training strategy across multiple stages
137
- - MixUp augmentation for improved generalization
138
- - Comprehensive evaluation on 200 bird species
139
-
140
- For more details about the training process and methodology, please refer to the repository documentation.
141
- """
142
-
143
- # Create the interface
144
- iface = gr.Interface(
145
- fn=predict_bird,
146
- inputs=gr.Image(type="pil", label="Upload Bird Image"),
147
- outputs=gr.Label(num_top_classes=5, label="Predictions"),
148
- title=title,
149
- description=description,
150
- article=article,
151
- examples=[
152
- # You can add example images here if you have them
153
- ],
154
- allow_flagging="never",
155
- theme=gr.themes.Soft()
156
- )
157
-
158
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  iface.launch(debug=True)
 
1
+ """
2
+ Gradio App for Bird Classification - Hugging Face Deployment
3
+ Enhanced model with architecture auto-detection and error handling.
4
+ """
5
+ import gradio as gr
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from PIL import Image
9
+ import json
10
+ import numpy as np
11
+ from torchvision import transforms
12
+ import os
13
+
14
+ # Import our model architecture
15
+ from models import create_model
16
+
17
+ # Configuration
18
+ MODEL_PATH = "best_model.pth"
19
+ CLASS_NAMES_PATH = "class_names.json"
20
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
+
22
+ # Load class names
23
+ with open(CLASS_NAMES_PATH, 'r') as f:
24
+ class_names = json.load(f)
25
+
26
+ NUM_CLASSES = len(class_names)
27
+
28
+ # Load model - detect architecture from checkpoint
29
+ print("Loading model...")
30
+
31
+ # First, try to detect the correct architecture from the model file
32
+ if os.path.exists(MODEL_PATH):
33
+ checkpoint = torch.load(MODEL_PATH, map_location='cpu')
34
+
35
+ # Detect EfficientNet variant based on feature dimensions
36
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
37
+ state_dict = checkpoint['model_state_dict']
38
+ else:
39
+ state_dict = checkpoint
40
+
41
+ # Check backbone head feature size to determine EfficientNet variant
42
+ if 'backbone._conv_head.weight' in state_dict:
43
+ conv_head_shape = state_dict['backbone._conv_head.weight'].shape
44
+ if conv_head_shape[0] == 1536: # EfficientNet-B3
45
+ model_type = 'efficientnet_b3'
46
+ elif conv_head_shape[0] == 1408: # EfficientNet-B2
47
+ model_type = 'efficientnet_b2'
48
+ elif conv_head_shape[0] == 1280: # EfficientNet-B0/B1
49
+ model_type = 'efficientnet_b1'
50
+ else:
51
+ model_type = 'efficientnet_b2' # Default fallback
52
+ else:
53
+ model_type = 'efficientnet_b2' # Default fallback
54
+
55
+ # Check actual number of classes from classifier
56
+ if 'classifier.9.weight' in state_dict:
57
+ actual_classes = state_dict['classifier.9.weight'].shape[0]
58
+ else:
59
+ actual_classes = NUM_CLASSES
60
+
61
+ print("Detected model: {} with {} classes".format(model_type, actual_classes))
62
+
63
+ else:
64
+ model_type = 'efficientnet_b2'
65
+ actual_classes = NUM_CLASSES
66
+ print("Model file not found, using default: {}".format(model_type))
67
+
68
+ model = create_model(
69
+ num_classes=actual_classes,
70
+ model_type=model_type,
71
+ pretrained=False, # We're loading trained weights
72
+ dropout_rate=0.3
73
+ )
74
+
75
+ # Load trained weights
76
+ if os.path.exists(MODEL_PATH):
77
+ try:
78
+ checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
79
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
80
+ model.load_state_dict(checkpoint['model_state_dict'])
81
+ print("✅ Model loaded successfully! ({}, {} classes)".format(model_type, actual_classes))
82
+ else:
83
+ model.load_state_dict(checkpoint)
84
+ print("✅ Model loaded successfully! ({}, {} classes)".format(model_type, actual_classes))
85
+ except Exception as e:
86
+ print("❌ Error loading model: {}".format(str(e)))
87
+ print("Please ensure the model architecture matches the saved weights.")
88
+ else:
89
+ print("⚠️ Model file not found. Please ensure best_model.pth is in the repository.")
90
+
91
+ model.to(DEVICE)
92
+ model.eval()
93
+
94
+ def predict_bird(image):
95
+ """
96
+ Predict bird species from uploaded image.
97
+ """
98
+ try:
99
+ # Preprocess image
100
+ if isinstance(image, np.ndarray):
101
+ image = Image.fromarray(image.astype('uint8'))
102
+
103
+ # Convert to RGB if needed
104
+ if image.mode != 'RGB':
105
+ image = image.convert('RGB')
106
+
107
+ # Define preprocessing step by step to avoid namespace issues
108
+ resize = transforms.Resize((320, 320))
109
+ to_tensor = transforms.ToTensor()
110
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
111
+
112
+ # Apply transformations step by step
113
+ resized_image = resize(image)
114
+ tensor_image = to_tensor(resized_image)
115
+ normalized_tensor = normalize(tensor_image)
116
+ input_tensor = normalized_tensor.unsqueeze(0).to(DEVICE)
117
+
118
+ # Prediction
119
+ with torch.no_grad():
120
+ outputs = model(input_tensor)
121
+ probabilities = F.softmax(outputs, dim=1)
122
+ confidence, predicted = torch.max(probabilities, 1)
123
+
124
+ # Get top 5 predictions
125
+ top5_prob, top5_indices = torch.topk(probabilities, 5)
126
+
127
+ # Format results
128
+ results = {}
129
+ for i in range(5):
130
+ class_idx = top5_indices[0][i].item()
131
+ prob = top5_prob[0][i].item()
132
+ # Handle potential class index mismatch
133
+ if class_idx < len(class_names):
134
+ class_name = class_names[class_idx].replace('_', ' ')
135
+ else:
136
+ class_name = "Class_" + str(class_idx)
137
+ results[class_name] = float(prob)
138
+
139
+ return results
140
+
141
+ except Exception as e:
142
+ return {"Error": "Prediction failed: " + str(e)}
143
+
144
+ # Create Gradio interface
145
+ title = "🐦 Bird Species Classifier"
146
+ description = """
147
+ ## Advanced Bird Classification Model
148
+
149
+ This model can classify **199 different bird species** using advanced deep learning techniques:
150
+
151
+ ### Model Details:
152
+ - **Architecture**: Auto-detected EfficientNet (B2/B3) with enhanced regularization
153
+ - **Training Strategy**: Progressive training with advanced augmentation
154
+ - **Performance**: Optimized for accuracy and reliability
155
+ - **Dataset**: CUB-200-2011 (199 bird species)
156
+
157
+ ### How to use:
158
+ 1. Upload a clear image of a bird
159
+ 2. The model will predict the top 5 most likely species
160
+ 3. Confidence scores show the model's certainty
161
+
162
+ ### Best Results Tips:
163
+ - Use high-quality, well-lit images
164
+ - Ensure the bird is clearly visible
165
+ - Close-up shots work better than distant ones
166
+ - Natural lighting produces better results
167
+
168
+ **Note**: This model was trained on the CUB-200-2011 dataset and works best with North American bird species.
169
+ """
170
+
171
+ article = """
172
+ ### Technical Implementation:
173
+ - **Framework**: PyTorch with auto-detected EfficientNet backbone
174
+ - **Training**: Progressive training with advanced augmentation strategies
175
+ - **Regularization**: Optimized dropout rates and comprehensive validation
176
+ - **Image Size**: 320x320 pixels for optimal detail capture
177
+
178
+ ### About the Model:
179
+ This bird classifier was developed using advanced machine learning techniques including:
180
+ - Transfer learning from ImageNet-pretrained EfficientNet
181
+ - Progressive training strategy across multiple stages
182
+ - Advanced data augmentation for improved generalization
183
+ - Comprehensive evaluation and optimization
184
+
185
+ The model automatically detects the correct architecture (EfficientNet-B2 or B3) from the saved weights,
186
+ ensuring compatibility and optimal performance.
187
+
188
+ For more details about the training process and methodology, please refer to the repository documentation.
189
+ """
190
+
191
+ # Create the interface
192
+ iface = gr.Interface(
193
+ fn=predict_bird,
194
+ inputs=gr.Image(type="pil", label="Upload Bird Image"),
195
+ outputs=gr.Label(num_top_classes=5, label="Predictions"),
196
+ title=title,
197
+ description=description,
198
+ article=article,
199
+ examples=[
200
+ # You can add example images here if you have them
201
+ ],
202
+ allow_flagging="never",
203
+ theme=gr.themes.Soft()
204
+ )
205
+
206
+ if __name__ == "__main__":
207
  iface.launch(debug=True)