Spaces:
Runtime error
A newer version of the Gradio SDK is available:
6.4.0
CIFAR-10 Image Classifier - Detailed Explanation
Overview
This application provides a user-friendly interface for running predictions on a trained PyTorch neural network model. The model is based on the implementation from the PyTorch CIFAR-10 Tutorial, which trains a convolutional neural network to classify images from the CIFAR-10 dataset.
Model Architecture Breakdown
The neural network implements the architecture from the PyTorch CIFAR-10 tutorial:
- Input Layer: Accepts RGB images of size 32Γ32 pixels (3 channels)
- First Convolutional Block:
- Conv2d layer: 3 input channels β 6 output channels, 5Γ5 kernel
- ReLU activation function
- MaxPool2d layer: 2Γ2 pooling window
- Second Convolutional Block:
- Conv2d layer: 6 input channels β 16 output channels, 5Γ5 kernel
- ReLU activation function
- MaxPool2d layer: 2Γ2 pooling window
- Fully Connected Layers:
- First FC layer: 400 inputs β 120 outputs with ReLU activation
- Second FC layer: 120 inputs β 84 outputs with ReLU activation
- Output layer: 84 inputs β 10 outputs (for 10 CIFAR-10 classes)
CIFAR-10 Dataset
The CIFAR-10 dataset consists of 60,000 32x32 color images in 10 classes, with 6,000 images per class. The 10 classes are:
- Airplane - Aircraft flying in the sky
- Automobile - Cars and vehicles on the road
- Bird - Flying or perched birds
- Cat - Domestic cats and felines
- Deer - Wild deer and similar animals
- Dog - Domestic dogs and canines
- Frog - Amphibians like frogs
- Horse - Horses and similar animals
- Ship - Boats and ships on water
- Truck - Trucks and heavy vehicles
How the Application Works
1. Model Loading
When the application starts, it attempts to load your trained model weights from a file named model.pth. This file should contain the state dictionary of a model with the exact architecture defined in the Net class, matching the PyTorch CIFAR-10 tutorial.
2. Image Preprocessing
Before making predictions, any input image goes through preprocessing:
- Maintained as RGB (3 channels) - no color conversion
- Resized to 32Γ32 pixels to match the model's expected input size
- Converted to a PyTorch tensor
- Batch dimension added (required by PyTorch)
3. Prediction Process
When you submit an image for classification, the process follows the PyTorch tutorial:
model.eval()
with torch.no_grad():
output = model(input_tensor)
probabilities = F.softmax(output, dim=1)
probabilities = probabilities.numpy()[0]
This implementation:
- Sets the model to evaluation mode with
model.eval() - Disables gradient computation with
torch.no_grad()for efficiency - Applies softmax to convert raw outputs to probabilities
- Extracts the first (and only) batch result
4. User Interface Features
The Gradio interface provides several ways to interact with the model:
- Image Upload: Upload any image file from your computer
- Drawing Tool: Draw an image directly in the browser
- Example Images: Use pre-made examples representing each CIFAR-10 class
- Real-time Results: See prediction probabilities for all 10 classes
- Responsive Design: Works well on both desktop and mobile devices
Image Input Capabilities
Supported Image Formats
The application accepts all common image formats:
- JPEG, PNG, BMP, TIFF, GIF, and WebP
- Color images (maintained as RGB with 3 channels)
- Images of any resolution (automatically resized to 32Γ32)
Robustness Features
The model has been designed to handle various image conditions:
- Resolution Independence: Works with images of any size (resized to 32Γ32)
- Color Preservation: Maintains RGB color information
- Contrast Handling: Works with both high and low contrast images
- Noise Tolerance: Can handle some image noise
- Rotation Tolerance: Some tolerance to slight rotations
- Scale Invariance: Works with objects of different sizes
Best Practices for Good Results
To get the best classification results:
- Center the object in the image area
- Use clear contrast between the object and background
- Fill most of the image area with the object
- Avoid excessive noise or artifacts
- Ensure the object is clearly visible
Image Preprocessing Pipeline
The complete preprocessing pipeline:
- Image upload or drawing
- Resize to 32Γ32 pixels using bilinear interpolation
- Conversion to PyTorch tensor with values scaled to [0,1]
- Addition of batch dimension for model inference
Technical Implementation Details
Custom CSS Styling
The application features a modern UI with:
- Animated gradient background
- Glass-morphism design elements
- Responsive layout that adapts to different screen sizes
- Interactive buttons with hover effects
- Clean typography using Google Fonts
Error Handling
The application gracefully handles:
- Missing model files (shows error message)
- Empty inputs (returns zero probabilities)
- Various image formats (maintained as RGB)
Performance Optimizations
- Model loaded once at startup
- Gradients disabled during inference
- Efficient tensor operations
- Caching of example predictions
Deployment to Hugging Face Spaces
To deploy this application to Hugging Face Spaces:
- Create a new Space with the "Gradio" SDK
- Upload all files from this directory
- Ensure your
model.pthfile is included - The Space will automatically install dependencies from
requirements.txt - The application will start automatically
Customization Guide
Using a Different Model File
If your model is saved with a different filename:
- Modify the
model_pathvariable in theload_model()function - Ensure the model architecture matches the
Netclass definition exactly
Changing Class Labels
To customize the class labels:
- Modify the
cifar10_classeslist in thepredict()function - Update the example images in the
create_example_images()function to match your new classes
Adjusting Image Preprocessing
To modify how images are preprocessed:
- Edit the
preprocess_image()function - Change the resize dimensions if your model expects different input size
- Add normalization if your model was trained with normalized inputs
Troubleshooting Common Issues
Model Not Loading
- Verify
model.pthis in the same directory asapp.py - Ensure the model architecture matches the
Netclass definition exactly - Check that the file is not corrupted
Poor Prediction Accuracy
- Verify your model was trained on similar data (CIFAR-10 or similar)
- Check if the preprocessing matches what was used during training
- Ensure input images are similar to the training data
UI Display Issues
- Update Gradio to the latest version
- Check browser compatibility
- Clear browser cache if styles aren't loading correctly
File Structure
cifar10-classifier/
βββ app.py # Main application file
βββ requirements.txt # Python dependencies
βββ README.md # User guide
βββ EXPLANATION.md # This file
βββ model.pth # Your trained model (to be added)
βββ space.json # Hugging Face Spaces configuration
Requirements Explanation
- torch>=1.7.0: Core PyTorch library for neural network operations
- torchvision>=0.8.0: Computer vision utilities, including image transforms
- gradio>=4.0.0: Framework for creating machine learning web interfaces
- pillow>=8.0.0: Python Imaging Library for image processing
- numpy>=1.19.0: Numerical computing library for array operations
Example Use Cases
- Object Recognition: Classify images into 10 common object categories
- Educational Tool: Demonstrate how convolutional neural networks work on real image data
- Model Showcase: Present your trained model to others in an interactive way
- Testing Platform: Evaluate model performance on custom inputs
This application provides a complete solution for deploying a PyTorch model trained on CIFAR-10 with an attractive, user-friendly interface that can be easily shared with others through Hugging Face Spaces. The implementation is based on the PyTorch CIFAR-10 tutorial, ensuring compatibility with models trained using the same approach.