shwethd commited on
Commit
7406b0f
·
verified ·
1 Parent(s): 33e3bf9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -6
app.py CHANGED
@@ -127,11 +127,22 @@ transform = transforms.Compose([
127
  IMAGENET_CLASSES = {}
128
  try:
129
  with open('imagenet_classes.json', 'r') as f:
130
- IMAGENET_CLASSES = json.load(f)
131
- except:
 
 
 
 
 
 
 
 
 
 
 
132
  # Fallback - create basic class mapping
133
- IMAGENET_CLASSES = {str(i): f"Class {i}" for i in range(1000)}
134
- print("⚠️ Using default class indices")
135
 
136
 
137
  # ============================================================================
@@ -211,7 +222,12 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
211
  image_input = gr.Image(type="pil", label="Upload Image")
212
  predict_btn = gr.Button("Classify Image", variant="primary", size="lg")
213
 
214
-
 
 
 
 
 
215
 
216
  with gr.Column():
217
  output = gr.Label(num_top_classes=5, label="Top-5 Predictions")
@@ -230,7 +246,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
230
  ---
231
  **Links:** [GitHub Code](https://github.com/Shwethaamrutha/TSAI-S8) | [Training Details](https://github.com/Shwethaamrutha/TSAI-S8/blob/main/README.md)
232
 
233
- Built with PyTorch •
234
  """)
235
 
236
  if __name__ == "__main__":
 
127
  IMAGENET_CLASSES = {}
128
  try:
129
  with open('imagenet_classes.json', 'r') as f:
130
+ data = json.load(f)
131
+
132
+ # Handle both dict and list formats
133
+ if isinstance(data, dict):
134
+ IMAGENET_CLASSES = data
135
+ elif isinstance(data, list):
136
+ # Convert list to dict with string indices
137
+ IMAGENET_CLASSES = {str(i): data[i] for i in range(len(data))}
138
+ else:
139
+ raise ValueError("Unexpected JSON format")
140
+
141
+ print(f"✅ Loaded {len(IMAGENET_CLASSES)} ImageNet classes")
142
+ except Exception as e:
143
  # Fallback - create basic class mapping
144
+ IMAGENET_CLASSES = {str(i): f"Class_{i}" for i in range(1000)}
145
+ print(f"⚠️ Using default class indices: {e}")
146
 
147
 
148
  # ============================================================================
 
222
  image_input = gr.Image(type="pil", label="Upload Image")
223
  predict_btn = gr.Button("Classify Image", variant="primary", size="lg")
224
 
225
+ gr.Markdown("""
226
+ ### 💡 Tips:
227
+ - Works best with clear, centered objects
228
+ - Supports 1000 ImageNet classes
229
+ - Try different images!
230
+ """)
231
 
232
  with gr.Column():
233
  output = gr.Label(num_top_classes=5, label="Top-5 Predictions")
 
246
  ---
247
  **Links:** [GitHub Code](https://github.com/Shwethaamrutha/TSAI-S8) | [Training Details](https://github.com/Shwethaamrutha/TSAI-S8/blob/main/README.md)
248
 
249
+ Built with PyTorch • Trained on AWS p4d.24xlarge • Top 10% from-scratch result
250
  """)
251
 
252
  if __name__ == "__main__":