CIFAR-10 RNN Image Classifier

An end-to-end deep learning project for classifying CIFAR-10 images using a Recurrent Neural Network (LSTM) built with PyTorch. Includes a modern web interface for real-time image classification.

🌟 Features

  • Custom RNN Architecture: Bidirectional LSTM layers with dropout
  • Complete Training Pipeline: Automated training with validation, checkpointing, and visualization
  • Comprehensive Evaluation: Confusion matrix, classification reports, and prediction visualizations
  • Modern Web Interface: Beautiful Flask web app for real-time image classification
  • CIFAR-10 Dataset: Automatically downloads and preprocesses the dataset

πŸ“Š Model Architecture

The model treats each 32x32 RGB image as a sequence of 32 rows, where each row has 96 features (32 pixels * 3 channels).

Input (Batch, 3, 32, 32)
    ↓
Reshape (Batch, 32, 96)
    ↓
Bidirectional LSTM (Hidden: 256, Layers: 2, Dropout: 0.2)
    ↓
Last Time Step Output
    ↓
Fully Connected (512) β†’ ReLU β†’ Dropout(0.3)
    ↓
Output (10 classes)

πŸš€ Quick Start

1. Install Dependencies

pip install -r requirements.txt

2. Train the Model

python train.py

This will:

  • Download the CIFAR-10 dataset automatically
  • Train the model for 50 epochs
  • Save checkpoints in ./checkpoints/
  • Generate training plots in ./plots/

3. Evaluate the Model

python evaluate.py

This will:

  • Load the best model checkpoint
  • Evaluate on the test set
  • Generate confusion matrix
  • Create prediction visualizations

4. Run the Web Application

python app.py

Then open your browser and navigate to http://localhost:5000

πŸ“ Project Structure

CNN/
β”œβ”€β”€ config.py              # Configuration and hyperparameters
β”œβ”€β”€ data_loader.py         # Data loading and preprocessing
β”œβ”€β”€ model.py               # CNN model architecture
β”œβ”€β”€ train.py               # Training script
β”œβ”€β”€ evaluate.py            # Evaluation script
β”œβ”€β”€ utils.py               # Utility functions
β”œβ”€β”€ app.py                 # Flask web application
β”œβ”€β”€ requirements.txt       # Python dependencies
β”œβ”€β”€ templates/
β”‚   └── index.html        # Web interface HTML
β”œβ”€β”€ static/
β”‚   β”œβ”€β”€ style.css         # Web interface CSS
β”‚   └── script.js         # Web interface JavaScript
β”œβ”€β”€ checkpoints/          # Model checkpoints (created during training)
β”œβ”€β”€ plots/                # Training visualizations (created during training)
└── data/                 # CIFAR-10 dataset (downloaded automatically)

🎯 CIFAR-10 Classes

The model classifies images into 10 categories:

  1. Airplane
  2. Automobile
  3. Bird
  4. Cat
  5. Deer
  6. Dog
  7. Frog
  8. Horse
  9. Ship
  10. Truck

βš™οΈ Configuration

Edit config.py to customize:

  • Training: epochs, batch size, learning rate
  • Model: number of classes, architecture parameters
  • Data: augmentation settings, normalization values
  • Paths: checkpoint and plot directories

πŸ“ˆ Training Details

  • Optimizer: SGD with momentum (0.9) and weight decay (5e-4)
  • Loss Function: Cross-Entropy Loss
  • Learning Rate: 0.001 with step decay
  • Batch Size: 128
  • Data Augmentation: Random crop and horizontal flip
  • Regularization: Batch normalization and dropout

🎨 Web Interface Features

  • Drag & Drop: Upload images via drag-and-drop
  • Random Samples: Test with random CIFAR-10 images
  • Real-time Classification: Instant predictions with confidence scores
  • Top-5 Predictions: View probability distribution
  • Modern UI: Dark theme with smooth animations

πŸ“Š Expected Performance

With the default configuration, the model typically achieves:

  • Training Accuracy: ~90%
  • Validation Accuracy: ~85%

πŸ› οΈ Requirements

  • Python 3.7+
  • PyTorch 2.0+
  • torchvision
  • Flask
  • NumPy
  • Matplotlib
  • scikit-learn
  • Pillow
  • tqdm

πŸ“ License

This project is open source and available for educational purposes.

🀝 Contributing

Feel free to fork this project and submit pull requests for improvements!

πŸ“§ Contact

For questions or feedback, please open an issue on the repository.


Built with ❀️ using PyTorch and Flask

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train N-I-M-I/CNN