shwethd commited on
Commit
33e3bf9
·
verified ·
1 Parent(s): 26d3537

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -72
app.py CHANGED
@@ -1,7 +1,6 @@
1
  #!/usr/bin/env python3
2
  """
3
- HuggingFace Spaces App for ImageNet ResNet50 Classifier
4
- Trained from scratch to 78%+ Top-1 accuracy
5
  """
6
 
7
  import gradio as gr
@@ -83,48 +82,34 @@ class ResNet50(nn.Module):
83
 
84
 
85
  # ============================================================================
86
- # MODEL LOADING
87
  # ============================================================================
88
 
89
  def load_model():
90
- """Load the trained model (CPU-optimized for HuggingFace)"""
91
  model = ResNet50(num_classes=1000)
92
-
93
  try:
94
- # Try to load checkpoint
95
- checkpoint_path = "best_model_final.pth" # Will be uploaded separately
96
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
97
-
98
- # Handle different checkpoint formats
99
  if isinstance(checkpoint, dict):
100
- if 'model' in checkpoint:
101
- state_dict = checkpoint['model']
102
- elif 'state_dict' in checkpoint:
103
- state_dict = checkpoint['state_dict']
104
- else:
105
- state_dict = checkpoint
106
  else:
107
  state_dict = checkpoint
108
 
109
- # Remove 'module.' prefix if present (from DataParallel)
110
  new_state_dict = {}
111
  for k, v in state_dict.items():
112
  name = k.replace('module.', '') if k.startswith('module.') else k
113
  new_state_dict[name] = v
114
 
115
  model.load_state_dict(new_state_dict)
116
- print(f"✅ Model loaded successfully from {checkpoint_path}")
117
-
118
  except Exception as e:
119
  print(f"⚠️ Could not load checkpoint: {e}")
120
- print("Using randomly initialized model for demo purposes")
121
 
122
  model.eval()
123
  return model
124
 
125
 
126
  # ============================================================================
127
- # IMAGE PREPROCESSING
128
  # ============================================================================
129
 
