CNN / README.md
N-I-M-I's picture
Upload folder using huggingface_hub
233caeb verified
---
license: mit
datasets:
- cifar10
metrics:
- accuracy
library_name: pytorch
tags:
- image-classification
- sequence-classification
---
# CIFAR-10 RNN Image Classifier
An end-to-end deep learning project for classifying CIFAR-10 images using a Recurrent Neural Network (LSTM) built with PyTorch. Includes a modern web interface for real-time image classification.
## 🌟 Features
- **Custom RNN Architecture**: Bidirectional LSTM layers with dropout
- **Complete Training Pipeline**: Automated training with validation, checkpointing, and visualization
- **Comprehensive Evaluation**: Confusion matrix, classification reports, and prediction visualizations
- **Modern Web Interface**: Beautiful Flask web app for real-time image classification
- **CIFAR-10 Dataset**: Automatically downloads and preprocesses the dataset
## πŸ“Š Model Architecture
The model treats each 32x32 RGB image as a sequence of 32 rows, where each row has 96 features (32 pixels * 3 channels).
```
Input (Batch, 3, 32, 32)
↓
Reshape (Batch, 32, 96)
↓
Bidirectional LSTM (Hidden: 256, Layers: 2, Dropout: 0.2)
↓
Last Time Step Output
↓
Fully Connected (512) β†’ ReLU β†’ Dropout(0.3)
↓
Output (10 classes)
```
## πŸš€ Quick Start
### 1. Install Dependencies
```bash
pip install -r requirements.txt
```
### 2. Train the Model
```bash
python train.py
```
This will:
- Download the CIFAR-10 dataset automatically
- Train the model for 50 epochs
- Save checkpoints in `./checkpoints/`
- Generate training plots in `./plots/`
### 3. Evaluate the Model
```bash
python evaluate.py
```
This will:
- Load the best model checkpoint
- Evaluate on the test set
- Generate confusion matrix
- Create prediction visualizations
### 4. Run the Web Application
```bash
python app.py
```
Then open your browser and navigate to `http://localhost:5000`
## πŸ“ Project Structure
```
CNN/
β”œβ”€β”€ config.py # Configuration and hyperparameters
β”œβ”€β”€ data_loader.py # Data loading and preprocessing
β”œβ”€β”€ model.py # CNN model architecture
β”œβ”€β”€ train.py # Training script
β”œβ”€β”€ evaluate.py # Evaluation script
β”œβ”€β”€ utils.py # Utility functions
β”œβ”€β”€ app.py # Flask web application
β”œβ”€β”€ requirements.txt # Python dependencies
β”œβ”€β”€ templates/
β”‚ └── index.html # Web interface HTML
β”œβ”€β”€ static/
β”‚ β”œβ”€β”€ style.css # Web interface CSS
β”‚ └── script.js # Web interface JavaScript
β”œβ”€β”€ checkpoints/ # Model checkpoints (created during training)
β”œβ”€β”€ plots/ # Training visualizations (created during training)
└── data/ # CIFAR-10 dataset (downloaded automatically)
```
## 🎯 CIFAR-10 Classes
The model classifies images into 10 categories:
1. Airplane
2. Automobile
3. Bird
4. Cat
5. Deer
6. Dog
7. Frog
8. Horse
9. Ship
10. Truck
## βš™οΈ Configuration
Edit `config.py` to customize:
- **Training**: epochs, batch size, learning rate
- **Model**: number of classes, architecture parameters
- **Data**: augmentation settings, normalization values
- **Paths**: checkpoint and plot directories
## πŸ“ˆ Training Details
- **Optimizer**: SGD with momentum (0.9) and weight decay (5e-4)
- **Loss Function**: Cross-Entropy Loss
- **Learning Rate**: 0.001 with step decay
- **Batch Size**: 128
- **Data Augmentation**: Random crop and horizontal flip
- **Regularization**: Batch normalization and dropout
## 🎨 Web Interface Features
- **Drag & Drop**: Upload images via drag-and-drop
- **Random Samples**: Test with random CIFAR-10 images
- **Real-time Classification**: Instant predictions with confidence scores
- **Top-5 Predictions**: View probability distribution
- **Modern UI**: Dark theme with smooth animations
## πŸ“Š Expected Performance
With the default configuration, the model typically achieves:
- **Training Accuracy**: ~90%
- **Validation Accuracy**: ~85%
## πŸ› οΈ Requirements
- Python 3.7+
- PyTorch 2.0+
- torchvision
- Flask
- NumPy
- Matplotlib
- scikit-learn
- Pillow
- tqdm
## πŸ“ License
This project is open source and available for educational purposes.
## 🀝 Contributing
Feel free to fork this project and submit pull requests for improvements!
## πŸ“§ Contact
For questions or feedback, please open an issue on the repository.
---
**Built with ❀️ using PyTorch and Flask**