""" CIFAR-10 Image Classifier - Hugging Face Space =============================================== Advanced CNN with Residual Connections achieving 87.88% test accuracy Architecture: CIFARNet - Depthwise Separable Convolutions - Residual Connections in C2 and C3 layers - Spatial Dropout for regularization - 174,762 parameters (0.67 MB model size) """ import torch import gradio as gr from PIL import Image from pathlib import Path import numpy as np # Import model architecture and preprocessing from model_cifar import CIFARNet from preprocess import CIFAR10_MEAN, CIFAR10_STD from torchvision import transforms # CIFAR-10 class names CIFAR10_CLASSES = [ "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck" ] # Class descriptions for better UX CLASS_DESCRIPTIONS = { "airplane": "âœˆī¸ Commercial or military aircraft", "automobile": "🚗 Cars, sedans, and vehicles", "bird": "đŸĻ Various bird species", "cat": "🐱 Domestic and wild cats", "deer": "đŸĻŒ Deer and similar animals", "dog": "🐕 Domestic dogs of various breeds", "frog": "🐸 Frogs and similar amphibians", "horse": "🐴 Horses and equines", "ship": "đŸšĸ Ships, boats, and vessels", "truck": "🚚 Trucks and large vehicles" } # Device configuration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model @torch.no_grad() def load_model(checkpoint_path: str = None): """Load the trained CIFARNet model.""" model = CIFARNet(num_classes=10).to(device) # Try to load checkpoint if checkpoint_path and Path(checkpoint_path).exists(): try: checkpoint = torch.load(checkpoint_path, map_location=device) if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) print(f"✅ Loaded checkpoint from epoch {checkpoint.get('epoch', '?')}") else: model.load_state_dict(checkpoint) print(f"✅ Loaded model weights from {checkpoint_path}") except Exception as e: print(f"âš ī¸ Could not load checkpoint: {e}") print("Using randomly initialized model") else: print("â„šī¸ No checkpoint provided, using randomly initialized model") model.eval() return model # Initialize model print(f"Device: {device}") model = load_model("./snapshots_complete/cifar_epoch_249.pth") # Preprocessing pipeline (matches training) preprocess = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(mean=CIFAR10_MEAN, std=CIFAR10_STD) ]) def predict(image: Image.Image) -> tuple: """ Predict the class of an input image. Args: image: PIL Image Returns: Tuple of (predictions_dict, confidence_html) """ if image is None: return {}, "

Please upload an image first!

" try: # Preprocess image img_tensor = preprocess(image.convert("RGB")).unsqueeze(0).to(device) # Inference with torch.no_grad(): outputs = model(img_tensor) probabilities = torch.softmax(outputs, dim=1)[0].cpu().numpy() # Get all predictions sorted by probability sorted_indices = np.argsort(probabilities)[::-1] # Create results dictionary for top 3 top3_results = { CIFAR10_CLASSES[i]: float(probabilities[i]) for i in sorted_indices[:3] } # Create detailed HTML output predicted_class = CIFAR10_CLASSES[sorted_indices[0]] confidence = probabilities[sorted_indices[0]] html_output = f"""

đŸŽ¯ Prediction Result

{predicted_class.upper()}
{CLASS_DESCRIPTIONS[predicted_class]}
Confidence: {confidence*100:.2f}%

📊 Top 5 Predictions:

