shwethd commited on
Commit
106a88d
Β·
verified Β·
1 Parent(s): 97b0fb2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -41
app.py CHANGED
@@ -1,6 +1,6 @@
1
  #!/usr/bin/env python3
2
  """
3
- HuggingFace App for ImageNet ResNet50 Classifier - 77.09% Accuracy
4
  """
5
 
6
  import gradio as gr
@@ -86,14 +86,25 @@ class ResNet50(nn.Module):
86
  # ============================================================================
87
 
88
  def load_model():
 
 
 
 
89
  model = ResNet50(num_classes=1000)
 
90
  try:
91
- checkpoint = torch.load("best_model_final.pth", map_location='cpu')
 
 
 
92
  if isinstance(checkpoint, dict):
93
  state_dict = checkpoint.get('model', checkpoint.get('state_dict', checkpoint))
94
  else:
95
  state_dict = checkpoint
96
 
 
 
 
97
  new_state_dict = {}
98
  for k, v in state_dict.items():
99
  name = k.replace('module.', '') if k.startswith('module.') else k
@@ -101,10 +112,21 @@ def load_model():
101
 
102
  model.load_state_dict(new_state_dict)
103
  print("βœ… Model loaded successfully")
 
 
 
 
 
 
 
 
104
  except Exception as e:
105
- print(f"⚠️ Could not load checkpoint: {e}")
 
 
106
 
107
  model.eval()
 
108
  return model
109
 
110
 
@@ -129,16 +151,21 @@ try:
129
  with open('imagenet_classes.json', 'r') as f:
130
  data = json.load(f)
131
 
 
 
132
  # Handle both dict and list formats
133
  if isinstance(data, dict):
134
  IMAGENET_CLASSES = data
 
135
  elif isinstance(data, list):
136
  # Convert list to dict with string indices
137
  IMAGENET_CLASSES = {str(i): data[i] for i in range(len(data))}
 
138
  else:
139
- raise ValueError("Unexpected JSON format")
 
 
140
 
141
- print(f"βœ… Loaded {len(IMAGENET_CLASSES)} ImageNet classes")
142
  except Exception as e:
143
  # Fallback - create basic class mapping
144
  IMAGENET_CLASSES = {str(i): f"Class_{i}" for i in range(1000)}
@@ -150,12 +177,11 @@ except Exception as e:
150
  # ============================================================================
151
 
152
  def predict(image):
153
- """Predict ImageNet class for input image"""
154
 
155
  if image is None:
156
- # Return dummy predictions for error case
157
  return {
158
- "Error - No Image": 1.0,
159
  "Please upload an image": 0.0,
160
  "": 0.0,
161
  " ": 0.0,
@@ -163,38 +189,58 @@ def predict(image):
163
  }
164
 
165
  try:
 
 
 
 
 
 
 
166
  # Preprocess
167
  img_tensor = transform(image).unsqueeze(0)
 
 
168
 
169
  # Inference
170
  with torch.no_grad():
171
  outputs = model(img_tensor)
 
 
 
172
  probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
 
173
 
174
  # Get top 5 predictions
175
  top5_prob, top5_indices = torch.topk(probabilities, 5)
176
 
177
- # Format results - dict with string keys and float values
 
 
 
 
 
 
 
 
 
178
  results = {}
179
  for i in range(5):
180
  idx = top5_indices[i].item()
181
  prob = top5_prob[i].item()
182
-
183
- # CRITICAL: Convert idx to string for JSON lookup
184
  class_name = IMAGENET_CLASSES.get(str(idx), f"Class_{idx}")
185
-
186
- # Ensure float probability
187
  results[class_name] = float(prob)
188
 
189
  return results
190
 
191
  except Exception as e:
192
- # Return error in valid format
193
- error_msg = str(e)[:80]
 
 
194
  return {
195
- f"Error: {error_msg}": 0.5,
196
- "Please try another image": 0.3,
197
- "Check console for details": 0.2,
198
  "": 0.0,
199
  " ": 0.0
200
  }
@@ -210,43 +256,27 @@ print("Model ready!")
210
 