130
  transform = transforms.Compose([
@@ -136,31 +121,17 @@ transform = transforms.Compose([
136
 
137
 
138
  # ============================================================================
139
- # IMAGENET CLASS LABELS
140
  # ============================================================================
141
 
142
- # Top 20 most common ImageNet classes for demo
143
- IMAGENET_CLASSES = {
144
- 0: "tench", 1: "goldfish", 2: "great white shark", 3: "tiger shark",
145
- 4: "hammerhead", 5: "electric ray", 6: "stingray", 7: "cock",
146
- 8: "hen", 9: "ostrich", 10: "brambling", 11: "goldfinch",
147
- 12: "house finch", 13: "junco", 14: "indigo bunting", 15: "robin",
148
- 151: "Chihuahua", 207: "golden retriever", 281: "tabby cat",
149
- 282: "tiger cat", 283: "Persian cat", 285: "Egyptian cat",
150
- 291: "lion", 292: "tiger", 293: "jaguar", 294: "leopard",
151
- 404: "airliner", 407: "container ship", 468: "cab",
152
- 511: "convertible", 609: "jeep", 627: "limousine",
153
- 817: "sports car", 751: "racer", 779: "school bus",
154
- 555: "fire engine", 569: "garbage truck", 717: "pickup",
155
- # Add more as needed
156
- }
157
-
158
- # Load full class names if available
159
  try:
160
  with open('imagenet_classes.json', 'r') as f:
161
  IMAGENET_CLASSES = json.load(f)
162
  except:
163
- pass # Use default subset
 
 
164
 
165
 
166
  # ============================================================================
@@ -168,21 +139,21 @@ except:
168
  # ============================================================================
169
 
170
  def predict(image):
171
- """
172
- Predict ImageNet class for input image
173
 
174
- Args:
175
- image: PIL Image
176
-
177
- Returns:
178
- dict: Top-5 predictions with confidence scores
179
- """
180
  if image is None:
181
- return {"error": "Please upload an image"}
 
 
 
 
 
 
 
182
 
183
  try:
184
  # Preprocess
185
- img_tensor = transform(image).unsqueeze(0) # Add batch dimension
186
 
187
  # Inference
188
  with torch.no_grad():
@@ -192,67 +163,76 @@ def predict(image):
192
  # Get top 5 predictions
193
  top5_prob, top5_indices = torch.topk(probabilities, 5)
194
 
195
- # Format results
196
  results = {}
197
  for i in range(5):
198
  idx = top5_indices[i].item()
199
  prob = top5_prob[i].item()
200
- class_name = IMAGENET_CLASSES.get(str(idx), f"Class {idx}")
201
- results[f"{class_name}"] = float(prob)
 
 
 
 
202
 
203
  return results
204
 
205
  except Exception as e:
206
- return {"error": f"Prediction failed: {str(e)}"}
 
 
 
 
 
 
 
 
207
 
208
 
209
  # ============================================================================
210
  # GRADIO INTERFACE
211
  # ============================================================================
212
 
213
- # Load model globally
214
  print("Loading model...")
215
  model = load_model()
216
- print("Model loaded successfully!")
217
 
218
- # Create Gradio interface
219
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
220
  gr.Markdown("""
221
  # 🔥 ImageNet ResNet50 Classifier
222
 
223
- Upload any image and get top-5 predictions with confidence scores.
 
 
224
  """)
225
 
226
  with gr.Row():
227
  with gr.Column():
228
  image_input = gr.Image(type="pil", label="Upload Image")
229
- predict_btn = gr.Button("Classify Image", variant="primary")
 
230
 
231
 
232
  with gr.Column():
233
  output = gr.Label(num_top_classes=5, label="Top-5 Predictions")
234
 
235
  gr.Markdown("""
236
- ### 🎯 Model Info:
237
  - **Architecture:** ResNet50 (25.5M params)
238
- - **Training:** From scratch (no pretrained weights)
239
- - **Dataset:** ImageNet (1.2M images, 1000 classes)
240
- - **Accuracy:** 77.09% Top-1 validation
241
-
242
- ### 🔗 Links:
243
- - [GitHub Repository](https://github.com/Shwethaamrutha/TSAI-S8)
244
  """)
245
-
246
 
247
- # Connect button
248
  predict_btn.click(fn=predict, inputs=image_input, outputs=output)
249
 
 
 
 
250
 
251
-
252
-
253
-
254
 
255
- # Launch
256
  if __name__ == "__main__":
257
  demo.launch()
258
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ HuggingFace App for ImageNet ResNet50 Classifier - 77.09% Accuracy
 
4
  """
5
 
6
  import gradio as gr
 
82
 
83
 
84
  # ============================================================================
85
+ # LOAD MODEL
86
  # ============================================================================
87
 
88
  def load_model():
 
89
  model = ResNet50(num_classes=1000)
 
90
  try:
91
+ checkpoint = torch.load("best_model_final.pth", map_location='cpu')
 
 
 
 
92
  if isinstance(checkpoint, dict):
93
+ state_dict = checkpoint.get('model', checkpoint.get('state_dict', checkpoint))
 
 
 
 
 
94
  else:
95
  state_dict = checkpoint
96
 
 
97
  new_state_dict = {}
98
  for k, v in state_dict.items():
99
  name = k.replace('module.', '') if k.startswith('module.') else k
100
  new_state_dict[name] = v
101
 
102
  model.load_state_dict(new_state_dict)
103
+ print("✅ Model loaded successfully")
 
104
  except Exception as e:
105
  print(f"⚠️ Could not load checkpoint: {e}")
 
106
 
107
  model.eval()
108
  return model
109
 
110
 
111
  # ============================================================================
112
+ # PREPROCESSING
113
  # ============================================================================
114
 
115
  transform = transforms.Compose([
 
121
 
122
 
123
  # ============================================================================
124
+ # IMAGENET CLASSES
125
  # ============================================================================
126
 
127
+ IMAGENET_CLASSES = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  try:
129
  with open('imagenet_classes.json', 'r') as f:
130
  IMAGENET_CLASSES = json.load(f)
131
  except:
132
+ # Fallback - create basic class mapping
133
+ IMAGENET_CLASSES = {str(i): f"Class {i}" for i in range(1000)}
134
+ print("⚠️ Using default class indices")
135
 
136
 
137
  # ============================================================================
 
139
  # ============================================================================
140
 
141
  def predict(image):
142
+ """Predict ImageNet class for input image"""
 
143
 
 
 
 
 
 
 
144
  if image is None:
145
+ # Return dummy predictions for error case
146
+ return {
147
+ "Error - No Image": 1.0,
148
+ "Please upload an image": 0.0,
149
+ "": 0.0,
150
+ " ": 0.0,
151
+ " ": 0.0
152
+ }
153
 
154
  try:
155
  # Preprocess
156
+ img_tensor = transform(image).unsqueeze(0)
157
 
158
  # Inference
159
  with torch.no_grad():
 
163
  # Get top 5 predictions
164
  top5_prob, top5_indices = torch.topk(probabilities, 5)
165
 
166
+ # Format results - dict with string keys and float values
167
  results = {}
168
  for i in range(5):
169
  idx = top5_indices[i].item()
170
  prob = top5_prob[i].item()
171
+
172
+ # CRITICAL: Convert idx to string for JSON lookup
173
+ class_name = IMAGENET_CLASSES.get(str(idx), f"Class_{idx}")
174
+
175
+ # Ensure float probability
176
+ results[class_name] = float(prob)
177
 
178
  return results
179
 
180
  except Exception as e:
181
+ # Return error in valid format
182
+ error_msg = str(e)[:80]
183
+ return {
184
+ f"Error: {error_msg}": 0.5,
185
+ "Please try another image": 0.3,
186
+ "Check console for details": 0.2,
187
+ "": 0.0,
188
+ " ": 0.0
189
+ }
190
 
191
 
192
  # ============================================================================
193
  # GRADIO INTERFACE
194
  # ============================================================================
195
 
 
196
  print("Loading model...")
197
  model = load_model()
198
+ print("Model ready!")
199
 
 
200
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
201
  gr.Markdown("""
202
  # 🔥 ImageNet ResNet50 Classifier
203
 
204
+ **77.09% Top-1 Accuracy** - Trained from scratch on ImageNet (1.2M images, 1000 classes)
205
+
206
+ Upload an image to get top-5 predictions with confidence scores.
207
  """)
208
 
209
  with gr.Row():
210
  with gr.Column():
211
  image_input = gr.Image(type="pil", label="Upload Image")
212
+ predict_btn = gr.Button("Classify Image", variant="primary", size="lg")
213
+
214
 
215
 
216
  with gr.Column():
217
  output = gr.Label(num_top_classes=5, label="Top-5 Predictions")
218
 
219
  gr.Markdown("""
220
+ ### 📊 Model Info:
221
  - **Architecture:** ResNet50 (25.5M params)
222
+ - **Training:** From scratch (no pretrained)
223
+ - **Accuracy:** 77.09% Top-1
224
+ - **Hardware:** A100 GPUs
 
 
 
225
  """)
 
226
 
 
227
  predict_btn.click(fn=predict, inputs=image_input, outputs=output)
228
 
229
+ gr.Markdown("""
230
+ ---
231
+ **Links:** [GitHub Code](https://github.com/Shwethaamrutha/TSAI-S8) | [Training Details](https://github.com/Shwethaamrutha/TSAI-S8/blob/main/README.md)
232
 
233
+ Built with PyTorch •
234
+ """)
 
235
 
 
236
  if __name__ == "__main__":
237
  demo.launch()
238