img-classifier / EXPLANATION.md
ksj47's picture
Upload 7 files
b862b3f verified

A newer version of the Gradio SDK is available: 6.4.0

Upgrade

CIFAR-10 Image Classifier - Detailed Explanation

Overview

This application provides a user-friendly interface for running predictions on a trained PyTorch neural network model. The model is based on the implementation from the PyTorch CIFAR-10 Tutorial, which trains a convolutional neural network to classify images from the CIFAR-10 dataset.

Model Architecture Breakdown

The neural network implements the architecture from the PyTorch CIFAR-10 tutorial:

  1. Input Layer: Accepts RGB images of size 32Γ—32 pixels (3 channels)
  2. First Convolutional Block:
    • Conv2d layer: 3 input channels β†’ 6 output channels, 5Γ—5 kernel
    • ReLU activation function
    • MaxPool2d layer: 2Γ—2 pooling window
  3. Second Convolutional Block:
    • Conv2d layer: 6 input channels β†’ 16 output channels, 5Γ—5 kernel
    • ReLU activation function
    • MaxPool2d layer: 2Γ—2 pooling window
  4. Fully Connected Layers:
    • First FC layer: 400 inputs β†’ 120 outputs with ReLU activation
    • Second FC layer: 120 inputs β†’ 84 outputs with ReLU activation
    • Output layer: 84 inputs β†’ 10 outputs (for 10 CIFAR-10 classes)

CIFAR-10 Dataset

The CIFAR-10 dataset consists of 60,000 32x32 color images in 10 classes, with 6,000 images per class. The 10 classes are:

  1. Airplane - Aircraft flying in the sky
  2. Automobile - Cars and vehicles on the road
  3. Bird - Flying or perched birds
  4. Cat - Domestic cats and felines
  5. Deer - Wild deer and similar animals
  6. Dog - Domestic dogs and canines
  7. Frog - Amphibians like frogs
  8. Horse - Horses and similar animals
  9. Ship - Boats and ships on water
  10. Truck - Trucks and heavy vehicles

How the Application Works

1. Model Loading

When the application starts, it attempts to load your trained model weights from a file named model.pth. This file should contain the state dictionary of a model with the exact architecture defined in the Net class, matching the PyTorch CIFAR-10 tutorial.

2. Image Preprocessing

Before making predictions, any input image goes through preprocessing:

  • Maintained as RGB (3 channels) - no color conversion
  • Resized to 32Γ—32 pixels to match the model's expected input size
  • Converted to a PyTorch tensor
  • Batch dimension added (required by PyTorch)

3. Prediction Process

When you submit an image for classification, the process follows the PyTorch tutorial:

model.eval()
with torch.no_grad():
    output = model(input_tensor)
    probabilities = F.softmax(output, dim=1)
    probabilities = probabilities.numpy()[0]

This implementation:

  • Sets the model to evaluation mode with model.eval()
  • Disables gradient computation with torch.no_grad() for efficiency
  • Applies softmax to convert raw outputs to probabilities
  • Extracts the first (and only) batch result

4. User Interface Features

The Gradio interface provides several ways to interact with the model:

  • Image Upload: Upload any image file from your computer
  • Drawing Tool: Draw an image directly in the browser
  • Example Images: Use pre-made examples representing each CIFAR-10 class
  • Real-time Results: See prediction probabilities for all 10 classes
  • Responsive Design: Works well on both desktop and mobile devices

Image Input Capabilities

Supported Image Formats

The application accepts all common image formats:

  • JPEG, PNG, BMP, TIFF, GIF, and WebP
  • Color images (maintained as RGB with 3 channels)
  • Images of any resolution (automatically resized to 32Γ—32)

Robustness Features

The model has been designed to handle various image conditions:

  • Resolution Independence: Works with images of any size (resized to 32Γ—32)
  • Color Preservation: Maintains RGB color information
  • Contrast Handling: Works with both high and low contrast images
  • Noise Tolerance: Can handle some image noise
  • Rotation Tolerance: Some tolerance to slight rotations
  • Scale Invariance: Works with objects of different sizes

