CIFAR-10 RNN Image Classifier
An end-to-end deep learning project for classifying CIFAR-10 images using a Recurrent Neural Network (LSTM) built with PyTorch. Includes a modern web interface for real-time image classification.
π Features
- Custom RNN Architecture: Bidirectional LSTM layers with dropout
- Complete Training Pipeline: Automated training with validation, checkpointing, and visualization
- Comprehensive Evaluation: Confusion matrix, classification reports, and prediction visualizations
- Modern Web Interface: Beautiful Flask web app for real-time image classification
- CIFAR-10 Dataset: Automatically downloads and preprocesses the dataset
π Model Architecture
The model treats each 32x32 RGB image as a sequence of 32 rows, where each row has 96 features (32 pixels * 3 channels).
Input (Batch, 3, 32, 32)
β
Reshape (Batch, 32, 96)
β
Bidirectional LSTM (Hidden: 256, Layers: 2, Dropout: 0.2)
β
Last Time Step Output
β
Fully Connected (512) β ReLU β Dropout(0.3)
β
Output (10 classes)
π Quick Start
1. Install Dependencies
pip install -r requirements.txt
2. Train the Model
python train.py
This will:
- Download the CIFAR-10 dataset automatically
- Train the model for 50 epochs
- Save checkpoints in
./checkpoints/ - Generate training plots in
./plots/
3. Evaluate the Model
python evaluate.py
This will:
- Load the best model checkpoint
- Evaluate on the test set
- Generate confusion matrix
- Create prediction visualizations
4. Run the Web Application
python app.py
Then open your browser and navigate to http://localhost:5000
π Project Structure
CNN/
βββ config.py # Configuration and hyperparameters
βββ data_loader.py # Data loading and preprocessing
βββ model.py # CNN model architecture
βββ train.py # Training script
βββ evaluate.py # Evaluation script
βββ utils.py # Utility functions
βββ app.py # Flask web application
βββ requirements.txt # Python dependencies
βββ templates/
β βββ index.html # Web interface HTML
βββ static/
β βββ style.css # Web interface CSS
β βββ script.js # Web interface JavaScript
βββ checkpoints/ # Model checkpoints (created during training)
βββ plots/ # Training visualizations (created during training)
βββ data/ # CIFAR-10 dataset (downloaded automatically)
π― CIFAR-10 Classes
The model classifies images into 10 categories:
- Airplane
- Automobile
- Bird
- Cat
- Deer
- Dog
- Frog
- Horse
- Ship
- Truck
βοΈ Configuration
Edit config.py to customize:
- Training: epochs, batch size, learning rate
- Model: number of classes, architecture parameters
- Data: augmentation settings, normalization values
- Paths: checkpoint and plot directories
π Training Details
- Optimizer: SGD with momentum (0.9) and weight decay (5e-4)
- Loss Function: Cross-Entropy Loss
- Learning Rate: 0.001 with step decay
- Batch Size: 128
- Data Augmentation: Random crop and horizontal flip
- Regularization: Batch normalization and dropout
π¨ Web Interface Features
- Drag & Drop: Upload images via drag-and-drop
- Random Samples: Test with random CIFAR-10 images
- Real-time Classification: Instant predictions with confidence scores
- Top-5 Predictions: View probability distribution
- Modern UI: Dark theme with smooth animations
π Expected Performance
With the default configuration, the model typically achieves:
- Training Accuracy: ~90%
- Validation Accuracy: ~85%
π οΈ Requirements
- Python 3.7+
- PyTorch 2.0+
- torchvision
- Flask
- NumPy
- Matplotlib
- scikit-learn
- Pillow
- tqdm
π License
This project is open source and available for educational purposes.
π€ Contributing
Feel free to fork this project and submit pull requests for improvements!
π§ Contact
For questions or feedback, please open an issue on the repository.
Built with β€οΈ using PyTorch and Flask