|
|
---
|
|
|
license: mit
|
|
|
datasets:
|
|
|
- cifar10
|
|
|
metrics:
|
|
|
- accuracy
|
|
|
library_name: pytorch
|
|
|
tags:
|
|
|
- image-classification
|
|
|
- sequence-classification
|
|
|
---
|
|
|
|
|
|
# CIFAR-10 RNN Image Classifier
|
|
|
|
|
|
An end-to-end deep learning project for classifying CIFAR-10 images using a Recurrent Neural Network (LSTM) built with PyTorch. Includes a modern web interface for real-time image classification.
|
|
|
|
|
|
## π Features
|
|
|
|
|
|
- **Custom RNN Architecture**: Bidirectional LSTM layers with dropout
|
|
|
- **Complete Training Pipeline**: Automated training with validation, checkpointing, and visualization
|
|
|
- **Comprehensive Evaluation**: Confusion matrix, classification reports, and prediction visualizations
|
|
|
- **Modern Web Interface**: Beautiful Flask web app for real-time image classification
|
|
|
- **CIFAR-10 Dataset**: Automatically downloads and preprocesses the dataset
|
|
|
|
|
|
## π Model Architecture
|
|
|
|
|
|
The model treats each 32x32 RGB image as a sequence of 32 rows, where each row has 96 features (32 pixels * 3 channels).
|
|
|
|
|
|
```
|
|
|
Input (Batch, 3, 32, 32)
|
|
|
β
|
|
|
Reshape (Batch, 32, 96)
|
|
|
β
|
|
|
Bidirectional LSTM (Hidden: 256, Layers: 2, Dropout: 0.2)
|
|
|
β
|
|
|
Last Time Step Output
|
|
|
β
|
|
|
Fully Connected (512) β ReLU β Dropout(0.3)
|
|
|
β
|
|
|
Output (10 classes)
|
|
|
```
|
|
|
|
|
|
## π Quick Start
|
|
|
|
|
|
### 1. Install Dependencies
|
|
|
|
|
|
```bash
|
|
|
pip install -r requirements.txt
|
|
|
```
|
|
|
|
|
|
### 2. Train the Model
|
|
|
|
|
|
```bash
|
|
|
python train.py
|
|
|
```
|
|
|
|
|
|
This will:
|
|
|
- Download the CIFAR-10 dataset automatically
|
|
|
- Train the model for 50 epochs
|
|
|
- Save checkpoints in `./checkpoints/`
|
|
|
- Generate training plots in `./plots/`
|
|
|
|
|
|
### 3. Evaluate the Model
|
|
|
|
|
|
```bash
|
|
|
python evaluate.py
|
|
|
```
|
|
|
|
|
|
This will:
|
|
|
- Load the best model checkpoint
|
|
|
- Evaluate on the test set
|
|
|
- Generate confusion matrix
|
|
|
- Create prediction visualizations
|
|
|
|
|
|
### 4. Run the Web Application
|
|
|
|
|
|
```bash
|
|
|
python app.py
|
|
|
```
|
|
|
|
|
|
Then open your browser and navigate to `http://localhost:5000`
|
|
|
|
|
|
## π Project Structure
|
|
|
|
|
|
```
|
|
|
CNN/
|
|
|
βββ config.py # Configuration and hyperparameters
|
|
|
βββ data_loader.py # Data loading and preprocessing
|
|
|
βββ model.py # CNN model architecture
|
|
|
βββ train.py # Training script
|
|
|
βββ evaluate.py # Evaluation script
|
|
|
βββ utils.py # Utility functions
|
|
|
βββ app.py # Flask web application
|
|
|
βββ requirements.txt # Python dependencies
|
|
|
βββ templates/
|
|
|
β βββ index.html # Web interface HTML
|
|
|
βββ static/
|
|
|
β βββ style.css # Web interface CSS
|
|
|
β βββ script.js # Web interface JavaScript
|
|
|
βββ checkpoints/ # Model checkpoints (created during training)
|
|
|
βββ plots/ # Training visualizations (created during training)
|
|
|
βββ data/ # CIFAR-10 dataset (downloaded automatically)
|
|
|
```
|
|
|
|
|
|
## π― CIFAR-10 Classes
|
|
|
|
|
|
The model classifies images into 10 categories:
|
|
|
1. Airplane
|
|
|
2. Automobile
|
|
|
3. Bird
|
|
|
4. Cat
|
|
|
5. Deer
|
|
|
6. Dog
|
|
|
7. Frog
|
|
|
8. Horse
|
|
|
9. Ship
|
|
|
10. Truck
|
|
|
|
|
|
## βοΈ Configuration
|
|
|
|
|
|
Edit `config.py` to customize:
|
|
|
- **Training**: epochs, batch size, learning rate
|
|
|
- **Model**: number of classes, architecture parameters
|
|
|
- **Data**: augmentation settings, normalization values
|
|
|
- **Paths**: checkpoint and plot directories
|
|
|
|
|
|
## π Training Details
|
|
|
|
|
|
- **Optimizer**: SGD with momentum (0.9) and weight decay (5e-4)
|
|
|
- **Loss Function**: Cross-Entropy Loss
|
|
|
- **Learning Rate**: 0.001 with step decay
|
|
|
- **Batch Size**: 128
|
|
|
- **Data Augmentation**: Random crop and horizontal flip
|
|
|
- **Regularization**: Batch normalization and dropout
|
|
|
|
|
|
## π¨ Web Interface Features
|
|
|
|
|
|
- **Drag & Drop**: Upload images via drag-and-drop
|
|
|
- **Random Samples**: Test with random CIFAR-10 images
|
|
|
- **Real-time Classification**: Instant predictions with confidence scores
|
|
|
- **Top-5 Predictions**: View probability distribution
|
|
|
- **Modern UI**: Dark theme with smooth animations
|
|
|
|
|
|
## π Expected Performance
|
|
|
|
|
|
With the default configuration, the model typically achieves:
|
|
|
- **Training Accuracy**: ~90%
|
|
|
- **Validation Accuracy**: ~85%
|
|
|
|
|
|
## π οΈ Requirements
|
|
|
|
|
|
- Python 3.7+
|
|
|
- PyTorch 2.0+
|
|
|
- torchvision
|
|
|
- Flask
|
|
|
- NumPy
|
|
|
- Matplotlib
|
|
|
- scikit-learn
|
|
|
- Pillow
|
|
|
- tqdm
|
|
|
|
|
|
## π License
|
|
|
|
|
|
This project is open source and available for educational purposes.
|
|
|
|
|
|
## π€ Contributing
|
|
|
|
|
|
Feel free to fork this project and submit pull requests for improvements!
|
|
|
|
|
|
## π§ Contact
|
|
|
|
|
|
For questions or feedback, please open an issue on the repository.
|
|
|
|
|
|
---
|
|
|
|
|
|
**Built with β€οΈ using PyTorch and Flask**
|
|
|
|