""" CIFAR-10 Image Classifier - Hugging Face Space =============================================== Advanced CNN with Residual Connections achieving 87.88% test accuracy Architecture: CIFARNet - Depthwise Separable Convolutions - Residual Connections in C2 and C3 layers - Spatial Dropout for regularization - 174,762 parameters (0.67 MB model size) """ import torch import gradio as gr from PIL import Image from pathlib import Path import numpy as np # Import model architecture and preprocessing from model_cifar import CIFARNet from preprocess import CIFAR10_MEAN, CIFAR10_STD from torchvision import transforms # CIFAR-10 class names CIFAR10_CLASSES = [ "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck" ] # Class descriptions for better UX CLASS_DESCRIPTIONS = { "airplane": "âī¸ Commercial or military aircraft", "automobile": "đ Cars, sedans, and vehicles", "bird": "đĻ Various bird species", "cat": "đą Domestic and wild cats", "deer": "đĻ Deer and similar animals", "dog": "đ Domestic dogs of various breeds", "frog": "đ¸ Frogs and similar amphibians", "horse": "đ´ Horses and equines", "ship": "đĸ Ships, boats, and vessels", "truck": "đ Trucks and large vehicles" } # Device configuration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model @torch.no_grad() def load_model(checkpoint_path: str = None): """Load the trained CIFARNet model.""" model = CIFARNet(num_classes=10).to(device) # Try to load checkpoint if checkpoint_path and Path(checkpoint_path).exists(): try: checkpoint = torch.load(checkpoint_path, map_location=device) if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) print(f"â Loaded checkpoint from epoch {checkpoint.get('epoch', '?')}") else: model.load_state_dict(checkpoint) print(f"â Loaded model weights from {checkpoint_path}") except Exception as e: print(f"â ī¸ Could not load checkpoint: {e}") print("Using randomly initialized model") else: print("âšī¸ No checkpoint provided, using randomly initialized model") model.eval() return model # Initialize model print(f"Device: {device}") model = load_model("./snapshots_complete/cifar_epoch_249.pth") # Preprocessing pipeline (matches training) preprocess = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(mean=CIFAR10_MEAN, std=CIFAR10_STD) ]) def predict(image: Image.Image) -> tuple: """ Predict the class of an input image. Args: image: PIL Image Returns: Tuple of (predictions_dict, confidence_html) """ if image is None: return {}, "
Please upload an image first!
" try: # Preprocess image img_tensor = preprocess(image.convert("RGB")).unsqueeze(0).to(device) # Inference with torch.no_grad(): outputs = model(img_tensor) probabilities = torch.softmax(outputs, dim=1)[0].cpu().numpy() # Get all predictions sorted by probability sorted_indices = np.argsort(probabilities)[::-1] # Create results dictionary for top 3 top3_results = { CIFAR10_CLASSES[i]: float(probabilities[i]) for i in sorted_indices[:3] } # Create detailed HTML output predicted_class = CIFAR10_CLASSES[sorted_indices[0]] confidence = probabilities[sorted_indices[0]] html_output = f"""Error during prediction: {str(e)}
" return {}, error_html # Model information for display model_description = """ ## đ About This Model **CIFARNet** is an advanced CNN architecture designed for CIFAR-10 image classification. It achieves state-of-the-art performance with exceptional efficiency. ### đ Performance Metrics - **Test Accuracy:** 87.88% - **Top-3 Accuracy:** 97.74% - **Top-5 Accuracy:** 99.31% - **Model Size:** 174,762 parameters (0.67 MB) ### đī¸ Architecture Highlights - **Depthwise Separable Convolutions** for parameter efficiency - **Residual Connections** in C2 and C3 layers for improved gradient flow - **Spatial Dropout** for better regularization - **Dilated Convolutions** (C3 layer with dilation=4) for larger receptive field ### đ¯ Best Performing Classes - Ship: 92.95% F1-score - Truck: 92.19% F1-score - Automobile: 93.77% F1-score - Frog: 90.28% F1-score ### đŦ Training Details - **Training Set:** CIFAR-10 (50,000 images) - **Test Set:** CIFAR-10 (10,000 images) - **Epochs:** 250 with cosine annealing - **Optimizer:** SGD with Nesterov momentum - **Augmentation:** HorizontalFlip, ShiftScaleRotate, CoarseDropout, ColorJitter ### đ Classes The model classifies images into 10 categories: - âī¸ Airplane - đ Automobile - đĻ Bird - đą Cat - đĻ Deer - đ Dog - đ¸ Frog - đ´ Horse - đĸ Ship - đ Truck ### đĄ Tips - Upload clear images for best results - The model works best with images containing the main object centered - Images are automatically resized to 32Ã32 pixels (CIFAR-10 standard) - Try different angles or lighting conditions to see how the model performs ### đ Links - [GitHub Repository](https://github.com/yourusername/CIFAR10-MLTraining) - [Model Architecture Details](https://github.com/yourusername/CIFAR10-MLTraining#model-architecture) - [Training Logs & Metrics](https://github.com/yourusername/CIFAR10-MLTraining#performance-results) """ # Example images (3 per class) examples = [ ["examples/airplane_1.jpg"], ["examples/airplane_2.jpg"], ["examples/automobile_1.jpg"], ["examples/ship_1.jpg"], ["examples/ship_2.jpg"], ["examples/cat_1.jpg"], ["examples/dog_1.jpg"], ["examples/horse_1.jpg"], ["examples/frog_1.jpg"], ["examples/truck_1.jpg"], ] # Custom CSS for better styling custom_css = """ .gradio-container { font-family: 'Inter', sans-serif; } .output-html { font-family: 'Inter', sans-serif; } """ # Create Gradio interface with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: gr.Markdown("# đ¯ CIFAR-10 Image Classifier") gr.Markdown("### Advanced CNN achieving 87.88% test accuracy") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil") predict_btn = gr.Button("đ Classify Image", variant="primary", size="lg") gr.Markdown("### đ¤ Try it out!") gr.Markdown("Upload an image containing one of the 10 CIFAR-10 classes.") with gr.Column(scale=1): label_output = gr.Label(num_top_classes=3, label="Top 3 Predictions") html_output = gr.HTML(label="Detailed Results") # Add examples section gr.Markdown("---") gr.Markdown("## đĄ How to use") gr.Markdown(""" 1. **Upload an image** using the upload box on the left 2. **Click 'Classify Image'** to get predictions 3. **View results** showing the top predictions with confidence scores The model works best with images from these categories: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck. """) gr.Markdown("### đ¸ Try These Examples") gr.Examples( examples=examples, inputs=image_input, outputs=[label_output, html_output], fn=predict, cache_examples=False, ) # Model information gr.Markdown("---") with gr.Accordion("đ Model Information & Performance Metrics", open=False): gr.Markdown(model_description) # Connect the prediction function to button click only predict_btn.click( fn=predict, inputs=image_input, outputs=[label_output, html_output] ) # Launch the app if __name__ == "__main__": demo.launch()