211
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
212
  gr.Markdown("""
213
- # πŸ”₯ ImageNet ResNet50 Classifier
214
 
215
- **77.09% Top-1 Accuracy** - Trained from scratch on ImageNet (1.2M images, 1000 classes)
216
 
217
- Upload an image to get top-5 predictions with confidence scores.
218
  """)
219
 
220
  with gr.Row():
221
  with gr.Column():
222
  image_input = gr.Image(type="pil", label="Upload Image")
223
- predict_btn = gr.Button("Classify Image", variant="primary", size="lg")
224
-
225
- gr.Markdown("""
226
- ### πŸ’‘ Tips:
227
- - Works best with clear, centered objects
228
- - Supports 1000 ImageNet classes
229
- - Try different images!
230
- """)
231
 
232
  with gr.Column():
233
- output = gr.Label(num_top_classes=5, label="Top-5 Predictions")
234
-
235
- gr.Markdown("""
236
- ### πŸ“Š Model Info:
237
- - **Architecture:** ResNet50 (25.5M params)
238
- - **Training:** From scratch (no pretrained)
239
- - **Accuracy:** 77.09% Top-1
240
- - **Hardware:** 8Γ— A100 GPUs
241
- """)
242
 
243
  predict_btn.click(fn=predict, inputs=image_input, outputs=output)
244
 
245
  gr.Markdown("""
246
- ---
247
- **Links:** [GitHub Code](https://github.com/Shwethaamrutha/TSAI-S8) | [Training Details](https://github.com/Shwethaamrutha/TSAI-S8/blob/main/README.md)
248
 
249
- Built with PyTorch β€’ Trained on AWS p4d.24xlarge β€’ Top 10% from-scratch result
250
  """)
251
 
252
  if __name__ == "__main__":
 
1
  #!/usr/bin/env python3
2
  """
3
+ DEBUG VERSION - HuggingFace App for ImageNet ResNet50
4
  """
5
 
6
  import gradio as gr
 
86
  # ============================================================================
87
 
88
  def load_model():
89
+ print("="*70)
90
+ print("LOADING MODEL")
91
+ print("="*70)
92
+
93
  model = ResNet50(num_classes=1000)
94
+
95
  try:
96
+ checkpoint = torch.load("best_model_final.pth", map_location='cpu', weights_only=False)
97
+ print(f"Checkpoint type: {type(checkpoint)}")
98
+ print(f"Checkpoint keys: {list(checkpoint.keys())[:5] if isinstance(checkpoint, dict) else 'Not a dict'}")
99
+
100
  if isinstance(checkpoint, dict):
101
  state_dict = checkpoint.get('model', checkpoint.get('state_dict', checkpoint))
102
  else:
103
  state_dict = checkpoint
104
 
105
+ print(f"State dict type: {type(state_dict)}")
106
+ print(f"State dict keys (first 5): {list(state_dict.keys())[:5]}")
107
+
108
  new_state_dict = {}
109
  for k, v in state_dict.items():
110
  name = k.replace('module.', '') if k.startswith('module.') else k
 
112
 
113
  model.load_state_dict(new_state_dict)
114
  print("βœ… Model loaded successfully")
115
+
116
+ # Test forward pass
117
+ test_input = torch.randn(1, 3, 224, 224)
118
+ with torch.no_grad():
119
+ test_output = model(test_input)
120
+ print(f"βœ… Model output shape: {test_output.shape}")
121
+ print(f"βœ… Model output range: [{test_output.min():.2f}, {test_output.max():.2f}]")
122
+
123
  except Exception as e:
124
+ print(f"❌ Error loading checkpoint: {e}")
125
+ import traceback
126
+ traceback.print_exc()
127
 
128
  model.eval()
129
+ print("="*70)
130
  return model
131
 
132
 
 
151
  with open('imagenet_classes.json', 'r') as f:
152
  data = json.load(f)
153
 
154
+ print(f"JSON data type: {type(data)}")
155
+
156
  # Handle both dict and list formats
157
  if isinstance(data, dict):
158
  IMAGENET_CLASSES = data
159
+ print(f"βœ… Loaded as dict with {len(IMAGENET_CLASSES)} classes")
160
  elif isinstance(data, list):
161
  # Convert list to dict with string indices
162
  IMAGENET_CLASSES = {str(i): data[i] for i in range(len(data))}