Best Practices for Good Results

To get the best classification results:

  1. Center the object in the image area
  2. Use clear contrast between the object and background
  3. Fill most of the image area with the object
  4. Avoid excessive noise or artifacts
  5. Ensure the object is clearly visible

Image Preprocessing Pipeline

The complete preprocessing pipeline:

  1. Image upload or drawing
  2. Resize to 32Γ—32 pixels using bilinear interpolation
  3. Conversion to PyTorch tensor with values scaled to [0,1]
  4. Addition of batch dimension for model inference

Technical Implementation Details

Custom CSS Styling

The application features a modern UI with:

  • Animated gradient background
  • Glass-morphism design elements
  • Responsive layout that adapts to different screen sizes
  • Interactive buttons with hover effects
  • Clean typography using Google Fonts

Error Handling

The application gracefully handles:

  • Missing model files (shows error message)
  • Empty inputs (returns zero probabilities)
  • Various image formats (maintained as RGB)

Performance Optimizations

  • Model loaded once at startup
  • Gradients disabled during inference
  • Efficient tensor operations
  • Caching of example predictions

Deployment to Hugging Face Spaces

To deploy this application to Hugging Face Spaces:

  1. Create a new Space with the "Gradio" SDK
  2. Upload all files from this directory
  3. Ensure your model.pth file is included
  4. The Space will automatically install dependencies from requirements.txt
  5. The application will start automatically

Customization Guide

Using a Different Model File

If your model is saved with a different filename:

  1. Modify the model_path variable in the load_model() function
  2. Ensure the model architecture matches the Net class definition exactly

Changing Class Labels

To customize the class labels:

  1. Modify the cifar10_classes list in the predict() function
  2. Update the example images in the create_example_images() function to match your new classes

Adjusting Image Preprocessing

To modify how images are preprocessed:

  1. Edit the preprocess_image() function
  2. Change the resize dimensions if your model expects different input size
  3. Add normalization if your model was trained with normalized inputs

Troubleshooting Common Issues

Model Not Loading

  • Verify model.pth is in the same directory as app.py
  • Ensure the model architecture matches the Net class definition exactly
  • Check that the file is not corrupted

Poor Prediction Accuracy

  • Verify your model was trained on similar data (CIFAR-10 or similar)
  • Check if the preprocessing matches what was used during training
  • Ensure input images are similar to the training data

UI Display Issues

  • Update Gradio to the latest version
  • Check browser compatibility
  • Clear browser cache if styles aren't loading correctly

File Structure

cifar10-classifier/
β”œβ”€β”€ app.py              # Main application file
β”œβ”€β”€ requirements.txt    # Python dependencies
β”œβ”€β”€ README.md           # User guide
β”œβ”€β”€ EXPLANATION.md      # This file
β”œβ”€β”€ model.pth           # Your trained model (to be added)
└── space.json          # Hugging Face Spaces configuration

Requirements Explanation

  • torch>=1.7.0: Core PyTorch library for neural network operations
  • torchvision>=0.8.0: Computer vision utilities, including image transforms
  • gradio>=4.0.0: Framework for creating machine learning web interfaces
  • pillow>=8.0.0: Python Imaging Library for image processing
  • numpy>=1.19.0: Numerical computing library for array operations

Example Use Cases

  1. Object Recognition: Classify images into 10 common object categories
  2. Educational Tool: Demonstrate how convolutional neural networks work on real image data
  3. Model Showcase: Present your trained model to others in an interactive way
  4. Testing Platform: Evaluate model performance on custom inputs

This application provides a complete solution for deploying a PyTorch model trained on CIFAR-10 with an attractive, user-friendly interface that can be easily shared with others through Hugging Face Spaces. The implementation is based on the PyTorch CIFAR-10 tutorial, ensuring compatibility with models trained using the same approach.