Alamgirapi commited on
Commit
b7c34ef
Β·
verified Β·
1 Parent(s): 9bed5ce

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +405 -0
app.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import requests
10
+ import io
11
+ from timm import create_model
12
+
13
+ # Set page config
14
+ st.set_page_config(
15
+ page_title="Sports Ball Classifier",
16
+ page_icon="πŸ€",
17
+ layout="wide"
18
+ )
19
+
20
+ # Custom ConvNeXt model definition (in case the saved model uses a different architecture)
21
+ class ConvNeXtBlock(nn.Module):
22
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
23
+ super().__init__()
24
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
25
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
26
+ self.pwconv1 = nn.Linear(dim, 4 * dim)
27
+ self.act = nn.GELU()
28
+ self.pwconv2 = nn.Linear(4 * dim, dim)
29
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
30
+ requires_grad=True) if layer_scale_init_value > 0 else None
31
+
32
+ def forward(self, x):
33
+ input = x
34
+ x = self.dwconv(x)
35
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
36
+ x = self.norm(x)
37
+ x = self.pwconv1(x)
38
+ x = self.act(x)
39
+ x = self.pwconv2(x)
40
+ if self.gamma is not None:
41
+ x = self.gamma * x
42
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
43
+ x = input + x
44
+ return x
45
+
46
+ class CustomConvNeXt(nn.Module):
47
+ def __init__(self, num_classes=15):
48
+ super().__init__()
49
+ self.stem = nn.Sequential(
50
+ nn.Conv2d(3, 96, kernel_size=4, stride=4),
51
+ nn.LayerNorm([96, 56, 56], eps=1e-6)
52
+ )
53
+
54
+ # Stage 1
55
+ self.stage1 = nn.Sequential(*[ConvNeXtBlock(96) for _ in range(3)])
56
+
57
+ # Downsample 1
58
+ self.downsample1 = nn.Sequential(
59
+ nn.LayerNorm([96, 56, 56], eps=1e-6),
60
+ nn.Conv2d(96, 192, kernel_size=2, stride=2)
61
+ )
62
+
63
+ # Stage 2
64
+ self.stage2 = nn.Sequential(*[ConvNeXtBlock(192) for _ in range(3)])
65
+
66
+ # Downsample 2
67
+ self.downsample2 = nn.Sequential(
68
+ nn.LayerNorm([192, 28, 28], eps=1e-6),
69
+ nn.Conv2d(192, 384, kernel_size=2, stride=2)
70
+ )
71
+
72
+ # Stage 3
73
+ self.stage3 = nn.Sequential(*[ConvNeXtBlock(384) for _ in range(9)])
74
+
75
+ # Downsample 3
76
+ self.downsample3 = nn.Sequential(
77
+ nn.LayerNorm([384, 14, 14], eps=1e-6),
78
+ nn.Conv2d(384, 768, kernel_size=2, stride=2)
79
+ )
80
+
81
+ # Stage 4
82
+ self.stage4 = nn.Sequential(*[ConvNeXtBlock(768) for _ in range(3)])
83
+
84
+ # Head
85
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
86
+ self.norm = nn.LayerNorm(768, eps=1e-6)
87
+ self.head = nn.Linear(768, num_classes)
88
+
89
+ def forward(self, x):
90
+ x = self.stem(x)
91
+ x = self.stage1(x)
92
+ x = self.downsample1(x)
93
+ x = self.stage2(x)
94
+ x = self.downsample2(x)
95
+ x = self.stage3(x)
96
+ x = self.downsample3(x)
97
+ x = self.stage4(x)
98
+ x = self.avgpool(x)
99
+ x = x.view(x.size(0), -1)
100
+ x = self.norm(x)
101
+ x = self.head(x)
102
+ return x
103
+
104
+ # Cache the model loading to avoid reloading on every interaction
105
+ @st.cache_resource
106
+ def load_model():
107
+ """Load the pre-trained ViT model for sports ball classification"""
108
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
109
+
110
+ try:
111
+ # Download model weights from Hugging Face
112
+ model_url = "https://huggingface.co/Alamgirapi/sports-ball-convnext-classifier/resolve/main/model.pth"
113
+ response = requests.get(model_url)
114
+ if response.status_code != 200:
115
+ raise Exception(f"Failed to download model: HTTP {response.status_code}")
116
+
117
+ model_state = torch.load(io.BytesIO(response.content), map_location=device)
118
+
119
+ # Inspect the state dict to understand the model structure
120
+ sample_keys = list(model_state.keys())[:10]
121
+
122
+ # Try Vision Transformer models (this is likely what was used)
123
+ vit_models_to_try = [
124
+ ("vit_base_patch16_224", lambda: create_model('vit_base_patch16_224', pretrained=False, num_classes=15)),
125
+ ("vit_small_patch16_224", lambda: create_model('vit_small_patch16_224', pretrained=False, num_classes=15)),
126
+ ("vit_tiny_patch16_224", lambda: create_model('vit_tiny_patch16_224', pretrained=False, num_classes=15)),
127
+ ("vit_large_patch16_224", lambda: create_model('vit_large_patch16_224', pretrained=False, num_classes=15)),
128
+ ("vit_base_patch32_224", lambda: create_model('vit_base_patch32_224', pretrained=False, num_classes=15)),
129
+ ]
130
+
131
+ st.info("Trying Vision Transformer (ViT) models...")
132
+ for model_name, model_func in vit_models_to_try:
133
+ try:
134
+ model = model_func()
135
+ model.load_state_dict(model_state)
136
+ model.eval()
137
+ model.to(device)
138
+ st.success(f"βœ… Successfully loaded model using: {model_name}")
139
+ return model, device
140
+ except Exception as e:
141
+ st.warning(f"❌ Failed to load with {model_name}: {str(e)[:100]}...")
142
+ continue
143
+
144
+ # Try ConvNeXt models as fallback
145
+ convnext_models_to_try = [
146
+ ("convnext_tiny", lambda: create_model('convnext_tiny', pretrained=False, num_classes=15)),
147
+ ("convnext_small", lambda: create_model('convnext_small', pretrained=False, num_classes=15)),
148
+ ("convnext_base", lambda: create_model('convnext_base', pretrained=False, num_classes=15)),
149
+ ]
150
+
151
+ st.info("Trying ConvNeXt models as fallback...")
152
+ for model_name, model_func in convnext_models_to_try:
153
+ try:
154
+ model = model_func()
155
+ model.load_state_dict(model_state)
156
+ model.eval()
157
+ model.to(device)
158
+ st.success(f"βœ… Successfully loaded model using: {model_name}")
159
+ return model, device
160
+ except Exception as e:
161
+ st.warning(f"❌ Failed to load with {model_name}: {str(e)[:100]}...")
162
+ continue
163
+
164
+ # Try other common models
165
+ other_models_to_try = [
166
+ ("resnet50", lambda: create_model('resnet50', pretrained=False, num_classes=15)),
167
+ ("efficientnet_b0", lambda: create_model('efficientnet_b0', pretrained=False, num_classes=15)),
168
+ ("mobilenetv3_large_100", lambda: create_model('mobilenetv3_large_100', pretrained=False, num_classes=15)),
169
+ ]
170
+
171
+ st.info("Trying other model architectures...")
172
+ for model_name, model_func in other_models_to_try:
173
+ try:
174
+ model = model_func()
175
+ model.load_state_dict(model_state)
176
+ model.eval()
177
+ model.to(device)
178
+ st.success(f"βœ… Successfully loaded model using: {model_name}")
179
+ return model, device
180
+ except Exception as e:
181
+ st.warning(f"❌ Failed to load with {model_name}: {str(e)[:100]}...")
182
+ continue
183
+
184
+ # If all fail, try loading with strict=False and show detailed info
185
+ st.info("Attempting to load with strict=False...")
186
+ try:
187
+ # Try with the most common ViT model first
188
+ model = create_model('vit_base_patch16_224', pretrained=False, num_classes=15)
189
+ missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False)
190
+
191
+ if missing_keys:
192
+ st.warning(f"⚠️ Missing keys ({len(missing_keys)}): {missing_keys[:3]}...")
193
+ if unexpected_keys:
194
+ st.warning(f"⚠️ Unexpected keys ({len(unexpected_keys)}): {unexpected_keys[:3]}...")
195
+
196
+ model.eval()
197
+ model.to(device)
198
+
199
+ if len(missing_keys) > 0 or len(unexpected_keys) > 0:
200
+ st.error("⚠️ Model loaded with mismatched weights - predictions will likely be unreliable!")
201
+ st.info("πŸ’‘ The saved model might have been trained with a different architecture.")
202
+ else:
203
+ st.success("βœ… Model loaded successfully with strict=False")
204
+
205
+ return model, device
206
+
207
+ except Exception as e:
208
+ st.error(f"❌ Failed to load model with all methods. Error: {str(e)}")
209
+ st.info("πŸ’‘ Try checking the model file or re-training with a compatible architecture.")
210
+ return None, device
211
+
212
+ except Exception as e:
213
+ st.error(f"❌ Error downloading or loading model: {str(e)}")
214
+ return None, device
215
+
216
+ def get_transform():
217
+ """Define image preprocessing transforms"""
218
+ return transforms.Compose([
219
+ transforms.Resize((224, 224)),
220
+ transforms.ToTensor(),
221
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
222
+ ])
223
+
224
+ def predict_image(image, model, device, transform, label_names, topk=5):
225
+ """Make predictions on uploaded image"""
226
+ # Transform image
227
+ img_tensor = transform(image).unsqueeze(0).to(device)
228
+
229
+ # Predict
230
+ with torch.no_grad():
231
+ outputs = model(img_tensor)
232
+ probs = F.softmax(outputs, dim=1)
233
+ top_probs, top_idxs = torch.topk(probs, k=topk)
234
+
235
+ # Convert to CPU for display
236
+ top_probs = top_probs[0].cpu().numpy()
237
+ top_idxs = top_idxs[0].cpu().numpy()
238
+
239
+ return top_probs, top_idxs
240
+
241
+ def main():
242
+ st.title("πŸ€ Sports Ball Classifier")
243
+ st.markdown("Upload an image of a sports ball and get AI-powered predictions!")
244
+
245
+ # Define label names
246
+ label_names = [
247
+ 'american_football', 'baseball', 'basketball', 'billiard_ball',
248
+ 'bowling_ball', 'cricket_ball', 'football', 'golf_ball',
249
+ 'hockey_ball', 'hockey_puck', 'rugby_ball', 'shuttlecock',
250
+ 'table_tennis_ball', 'tennis_ball', 'volleyball'
251
+ ]
252
+
253
+ # Load model
254
+ with st.spinner("Loading model..."):
255
+ model, device = load_model()
256
+
257
+ if model is None:
258
+ st.error("Failed to load model. Please try again later.")
259
+ return
260
+
261
+ st.success(f"Model loaded successfully! Using device: {device}")
262
+
263
+ # Get image transform
264
+ transform = get_transform()
265
+
266
+ # Create two columns
267
+ col1, col2 = st.columns([1, 1])
268
+
269
+ with col1:
270
+ st.subheader("Upload Image")
271
+ uploaded_file = st.file_uploader(
272
+ "Choose an image...",
273
+ type=['png', 'jpg', 'jpeg'],
274
+ help="Upload a clear image of a sports ball for best results"
275
+ )
276
+
277
+ # Number of top predictions to show
278
+ topk = st.slider("Number of predictions to show:", 1, 10, 5)
279
+
280
+ with col2:
281
+ st.subheader("Predictions")
282
+
283
+ if uploaded_file is not None:
284
+ # Display uploaded image
285
+ image = Image.open(uploaded_file).convert("RGB")
286
+ st.image(image, caption="Uploaded Image", use_container_width=True)
287
+
288
+ # Make predictions
289
+ with st.spinner("Analyzing image..."):
290
+ try:
291
+ top_probs, top_idxs = predict_image(
292
+ image, model, device, transform, label_names, topk
293
+ )
294
+
295
+ # Show original top prediction prominently
296
+ top_confidence = float(top_probs[0] * 100)
297
+ top_label = label_names[top_idxs[0]].replace('_', ' ').title()
298
+
299
+ if top_confidence > 70:
300
+ color = "🟒"
301
+ elif top_confidence > 40:
302
+ color = "🟑"
303
+ else:
304
+ color = "πŸ”΄"
305
+
306
+ st.success(f"{color} **Primary Prediction: {top_label}** ({top_confidence:.2f}%)")
307
+ st.progress(float(top_confidence / 100))
308
+
309
+ # Show top 3 high confidence predictions
310
+ st.subheader("Top 3 Predictions:")
311
+
312
+ for i in range(min(3, len(top_probs))):
313
+ confidence = float(top_probs[i] * 100)
314
+ label = label_names[top_idxs[i]].replace('_', ' ').title()
315
+
316
+ # Color coding based on confidence
317
+ if confidence > 70:
318
+ color = "🟒"
319
+ elif confidence > 40:
320
+ color = "🟑"
321
+ else:
322
+ color = "πŸ”΄"
323
+
324
+ st.write(f"{i+1}. {color} **{label}**: {confidence:.2f}%")
325
+
326
+ # Progress bar for confidence (convert to Python float)
327
+ st.progress(float(confidence / 100))
328
+
329
+ # Show all predictions if user wants more
330
+ if topk > 3:
331
+ with st.expander(f"See all {topk} predictions"):
332
+ for i in range(3, len(top_probs)):
333
+ confidence = float(top_probs[i] * 100)
334
+ label = label_names[top_idxs[i]].replace('_', ' ').title()
335
+
336
+ if confidence > 70:
337
+ color = "🟒"
338
+ elif confidence > 40:
339
+ color = "🟑"
340
+ else:
341
+ color = "πŸ”΄"
342
+
343
+ st.write(f"{i+1}. {color} **{label}**: {confidence:.2f}%")
344
+ st.progress(float(confidence / 100))
345
+
346
+ # Show detailed results in expandable section
347
+ with st.expander("Detailed Results"):
348
+ fig, ax = plt.subplots(figsize=(10, 6))
349
+
350
+ labels = [label_names[idx].replace('_', ' ').title() for idx in top_idxs]
351
+ probabilities = [float(prob * 100) for prob in top_probs] # Convert to Python float
352
+
353
+ bars = ax.barh(labels[::-1], probabilities[::-1])
354
+ ax.set_xlabel('Confidence (%)')
355
+ ax.set_title(f'Top {topk} Predictions')
356
+ ax.set_xlim(0, 100)
357
+
358
+ # Color bars based on confidence
359
+ for bar, prob in zip(bars, probabilities[::-1]):
360
+ if prob > 70:
361
+ bar.set_color('#4CAF50') # Green
362
+ elif prob > 40:
363
+ bar.set_color('#FF9800') # Orange
364
+ else:
365
+ bar.set_color('#F44336') # Red
366
+
367
+ # Add percentage labels on bars
368
+ for i, (bar, prob) in enumerate(zip(bars, probabilities[::-1])):
369
+ ax.text(float(prob) + 1, bar.get_y() + bar.get_height()/2,
370
+ f'{float(prob):.1f}%', va='center')
371
+
372
+ plt.tight_layout()
373
+ st.pyplot(fig)
374
+
375
+ except Exception as e:
376
+ st.error(f"Error during prediction: {str(e)}")
377
+
378
+ else:
379
+ st.info("πŸ‘† Please upload an image to get started!")
380
+
381
+ # Additional information
382
+ st.markdown("---")
383
+ st.subheader("Supported Sports Balls")
384
+
385
+ # Display supported categories in a nice grid
386
+ cols = st.columns(5)
387
+ for i, label in enumerate(label_names):
388
+ with cols[i % 5]:
389
+ st.write(f"β€’ {label.replace('_', ' ').title()}")
390
+
391
+ st.markdown("---")
392
+ st.markdown("""
393
+ **About this model:**
394
+ - Built using ConvNeXt architecture
395
+ - Trained to classify 15 different types of sports balls
396
+ - Model weights from: [Alamgirapi/sports-ball-convnext-classifier](https://huggingface.co/Alamgirapi/sports-ball-convnext-classifier)
397
+
398
+ **Tips for best results:**
399
+ - Use clear, well-lit images
400
+ - Ensure the ball is the main subject
401
+ - Avoid cluttered backgrounds when possible
402
+ """)
403
+
404
+ if __name__ == "__main__":
405
+ main()