import streamlit as st import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from PIL import Image import matplotlib.pyplot as plt import numpy as np import requests import io from timm import create_model # Set page config st.set_page_config( page_title="Sports Ball Classifier", page_icon="🏀", layout="wide" ) # Custom ConvNeXt model definition (in case the saved model uses a different architecture) class ConvNeXtBlock(nn.Module): def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): super().__init__() self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) self.norm = nn.LayerNorm(dim, eps=1e-6) self.pwconv1 = nn.Linear(dim, 4 * dim) self.act = nn.GELU() self.pwconv2 = nn.Linear(4 * dim, dim) self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) if layer_scale_init_value > 0 else None def forward(self, x): input = x x = self.dwconv(x) x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.pwconv2(x) if self.gamma is not None: x = self.gamma * x x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) x = input + x return x class CustomConvNeXt(nn.Module): def __init__(self, num_classes=15): super().__init__() self.stem = nn.Sequential( nn.Conv2d(3, 96, kernel_size=4, stride=4), nn.LayerNorm([96, 56, 56], eps=1e-6) ) # Stage 1 self.stage1 = nn.Sequential(*[ConvNeXtBlock(96) for _ in range(3)]) # Downsample 1 self.downsample1 = nn.Sequential( nn.LayerNorm([96, 56, 56], eps=1e-6), nn.Conv2d(96, 192, kernel_size=2, stride=2) ) # Stage 2 self.stage2 = nn.Sequential(*[ConvNeXtBlock(192) for _ in range(3)]) # Downsample 2 self.downsample2 = nn.Sequential( nn.LayerNorm([192, 28, 28], eps=1e-6), nn.Conv2d(192, 384, kernel_size=2, stride=2) ) # Stage 3 self.stage3 = nn.Sequential(*[ConvNeXtBlock(384) for _ in range(9)]) # Downsample 3 self.downsample3 = nn.Sequential( nn.LayerNorm([384, 14, 14], eps=1e-6), nn.Conv2d(384, 768, kernel_size=2, stride=2) ) # Stage 4 self.stage4 = nn.Sequential(*[ConvNeXtBlock(768) for _ in range(3)]) # Head self.avgpool = nn.AdaptiveAvgPool2d(1) self.norm = nn.LayerNorm(768, eps=1e-6) self.head = nn.Linear(768, num_classes) def forward(self, x): x = self.stem(x) x = self.stage1(x) x = self.downsample1(x) x = self.stage2(x) x = self.downsample2(x) x = self.stage3(x) x = self.downsample3(x) x = self.stage4(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.norm(x) x = self.head(x) return x # Cache the model loading to avoid reloading on every interaction @st.cache_resource def load_model(): """Load the pre-trained ViT model for sports ball classification""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') try: # Download model weights from Hugging Face model_url = "https://huggingface.co/Alamgirapi/sports-ball-convnext-classifier/resolve/main/model.pth" response = requests.get(model_url) if response.status_code != 200: raise Exception(f"Failed to download model: HTTP {response.status_code}") model_state = torch.load(io.BytesIO(response.content), map_location=device) # Inspect the state dict to understand the model structure sample_keys = list(model_state.keys())[:10] # Try Vision Transformer models (this is likely what was used) vit_models_to_try = [ ("vit_base_patch16_224", lambda: create_model('vit_base_patch16_224', pretrained=False, num_classes=15)), ("vit_small_patch16_224", lambda: create_model('vit_small_patch16_224', pretrained=False, num_classes=15)), ("vit_tiny_patch16_224", lambda: create_model('vit_tiny_patch16_224', pretrained=False, num_classes=15)), ("vit_large_patch16_224", lambda: create_model('vit_large_patch16_224', pretrained=False, num_classes=15)), ("vit_base_patch32_224", lambda: create_model('vit_base_patch32_224', pretrained=False, num_classes=15)), ] st.info("Trying Vision Transformer (ViT) models...") for model_name, model_func in vit_models_to_try: try: model = model_func() model.load_state_dict(model_state) model.eval() model.to(device) st.success(f"✅ Successfully loaded model using: {model_name}") return model, device except Exception as e: st.warning(f"❌ Failed to load with {model_name}: {str(e)[:100]}...") continue # Try ConvNeXt models as fallback convnext_models_to_try = [ ("convnext_tiny", lambda: create_model('convnext_tiny', pretrained=False, num_classes=15)), ("convnext_small", lambda: create_model('convnext_small', pretrained=False, num_classes=15)), ("convnext_base", lambda: create_model('convnext_base', pretrained=False, num_classes=15)), ] st.info("Trying ConvNeXt models as fallback...") for model_name, model_func in convnext_models_to_try: try: model = model_func() model.load_state_dict(model_state) model.eval() model.to(device) st.success(f"✅ Successfully loaded model using: {model_name}") return model, device except Exception as e: st.warning(f"❌ Failed to load with {model_name}: {str(e)[:100]}...") continue # Try other common models other_models_to_try = [ ("resnet50", lambda: create_model('resnet50', pretrained=False, num_classes=15)), ("efficientnet_b0", lambda: create_model('efficientnet_b0', pretrained=False, num_classes=15)), ("mobilenetv3_large_100", lambda: create_model('mobilenetv3_large_100', pretrained=False, num_classes=15)), ] st.info("Trying other model architectures...") for model_name, model_func in other_models_to_try: try: model = model_func() model.load_state_dict(model_state) model.eval() model.to(device) st.success(f"✅ Successfully loaded model using: {model_name}") return model, device except Exception as e: st.warning(f"❌ Failed to load with {model_name}: {str(e)[:100]}...") continue # If all fail, try loading with strict=False and show detailed info st.info("Attempting to load with strict=False...") try: # Try with the most common ViT model first model = create_model('vit_base_patch16_224', pretrained=False, num_classes=15) missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False) if missing_keys: st.warning(f"⚠️ Missing keys ({len(missing_keys)}): {missing_keys[:3]}...") if unexpected_keys: st.warning(f"⚠️ Unexpected keys ({len(unexpected_keys)}): {unexpected_keys[:3]}...") model.eval() model.to(device) if len(missing_keys) > 0 or len(unexpected_keys) > 0: st.error("⚠️ Model loaded with mismatched weights - predictions will likely be unreliable!") st.info("💡 The saved model might have been trained with a different architecture.") else: st.success("✅ Model loaded successfully with strict=False") return model, device except Exception as e: st.error(f"❌ Failed to load model with all methods. Error: {str(e)}") st.info("💡 Try checking the model file or re-training with a compatible architecture.") return None, device except Exception as e: st.error(f"❌ Error downloading or loading model: {str(e)}") return None, device def get_transform(): """Define image preprocessing transforms""" return transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def predict_image(image, model, device, transform, label_names, topk=5): """Make predictions on uploaded image""" # Transform image img_tensor = transform(image).unsqueeze(0).to(device) # Predict with torch.no_grad(): outputs = model(img_tensor) probs = F.softmax(outputs, dim=1) top_probs, top_idxs = torch.topk(probs, k=topk) # Convert to CPU for display top_probs = top_probs[0].cpu().numpy() top_idxs = top_idxs[0].cpu().numpy() return top_probs, top_idxs def main(): st.title("🏀 Sports Ball Classifier") st.markdown("Upload an image of a sports ball and get AI-powered predictions!") # Define label names label_names = [ 'american_football', 'baseball', 'basketball', 'billiard_ball', 'bowling_ball', 'cricket_ball', 'football', 'golf_ball', 'hockey_ball', 'hockey_puck', 'rugby_ball', 'shuttlecock', 'table_tennis_ball', 'tennis_ball', 'volleyball' ] # Load model with st.spinner("Loading model..."): model, device = load_model() if model is None: st.error("Failed to load model. Please try again later.") return st.success(f"Model loaded successfully! Using device: {device}") # Get image transform transform = get_transform() # Create two columns col1, col2 = st.columns([1, 1]) with col1: st.subheader("Upload Image") uploaded_file = st.file_uploader( "Choose an image...", type=['png', 'jpg', 'jpeg'], help="Upload a clear image of a sports ball for best results" ) # Number of top predictions to show topk = st.slider("Number of predictions to show:", 1, 10, 5) with col2: st.subheader("Predictions") if uploaded_file is not None: # Display uploaded image image = Image.open(uploaded_file).convert("RGB") st.image(image, caption="Uploaded Image", use_container_width=True) # Make predictions with st.spinner("Analyzing image..."): try: top_probs, top_idxs = predict_image( image, model, device, transform, label_names, topk ) # Show original top prediction prominently top_confidence = float(top_probs[0] * 100) top_label = label_names[top_idxs[0]].replace('_', ' ').title() if top_confidence > 70: color = "🟢" elif top_confidence > 40: color = "🟡" else: color = "🔴" st.success(f"{color} **Primary Prediction: {top_label}** ({top_confidence:.2f}%)") st.progress(float(top_confidence / 100)) # Show top 3 high confidence predictions st.subheader("Top 3 Predictions:") for i in range(min(3, len(top_probs))): confidence = float(top_probs[i] * 100) label = label_names[top_idxs[i]].replace('_', ' ').title() # Color coding based on confidence if confidence > 70: color = "🟢" elif confidence > 40: color = "🟡" else: color = "🔴" st.write(f"{i+1}. {color} **{label}**: {confidence:.2f}%") # Progress bar for confidence (convert to Python float) st.progress(float(confidence / 100)) # Show all predictions if user wants more if topk > 3: with st.expander(f"See all {topk} predictions"): for i in range(3, len(top_probs)): confidence = float(top_probs[i] * 100) label = label_names[top_idxs[i]].replace('_', ' ').title() if confidence > 70: color = "🟢" elif confidence > 40: color = "🟡" else: color = "🔴" st.write(f"{i+1}. {color} **{label}**: {confidence:.2f}%") st.progress(float(confidence / 100)) # Show detailed results in expandable section with st.expander("Detailed Results"): fig, ax = plt.subplots(figsize=(10, 6)) labels = [label_names[idx].replace('_', ' ').title() for idx in top_idxs] probabilities = [float(prob * 100) for prob in top_probs] # Convert to Python float bars = ax.barh(labels[::-1], probabilities[::-1]) ax.set_xlabel('Confidence (%)') ax.set_title(f'Top {topk} Predictions') ax.set_xlim(0, 100) # Color bars based on confidence for bar, prob in zip(bars, probabilities[::-1]): if prob > 70: bar.set_color('#4CAF50') # Green elif prob > 40: bar.set_color('#FF9800') # Orange else: bar.set_color('#F44336') # Red # Add percentage labels on bars for i, (bar, prob) in enumerate(zip(bars, probabilities[::-1])): ax.text(float(prob) + 1, bar.get_y() + bar.get_height()/2, f'{float(prob):.1f}%', va='center') plt.tight_layout() st.pyplot(fig) except Exception as e: st.error(f"Error during prediction: {str(e)}") else: st.info("👆 Please upload an image to get started!") # Additional information st.markdown("---") st.subheader("Supported Sports Balls") # Display supported categories in a nice grid cols = st.columns(5) for i, label in enumerate(label_names): with cols[i % 5]: st.write(f"• {label.replace('_', ' ').title()}") st.markdown("---") st.markdown(""" **About this model:** - Built using ConvNeXt architecture - Trained to classify 15 different types of sports balls - Model weights from: [Alamgirapi/sports-ball-convnext-classifier](https://huggingface.co/Alamgirapi/sports-ball-convnext-classifier) **Tips for best results:** - Use clear, well-lit images - Ensure the ball is the main subject - Avoid cluttered backgrounds when possible """) if __name__ == "__main__": main()