""" for i, idx in enumerate(sorted_indices[:5], 1): class_name = CIFAR10_CLASSES[idx] prob = probabilities[idx] bar_width = int(prob * 100) # Color coding based on rank if i == 1: color = "#28a745" # Green for top prediction elif i == 2: color = "#17a2b8" # Blue for second else: color = "#6c757d" # Gray for others html_output += f"""
{i}. {class_name} {prob*100:.2f}%
""" html_output += """
""" return top3_results, html_output except Exception as e: error_html = f"

Error during prediction: {str(e)}

" return {}, error_html # Model information for display model_description = """ ## 🚀 About This Model **CIFARNet** is an advanced CNN architecture designed for CIFAR-10 image classification. It achieves state-of-the-art performance with exceptional efficiency. ### 📊 Performance Metrics - **Test Accuracy:** 87.88% - **Top-3 Accuracy:** 97.74% - **Top-5 Accuracy:** 99.31% - **Model Size:** 174,762 parameters (0.67 MB) ### đŸ—ī¸ Architecture Highlights - **Depthwise Separable Convolutions** for parameter efficiency - **Residual Connections** in C2 and C3 layers for improved gradient flow - **Spatial Dropout** for better regularization - **Dilated Convolutions** (C3 layer with dilation=4) for larger receptive field ### đŸŽ¯ Best Performing Classes - Ship: 92.95% F1-score - Truck: 92.19% F1-score - Automobile: 93.77% F1-score - Frog: 90.28% F1-score ### đŸ”Ŧ Training Details - **Training Set:** CIFAR-10 (50,000 images) - **Test Set:** CIFAR-10 (10,000 images) - **Epochs:** 250 with cosine annealing - **Optimizer:** SGD with Nesterov momentum - **Augmentation:** HorizontalFlip, ShiftScaleRotate, CoarseDropout, ColorJitter ### 📚 Classes The model classifies images into 10 categories: - âœˆī¸ Airplane - 🚗 Automobile - đŸĻ Bird - 🐱 Cat - đŸĻŒ Deer - 🐕 Dog - 🐸 Frog - 🐴 Horse - đŸšĸ Ship - 🚚 Truck ### 💡 Tips - Upload clear images for best results - The model works best with images containing the main object centered - Images are automatically resized to 32×32 pixels (CIFAR-10 standard) - Try different angles or lighting conditions to see how the model performs ### 🔗 Links - [GitHub Repository](https://github.com/yourusername/CIFAR10-MLTraining) - [Model Architecture Details](https://github.com/yourusername/CIFAR10-MLTraining#model-architecture) - [Training Logs & Metrics](https://github.com/yourusername/CIFAR10-MLTraining#performance-results) """ # Example images (3 per class) examples = [ ["examples/airplane_1.jpg"], ["examples/airplane_2.jpg"], ["examples/automobile_1.jpg"], ["examples/ship_1.jpg"], ["examples/ship_2.jpg"], ["examples/cat_1.jpg"], ["examples/dog_1.jpg"], ["examples/horse_1.jpg"], ["examples/frog_1.jpg"], ["examples/truck_1.jpg"], ] # Custom CSS for better styling custom_css = """ .gradio-container { font-family: 'Inter', sans-serif; } .output-html { font-family: 'Inter', sans-serif; } """ # Create Gradio interface with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: gr.Markdown("# đŸŽ¯ CIFAR-10 Image Classifier") gr.Markdown("### Advanced CNN achieving 87.88% test accuracy") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil") predict_btn = gr.Button("🚀 Classify Image", variant="primary", size="lg") gr.Markdown("### 📤 Try it out!") gr.Markdown("Upload an image containing one of the 10 CIFAR-10 classes.") with gr.Column(scale=1): label_output = gr.Label(num_top_classes=3, label="Top 3 Predictions") html_output = gr.HTML(label="Detailed Results") # Add examples section gr.Markdown("---") gr.Markdown("## 💡 How to use") gr.Markdown(""" 1. **Upload an image** using the upload box on the left 2. **Click 'Classify Image'** to get predictions 3. **View results** showing the top predictions with confidence scores The model works best with images from these categories: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck. """) gr.Markdown("### 📸 Try These Examples") gr.Examples( examples=examples, inputs=image_input, outputs=[label_output, html_output], fn=predict, cache_examples=False, ) # Model information gr.Markdown("---") with gr.Accordion("📖 Model Information & Performance Metrics", open=False): gr.Markdown(model_description) # Connect the prediction function to button click only predict_btn.click( fn=predict, inputs=image_input, outputs=[label_output, html_output] ) # Launch the app if __name__ == "__main__": demo.launch()