163
+ print(f"βœ… Converted list to dict with {len(IMAGENET_CLASSES)} classes")
164
  else:
165
+ raise ValueError(f"Unexpected JSON format: {type(data)}")
166
+
167
+ print(f"Sample classes: {list(IMAGENET_CLASSES.items())[:3]}")
168
 
 
169
  except Exception as e:
170
  # Fallback - create basic class mapping
171
  IMAGENET_CLASSES = {str(i): f"Class_{i}" for i in range(1000)}
 
177
  # ============================================================================
178
 
179
  def predict(image):
180
+ """Predict ImageNet class for input image - WITH DEBUG INFO"""
181
 
182
  if image is None:
 
183
  return {
184
+ "No Image Uploaded": 1.0,
185
  "Please upload an image": 0.0,
186
  "": 0.0,
187
  " ": 0.0,
 
189
  }
190
 
191
  try:
192
+ print(f"\n{'='*70}")
193
+ print(f"PREDICTION DEBUG")
194
+ print(f"{'='*70}")
195
+ print(f"Image type: {type(image)}")
196
+ print(f"Image size: {image.size}")
197
+ print(f"Image mode: {image.mode}")
198
+
199
  # Preprocess
200
  img_tensor = transform(image).unsqueeze(0)
201
+ print(f"Tensor shape: {img_tensor.shape}")
202
+ print(f"Tensor range: [{img_tensor.min():.3f}, {img_tensor.max():.3f}]")
203
 
204
  # Inference
205
  with torch.no_grad():
206
  outputs = model(img_tensor)
207
+ print(f"Raw outputs shape: {outputs.shape}")
208
+ print(f"Raw outputs range: [{outputs.min():.2f}, {outputs.max():.2f}]")
209
+
210
  probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
211
+ print(f"Probabilities sum: {probabilities.sum():.4f}")
212
 
213
  # Get top 5 predictions
214
  top5_prob, top5_indices = torch.topk(probabilities, 5)
215
 
216
+ print(f"\nTop-5 Predictions:")
217
+ for i in range(5):
218
+ idx = top5_indices[i].item()
219
+ prob = top5_prob[i].item()
220
+ class_name = IMAGENET_CLASSES.get(str(idx), f"Class_{idx}")
221
+ print(f" {idx}: {class_name} = {prob:.4f}")
222
+
223
+ print(f"{'='*70}\n")
224
+
225
+ # Format results
226
  results = {}
227
  for i in range(5):
228
  idx = top5_indices[i].item()
229
  prob = top5_prob[i].item()
 
 
230
  class_name = IMAGENET_CLASSES.get(str(idx), f"Class_{idx}")
 
 
231
  results[class_name] = float(prob)
232
 
233
  return results
234
 
235
  except Exception as e:
236
+ print(f"❌ Prediction error: {e}")
237
+ import traceback
238
+ traceback.print_exc()
239
+
240
  return {
241
+ f"Error {str(e)[:50]}": 0.5,
242
+ "Check logs": 0.3,
243
+ "Try another image": 0.2,
244
  "": 0.0,
245
  " ": 0.0
246
  }
 
256
 
257
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
258
  gr.Markdown("""
259
+ # πŸ”₯ ImageNet ResNet50 Classifier (DEBUG VERSION)
260
 
261
+ **77.09% Top-1 Accuracy** - From scratch training
262
 
263
+ Upload an image to test. Check console for debug output.
264
  """)
265
 
266
  with gr.Row():
267
  with gr.Column():
268
  image_input = gr.Image(type="pil", label="Upload Image")
269
+ predict_btn = gr.Button("Classify", variant="primary")
 
 
 
 
 
 
 
270
 
271
  with gr.Column():
272
+ output = gr.Label(num_top_classes=5, label="Predictions")
 
 
 
 
 
 
 
 
273
 
274
  predict_btn.click(fn=predict, inputs=image_input, outputs=output)
275
 
276
  gr.Markdown("""
277
+ **Model:** ResNet50 (25.5M params) | **Accuracy:** 77.09%
 
278
 
279
+ [GitHub](https://github.com/Shwethaamrutha/TSAI-S8)
280
  """)
281
 
282
  if __name__ == "__main__":