Upload folder using huggingface_hub
Browse files- .gitignore +63 -0
- README.md +173 -0
- app.py +166 -0
- config.py +50 -0
- data_loader.py +100 -0
- debug_train.py +88 -0
- evaluate.py +96 -0
- model.py +93 -0
- requirements.txt +8 -0
- static/script.js +163 -0
- static/style.css +391 -0
- templates/index.html +103 -0
- train.py +220 -0
- train_log.txt +0 -0
- utils.py +186 -0
.gitignore
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
|
| 23 |
+
# Virtual Environment
|
| 24 |
+
venv/
|
| 25 |
+
ENV/
|
| 26 |
+
env/
|
| 27 |
+
.venv
|
| 28 |
+
|
| 29 |
+
# PyTorch
|
| 30 |
+
*.pth
|
| 31 |
+
*.pt
|
| 32 |
+
checkpoints/
|
| 33 |
+
!checkpoints/.gitkeep
|
| 34 |
+
|
| 35 |
+
# Data
|
| 36 |
+
data/
|
| 37 |
+
*.pkl
|
| 38 |
+
*.pickle
|
| 39 |
+
|
| 40 |
+
# Plots and visualizations
|
| 41 |
+
plots/
|
| 42 |
+
!plots/.gitkeep
|
| 43 |
+
|
| 44 |
+
# IDE
|
| 45 |
+
.vscode/
|
| 46 |
+
.idea/
|
| 47 |
+
*.swp
|
| 48 |
+
*.swo
|
| 49 |
+
*~
|
| 50 |
+
|
| 51 |
+
# OS
|
| 52 |
+
.DS_Store
|
| 53 |
+
Thumbs.db
|
| 54 |
+
|
| 55 |
+
# Jupyter Notebook
|
| 56 |
+
.ipynb_checkpoints
|
| 57 |
+
|
| 58 |
+
# Flask
|
| 59 |
+
instance/
|
| 60 |
+
.webassets-cache
|
| 61 |
+
|
| 62 |
+
# Logs
|
| 63 |
+
*.log
|
README.md
CHANGED
|
@@ -1,3 +1,176 @@
|
|
| 1 |
---
|
| 2 |
license: mit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
license: mit
|
| 3 |
+
datasets:
|
| 4 |
+
- cifar10
|
| 5 |
+
metrics:
|
| 6 |
+
- accuracy
|
| 7 |
+
library_name: pytorch
|
| 8 |
+
tags:
|
| 9 |
+
- image-classification
|
| 10 |
+
- sequence-classification
|
| 11 |
---
|
| 12 |
+
|
| 13 |
+
# CIFAR-10 RNN Image Classifier
|
| 14 |
+
|
| 15 |
+
An end-to-end deep learning project for classifying CIFAR-10 images using a Recurrent Neural Network (LSTM) built with PyTorch. Includes a modern web interface for real-time image classification.
|
| 16 |
+
|
| 17 |
+
## 🌟 Features
|
| 18 |
+
|
| 19 |
+
- **Custom RNN Architecture**: Bidirectional LSTM layers with dropout
|
| 20 |
+
- **Complete Training Pipeline**: Automated training with validation, checkpointing, and visualization
|
| 21 |
+
- **Comprehensive Evaluation**: Confusion matrix, classification reports, and prediction visualizations
|
| 22 |
+
- **Modern Web Interface**: Beautiful Flask web app for real-time image classification
|
| 23 |
+
- **CIFAR-10 Dataset**: Automatically downloads and preprocesses the dataset
|
| 24 |
+
|
| 25 |
+
## 📊 Model Architecture
|
| 26 |
+
|
| 27 |
+
The model treats each 32x32 RGB image as a sequence of 32 rows, where each row has 96 features (32 pixels * 3 channels).
|
| 28 |
+
|
| 29 |
+
```
|
| 30 |
+
Input (Batch, 3, 32, 32)
|
| 31 |
+
↓
|
| 32 |
+
Reshape (Batch, 32, 96)
|
| 33 |
+
↓
|
| 34 |
+
Bidirectional LSTM (Hidden: 256, Layers: 2, Dropout: 0.2)
|
| 35 |
+
↓
|
| 36 |
+
Last Time Step Output
|
| 37 |
+
↓
|
| 38 |
+
Fully Connected (512) → ReLU → Dropout(0.3)
|
| 39 |
+
↓
|
| 40 |
+
Output (10 classes)
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
## 🚀 Quick Start
|
| 44 |
+
|
| 45 |
+
### 1. Install Dependencies
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
pip install -r requirements.txt
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### 2. Train the Model
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
python train.py
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
This will:
|
| 58 |
+
- Download the CIFAR-10 dataset automatically
|
| 59 |
+
- Train the model for 50 epochs
|
| 60 |
+
- Save checkpoints in `./checkpoints/`
|
| 61 |
+
- Generate training plots in `./plots/`
|
| 62 |
+
|
| 63 |
+
### 3. Evaluate the Model
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
python evaluate.py
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
This will:
|
| 70 |
+
- Load the best model checkpoint
|
| 71 |
+
- Evaluate on the test set
|
| 72 |
+
- Generate confusion matrix
|
| 73 |
+
- Create prediction visualizations
|
| 74 |
+
|
| 75 |
+
### 4. Run the Web Application
|
| 76 |
+
|
| 77 |
+
```bash
|
| 78 |
+
python app.py
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
Then open your browser and navigate to `http://localhost:5000`
|
| 82 |
+
|
| 83 |
+
## 📁 Project Structure
|
| 84 |
+
|
| 85 |
+
```
|
| 86 |
+
CNN/
|
| 87 |
+
├── config.py # Configuration and hyperparameters
|
| 88 |
+
├── data_loader.py # Data loading and preprocessing
|
| 89 |
+
├── model.py # CNN model architecture
|
| 90 |
+
├── train.py # Training script
|
| 91 |
+
├── evaluate.py # Evaluation script
|
| 92 |
+
├── utils.py # Utility functions
|
| 93 |
+
├── app.py # Flask web application
|
| 94 |
+
├── requirements.txt # Python dependencies
|
| 95 |
+
├── templates/
|
| 96 |
+
│ └── index.html # Web interface HTML
|
| 97 |
+
├── static/
|
| 98 |
+
│ ├── style.css # Web interface CSS
|
| 99 |
+
│ └── script.js # Web interface JavaScript
|
| 100 |
+
├── checkpoints/ # Model checkpoints (created during training)
|
| 101 |
+
├── plots/ # Training visualizations (created during training)
|
| 102 |
+
└── data/ # CIFAR-10 dataset (downloaded automatically)
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
## 🎯 CIFAR-10 Classes
|
| 106 |
+
|
| 107 |
+
The model classifies images into 10 categories:
|
| 108 |
+
1. Airplane
|
| 109 |
+
2. Automobile
|
| 110 |
+
3. Bird
|
| 111 |
+
4. Cat
|
| 112 |
+
5. Deer
|
| 113 |
+
6. Dog
|
| 114 |
+
7. Frog
|
| 115 |
+
8. Horse
|
| 116 |
+
9. Ship
|
| 117 |
+
10. Truck
|
| 118 |
+
|
| 119 |
+
## ⚙️ Configuration
|
| 120 |
+
|
| 121 |
+
Edit `config.py` to customize:
|
| 122 |
+
- **Training**: epochs, batch size, learning rate
|
| 123 |
+
- **Model**: number of classes, architecture parameters
|
| 124 |
+
- **Data**: augmentation settings, normalization values
|
| 125 |
+
- **Paths**: checkpoint and plot directories
|
| 126 |
+
|
| 127 |
+
## 📈 Training Details
|
| 128 |
+
|
| 129 |
+
- **Optimizer**: SGD with momentum (0.9) and weight decay (5e-4)
|
| 130 |
+
- **Loss Function**: Cross-Entropy Loss
|
| 131 |
+
- **Learning Rate**: 0.001 with step decay
|
| 132 |
+
- **Batch Size**: 128
|
| 133 |
+
- **Data Augmentation**: Random crop and horizontal flip
|
| 134 |
+
- **Regularization**: Batch normalization and dropout
|
| 135 |
+
|
| 136 |
+
## 🎨 Web Interface Features
|
| 137 |
+
|
| 138 |
+
- **Drag & Drop**: Upload images via drag-and-drop
|
| 139 |
+
- **Random Samples**: Test with random CIFAR-10 images
|
| 140 |
+
- **Real-time Classification**: Instant predictions with confidence scores
|
| 141 |
+
- **Top-5 Predictions**: View probability distribution
|
| 142 |
+
- **Modern UI**: Dark theme with smooth animations
|
| 143 |
+
|
| 144 |
+
## 📊 Expected Performance
|
| 145 |
+
|
| 146 |
+
With the default configuration, the model typically achieves:
|
| 147 |
+
- **Training Accuracy**: ~90%
|
| 148 |
+
- **Validation Accuracy**: ~85%
|
| 149 |
+
|
| 150 |
+
## 🛠️ Requirements
|
| 151 |
+
|
| 152 |
+
- Python 3.7+
|
| 153 |
+
- PyTorch 2.0+
|
| 154 |
+
- torchvision
|
| 155 |
+
- Flask
|
| 156 |
+
- NumPy
|
| 157 |
+
- Matplotlib
|
| 158 |
+
- scikit-learn
|
| 159 |
+
- Pillow
|
| 160 |
+
- tqdm
|
| 161 |
+
|
| 162 |
+
## 📝 License
|
| 163 |
+
|
| 164 |
+
This project is open source and available for educational purposes.
|
| 165 |
+
|
| 166 |
+
## 🤝 Contributing
|
| 167 |
+
|
| 168 |
+
Feel free to fork this project and submit pull requests for improvements!
|
| 169 |
+
|
| 170 |
+
## 📧 Contact
|
| 171 |
+
|
| 172 |
+
For questions or feedback, please open an issue on the repository.
|
| 173 |
+
|
| 174 |
+
---
|
| 175 |
+
|
| 176 |
+
**Built with ❤️ using PyTorch and Flask**
|
app.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Flask web application for CIFAR-10 image classification
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import io
|
| 6 |
+
import base64
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from flask import Flask, render_template, request, jsonify
|
| 10 |
+
import torchvision.transforms as transforms
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
import config
|
| 14 |
+
from model import get_model
|
| 15 |
+
from utils import load_checkpoint
|
| 16 |
+
|
| 17 |
+
app = Flask(__name__)
|
| 18 |
+
|
| 19 |
+
# Global model variable
|
| 20 |
+
model = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def load_model():
|
| 24 |
+
"""Load the trained model"""
|
| 25 |
+
global model
|
| 26 |
+
|
| 27 |
+
if not os.path.exists(config.BEST_MODEL_PATH):
|
| 28 |
+
print(f"Warning: Model checkpoint not found at {config.BEST_MODEL_PATH}")
|
| 29 |
+
return False
|
| 30 |
+
|
| 31 |
+
model = get_model(num_classes=config.NUM_CLASSES, device=config.DEVICE)
|
| 32 |
+
epoch, accuracy = load_checkpoint(model, None, config.BEST_MODEL_PATH)
|
| 33 |
+
model.eval()
|
| 34 |
+
|
| 35 |
+
print(f"Model loaded from epoch {epoch + 1} with accuracy: {accuracy:.2f}%")
|
| 36 |
+
return True
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def preprocess_image(image):
|
| 40 |
+
"""
|
| 41 |
+
Preprocess image for model prediction
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
image: PIL Image
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
torch.Tensor: Preprocessed image tensor
|
| 48 |
+
"""
|
| 49 |
+
transform = transforms.Compose([
|
| 50 |
+
transforms.Resize((32, 32)),
|
| 51 |
+
transforms.ToTensor(),
|
| 52 |
+
transforms.Normalize(
|
| 53 |
+
mean=[0.4914, 0.4822, 0.4465],
|
| 54 |
+
std=[0.2470, 0.2435, 0.2616]
|
| 55 |
+
)
|
| 56 |
+
])
|
| 57 |
+
|
| 58 |
+
return transform(image).unsqueeze(0)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@app.route('/')
|
| 62 |
+
def index():
|
| 63 |
+
"""Render the main page"""
|
| 64 |
+
return render_template('index.html', class_names=config.CLASS_NAMES)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@app.route('/predict', methods=['POST'])
|
| 68 |
+
def predict():
|
| 69 |
+
"""Handle prediction requests"""
|
| 70 |
+
if model is None:
|
| 71 |
+
return jsonify({'error': 'Model not loaded'}), 500
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
# Get image from request
|
| 75 |
+
if 'file' not in request.files:
|
| 76 |
+
return jsonify({'error': 'No file provided'}), 400
|
| 77 |
+
|
| 78 |
+
file = request.files['file']
|
| 79 |
+
|
| 80 |
+
if file.filename == '':
|
| 81 |
+
return jsonify({'error': 'No file selected'}), 400
|
| 82 |
+
|
| 83 |
+
# Read and preprocess image
|
| 84 |
+
image = Image.open(file.stream).convert('RGB')
|
| 85 |
+
input_tensor = preprocess_image(image).to(config.DEVICE)
|
| 86 |
+
|
| 87 |
+
# Make prediction
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
output = model(input_tensor)
|
| 90 |
+
probabilities = torch.nn.functional.softmax(output[0], dim=0)
|
| 91 |
+
confidence, predicted = torch.max(probabilities, 0)
|
| 92 |
+
|
| 93 |
+
# Get top 5 predictions
|
| 94 |
+
top5_prob, top5_idx = torch.topk(probabilities, 5)
|
| 95 |
+
top5_predictions = [
|
| 96 |
+
{
|
| 97 |
+
'class': config.CLASS_NAMES[idx],
|
| 98 |
+
'probability': float(prob * 100)
|
| 99 |
+
}
|
| 100 |
+
for idx, prob in zip(top5_idx.cpu().numpy(), top5_prob.cpu().numpy())
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
# Prepare response
|
| 104 |
+
response = {
|
| 105 |
+
'predicted_class': config.CLASS_NAMES[predicted.item()],
|
| 106 |
+
'confidence': float(confidence.item() * 100),
|
| 107 |
+
'top5_predictions': top5_predictions
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
return jsonify(response)
|
| 111 |
+
|
| 112 |
+
except Exception as e:
|
| 113 |
+
return jsonify({'error': str(e)}), 500
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@app.route('/random_sample', methods=['GET'])
|
| 117 |
+
def random_sample():
|
| 118 |
+
"""Get a random sample from CIFAR-10 test set or generate dummy if missing"""
|
| 119 |
+
try:
|
| 120 |
+
from data_loader import get_data_loaders
|
| 121 |
+
# Check if dataset exists
|
| 122 |
+
dataset_path = os.path.join(config.DATA_DIR, 'cifar-10-batches-py')
|
| 123 |
+
|
| 124 |
+
if os.path.exists(dataset_path):
|
| 125 |
+
_, test_loader = get_data_loaders()
|
| 126 |
+
dataset = test_loader.dataset
|
| 127 |
+
idx = np.random.randint(0, len(dataset))
|
| 128 |
+
image, label = dataset[idx]
|
| 129 |
+
|
| 130 |
+
# Denormalize image
|
| 131 |
+
mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
|
| 132 |
+
std = torch.tensor([0.2470, 0.2435, 0.2616]).view(3, 1, 1)
|
| 133 |
+
image_denorm = image * std + mean
|
| 134 |
+
image_denorm = torch.clamp(image_denorm, 0, 1)
|
| 135 |
+
|
| 136 |
+
# Convert to PIL Image
|
| 137 |
+
image_np = (image_denorm.numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
|
| 138 |
+
label_name = config.CLASS_NAMES[label]
|
| 139 |
+
else:
|
| 140 |
+
# Generate dummy image for demonstration
|
| 141 |
+
image_np = np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8)
|
| 142 |
+
label_name = "Dummy Sample (Dataset still downloading)"
|
| 143 |
+
|
| 144 |
+
pil_image = Image.fromarray(image_np)
|
| 145 |
+
|
| 146 |
+
# Convert to base64
|
| 147 |
+
buffered = io.BytesIO()
|
| 148 |
+
pil_image.save(buffered, format="PNG")
|
| 149 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
| 150 |
+
|
| 151 |
+
return jsonify({
|
| 152 |
+
'image': f'data:image/png;base64,{img_str}',
|
| 153 |
+
'true_label': label_name
|
| 154 |
+
})
|
| 155 |
+
|
| 156 |
+
except Exception as e:
|
| 157 |
+
return jsonify({'error': str(e)}), 500
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
if __name__ == '__main__':
|
| 161 |
+
# Load model
|
| 162 |
+
if load_model():
|
| 163 |
+
print("Starting Flask application...")
|
| 164 |
+
app.run(debug=True, host='0.0.0.0', port=5000)
|
| 165 |
+
else:
|
| 166 |
+
print("Failed to load model. Please train the model first using train.py")
|
config.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration file for CIFAR-10 CNN project
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
# Data configuration
|
| 7 |
+
DATA_DIR = './data'
|
| 8 |
+
BATCH_SIZE = 32
|
| 9 |
+
NUM_WORKERS = 0 # Set to 0 for better stability on some systems without GPU
|
| 10 |
+
|
| 11 |
+
# Model configuration
|
| 12 |
+
NUM_CLASSES = 10
|
| 13 |
+
INPUT_CHANNELS = 3
|
| 14 |
+
IMAGE_SIZE = 32
|
| 15 |
+
HIDDEN_SIZE = 256
|
| 16 |
+
NUM_LAYERS = 2
|
| 17 |
+
RNN_DROPOUT = 0.2
|
| 18 |
+
|
| 19 |
+
# Training configuration
|
| 20 |
+
EPOCHS = 5
|
| 21 |
+
LEARNING_RATE = 0.01 # Increased slightly for faster convergence in few epochs
|
| 22 |
+
WEIGHT_DECAY = 5e-4
|
| 23 |
+
MOMENTUM = 0.9
|
| 24 |
+
|
| 25 |
+
# Device configuration
|
| 26 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 27 |
+
|
| 28 |
+
# Checkpoint configuration
|
| 29 |
+
CHECKPOINT_DIR = './checkpoints'
|
| 30 |
+
BEST_MODEL_PATH = './checkpoints/best_model.pth'
|
| 31 |
+
LAST_MODEL_PATH = './checkpoints/last_model.pth'
|
| 32 |
+
|
| 33 |
+
# Visualization configuration
|
| 34 |
+
PLOTS_DIR = './plots'
|
| 35 |
+
|
| 36 |
+
# CIFAR-10 class names
|
| 37 |
+
CLASS_NAMES = [
|
| 38 |
+
'airplane', 'automobile', 'bird', 'cat', 'deer',
|
| 39 |
+
'dog', 'frog', 'horse', 'ship', 'truck'
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
# Data augmentation settings
|
| 43 |
+
USE_AUGMENTATION = True
|
| 44 |
+
RANDOM_CROP_PADDING = 4
|
| 45 |
+
RANDOM_HORIZONTAL_FLIP = 0.5
|
| 46 |
+
|
| 47 |
+
# Learning rate scheduler
|
| 48 |
+
USE_SCHEDULER = True
|
| 49 |
+
SCHEDULER_STEP_SIZE = 20
|
| 50 |
+
SCHEDULER_GAMMA = 0.1
|
data_loader.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data loading and preprocessing for CIFAR-10 dataset
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torchvision import datasets, transforms
|
| 7 |
+
import config
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_transforms(train=True):
|
| 11 |
+
"""
|
| 12 |
+
Get data transformations for training or testing
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
train (bool): If True, returns training transforms with augmentation
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
torchvision.transforms.Compose: Composed transforms
|
| 19 |
+
"""
|
| 20 |
+
if train and config.USE_AUGMENTATION:
|
| 21 |
+
transform = transforms.Compose([
|
| 22 |
+
transforms.RandomCrop(32, padding=config.RANDOM_CROP_PADDING),
|
| 23 |
+
transforms.RandomHorizontalFlip(p=config.RANDOM_HORIZONTAL_FLIP),
|
| 24 |
+
transforms.ToTensor(),
|
| 25 |
+
transforms.Normalize(
|
| 26 |
+
mean=[0.4914, 0.4822, 0.4465],
|
| 27 |
+
std=[0.2470, 0.2435, 0.2616]
|
| 28 |
+
)
|
| 29 |
+
])
|
| 30 |
+
else:
|
| 31 |
+
transform = transforms.Compose([
|
| 32 |
+
transforms.ToTensor(),
|
| 33 |
+
transforms.Normalize(
|
| 34 |
+
mean=[0.4914, 0.4822, 0.4465],
|
| 35 |
+
std=[0.2470, 0.2435, 0.2616]
|
| 36 |
+
)
|
| 37 |
+
])
|
| 38 |
+
|
| 39 |
+
return transform
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_data_loaders():
|
| 43 |
+
"""
|
| 44 |
+
Create train and test data loaders for CIFAR-10
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
tuple: (train_loader, test_loader)
|
| 48 |
+
"""
|
| 49 |
+
# Get transforms
|
| 50 |
+
train_transform = get_transforms(train=True)
|
| 51 |
+
test_transform = get_transforms(train=False)
|
| 52 |
+
|
| 53 |
+
# Load datasets
|
| 54 |
+
train_dataset = datasets.CIFAR10(
|
| 55 |
+
root=config.DATA_DIR,
|
| 56 |
+
train=True,
|
| 57 |
+
download=True,
|
| 58 |
+
transform=train_transform
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
test_dataset = datasets.CIFAR10(
|
| 62 |
+
root=config.DATA_DIR,
|
| 63 |
+
train=False,
|
| 64 |
+
download=True,
|
| 65 |
+
transform=test_transform
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# Create data loaders
|
| 69 |
+
train_loader = DataLoader(
|
| 70 |
+
train_dataset,
|
| 71 |
+
batch_size=config.BATCH_SIZE,
|
| 72 |
+
shuffle=True,
|
| 73 |
+
num_workers=config.NUM_WORKERS,
|
| 74 |
+
pin_memory=True if config.DEVICE.type == 'cuda' else False
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
test_loader = DataLoader(
|
| 78 |
+
test_dataset,
|
| 79 |
+
batch_size=config.BATCH_SIZE,
|
| 80 |
+
shuffle=False,
|
| 81 |
+
num_workers=config.NUM_WORKERS,
|
| 82 |
+
pin_memory=True if config.DEVICE.type == 'cuda' else False
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
return train_loader, test_loader
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def denormalize(tensor):
|
| 89 |
+
"""
|
| 90 |
+
Denormalize a tensor image for visualization
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
tensor: Normalized tensor image
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
tensor: Denormalized tensor image
|
| 97 |
+
"""
|
| 98 |
+
mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
|
| 99 |
+
std = torch.tensor([0.2470, 0.2435, 0.2616]).view(3, 1, 1)
|
| 100 |
+
return tensor * std + mean
|
debug_train.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Debug training script using dummy data to test the pipeline without downloading CIFAR-10
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
import config
|
| 12 |
+
from model import get_model, count_parameters
|
| 13 |
+
from utils import save_checkpoint, plot_training_history
|
| 14 |
+
|
| 15 |
+
def get_dummy_data_loaders():
|
| 16 |
+
"""Create dummy data loaders for testing"""
|
| 17 |
+
# Create random images (32x32) and labels (0-9)
|
| 18 |
+
train_size = 100
|
| 19 |
+
test_size = 20
|
| 20 |
+
|
| 21 |
+
train_images = torch.randn(train_size, 3, 32, 32)
|
| 22 |
+
train_labels = torch.randint(0, 10, (train_size,))
|
| 23 |
+
|
| 24 |
+
test_images = torch.randn(test_size, 3, 32, 32)
|
| 25 |
+
test_labels = torch.randint(0, 10, (test_size,))
|
| 26 |
+
|
| 27 |
+
train_dataset = TensorDataset(train_images, train_labels)
|
| 28 |
+
test_dataset = TensorDataset(test_images, test_labels)
|
| 29 |
+
|
| 30 |
+
train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)
|
| 31 |
+
test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE, shuffle=False)
|
| 32 |
+
|
| 33 |
+
return train_loader, test_loader
|
| 34 |
+
|
| 35 |
+
def debug_train():
|
| 36 |
+
"""Debug training function"""
|
| 37 |
+
os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)
|
| 38 |
+
os.makedirs(config.PLOTS_DIR, exist_ok=True)
|
| 39 |
+
|
| 40 |
+
print("Creating dummy data loaders...")
|
| 41 |
+
train_loader, test_loader = get_dummy_data_loaders()
|
| 42 |
+
|
| 43 |
+
print(f"Creating model on {config.DEVICE}...")
|
| 44 |
+
model = get_model(num_classes=config.NUM_CLASSES, device=config.DEVICE)
|
| 45 |
+
|
| 46 |
+
criterion = nn.CrossEntropyLoss()
|
| 47 |
+
optimizer = optim.SGD(model.parameters(), lr=0.01)
|
| 48 |
+
|
| 49 |
+
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
|
| 50 |
+
|
| 51 |
+
print("Starting debug training for 2 epochs...")
|
| 52 |
+
for epoch in range(2):
|
| 53 |
+
model.train()
|
| 54 |
+
running_loss = 0.0
|
| 55 |
+
correct = 0
|
| 56 |
+
total = 0
|
| 57 |
+
|
| 58 |
+
for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
|
| 59 |
+
inputs, labels = inputs.to(config.DEVICE), labels.to(config.DEVICE)
|
| 60 |
+
optimizer.zero_grad()
|
| 61 |
+
outputs = model(inputs)
|
| 62 |
+
loss = criterion(outputs, labels)
|
| 63 |
+
loss.backward()
|
| 64 |
+
optimizer.step()
|
| 65 |
+
|
| 66 |
+
running_loss += loss.item()
|
| 67 |
+
_, predicted = outputs.max(1)
|
| 68 |
+
total += labels.size(0)
|
| 69 |
+
correct += predicted.eq(labels).sum().item()
|
| 70 |
+
|
| 71 |
+
train_loss = running_loss / len(train_loader)
|
| 72 |
+
train_acc = 100. * correct / total
|
| 73 |
+
|
| 74 |
+
history['train_loss'].append(train_loss)
|
| 75 |
+
history['train_acc'].append(train_acc)
|
| 76 |
+
history['val_loss'].append(train_loss) # Just use train loss for dummy validation
|
| 77 |
+
history['val_acc'].append(train_acc)
|
| 78 |
+
|
| 79 |
+
print(f"Epoch {epoch+1}: Loss {train_loss:.4f}, Acc {train_acc:.2f}%")
|
| 80 |
+
|
| 81 |
+
# Save "best" model for app testing
|
| 82 |
+
save_checkpoint(model, optimizer, epoch, train_acc, config.BEST_MODEL_PATH)
|
| 83 |
+
plot_training_history(history, config.PLOTS_DIR)
|
| 84 |
+
|
| 85 |
+
print("\nDebug training complete. 'best_model.pth' created for testing the web app.")
|
| 86 |
+
|
| 87 |
+
if __name__ == "__main__":
|
| 88 |
+
debug_train()
|
evaluate.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation script for CIFAR-10 CNN
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
import config
|
| 9 |
+
from model import get_model
|
| 10 |
+
from data_loader import get_data_loaders
|
| 11 |
+
from utils import (
|
| 12 |
+
load_checkpoint, plot_confusion_matrix,
|
| 13 |
+
print_classification_report, visualize_predictions
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def evaluate():
|
| 18 |
+
"""
|
| 19 |
+
Evaluate the trained model
|
| 20 |
+
"""
|
| 21 |
+
# Create plots directory
|
| 22 |
+
os.makedirs(config.PLOTS_DIR, exist_ok=True)
|
| 23 |
+
|
| 24 |
+
# Get data loaders
|
| 25 |
+
print("Loading CIFAR-10 dataset...")
|
| 26 |
+
_, test_loader = get_data_loaders()
|
| 27 |
+
print(f"Test samples: {len(test_loader.dataset)}")
|
| 28 |
+
|
| 29 |
+
# Create model
|
| 30 |
+
print(f"\nLoading model from {config.BEST_MODEL_PATH}")
|
| 31 |
+
model = get_model(num_classes=config.NUM_CLASSES, device=config.DEVICE)
|
| 32 |
+
|
| 33 |
+
# Load checkpoint
|
| 34 |
+
if not os.path.exists(config.BEST_MODEL_PATH):
|
| 35 |
+
print(f"Error: Model checkpoint not found at {config.BEST_MODEL_PATH}")
|
| 36 |
+
print("Please train the model first using train.py")
|
| 37 |
+
return
|
| 38 |
+
|
| 39 |
+
epoch, accuracy = load_checkpoint(model, None, config.BEST_MODEL_PATH)
|
| 40 |
+
print(f"Loaded model from epoch {epoch + 1} with accuracy: {accuracy:.2f}%")
|
| 41 |
+
|
| 42 |
+
# Evaluate
|
| 43 |
+
model.eval()
|
| 44 |
+
correct = 0
|
| 45 |
+
total = 0
|
| 46 |
+
all_predictions = []
|
| 47 |
+
all_labels = []
|
| 48 |
+
|
| 49 |
+
print("\nEvaluating model...")
|
| 50 |
+
with torch.no_grad():
|
| 51 |
+
pbar = tqdm(test_loader, desc='Evaluating')
|
| 52 |
+
for inputs, labels in pbar:
|
| 53 |
+
inputs, labels = inputs.to(config.DEVICE), labels.to(config.DEVICE)
|
| 54 |
+
|
| 55 |
+
# Forward pass
|
| 56 |
+
outputs = model(inputs)
|
| 57 |
+
_, predicted = outputs.max(1)
|
| 58 |
+
|
| 59 |
+
# Statistics
|
| 60 |
+
total += labels.size(0)
|
| 61 |
+
correct += predicted.eq(labels).sum().item()
|
| 62 |
+
|
| 63 |
+
# Store predictions and labels
|
| 64 |
+
all_predictions.extend(predicted.cpu().numpy())
|
| 65 |
+
all_labels.extend(labels.cpu().numpy())
|
| 66 |
+
|
| 67 |
+
# Update progress bar
|
| 68 |
+
pbar.set_postfix({'acc': f'{100. * correct / total:.2f}%'})
|
| 69 |
+
|
| 70 |
+
# Calculate final accuracy
|
| 71 |
+
final_accuracy = 100. * correct / total
|
| 72 |
+
|
| 73 |
+
# Print results
|
| 74 |
+
print("\n" + "=" * 80)
|
| 75 |
+
print(f"Final Test Accuracy: {final_accuracy:.2f}%")
|
| 76 |
+
print(f"Correct predictions: {correct}/{total}")
|
| 77 |
+
print("=" * 80)
|
| 78 |
+
|
| 79 |
+
# Print classification report
|
| 80 |
+
print_classification_report(all_labels, all_predictions)
|
| 81 |
+
|
| 82 |
+
# Plot confusion matrix
|
| 83 |
+
print("\nGenerating confusion matrix...")
|
| 84 |
+
cm_path = os.path.join(config.PLOTS_DIR, 'confusion_matrix.png')
|
| 85 |
+
plot_confusion_matrix(all_labels, all_predictions, cm_path)
|
| 86 |
+
print(f"Confusion matrix saved to {cm_path}")
|
| 87 |
+
|
| 88 |
+
# Visualize predictions
|
| 89 |
+
print("\nGenerating prediction visualizations...")
|
| 90 |
+
visualize_predictions(model, test_loader, config.DEVICE, num_images=16)
|
| 91 |
+
|
| 92 |
+
print("\nEvaluation completed!")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
if __name__ == '__main__':
|
| 96 |
+
evaluate()
|
model.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RNN Model Architecture for CIFAR-10 Classification
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import config
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CIFAR10RNN(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
Recurrent Neural Network (LSTM) for CIFAR-10 classification
|
| 12 |
+
|
| 13 |
+
Architecture:
|
| 14 |
+
- Input sequence: 32 rows of 32x3 pixels (= 96 features per step)
|
| 15 |
+
- Bidirectional LSTM layers
|
| 16 |
+
- Fully connected layer for classification
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, input_size=96, hidden_size=256, num_layers=2, num_classes=10):
|
| 20 |
+
super(CIFAR10RNN, self).__init__()
|
| 21 |
+
|
| 22 |
+
self.hidden_size = hidden_size
|
| 23 |
+
self.num_layers = num_layers
|
| 24 |
+
|
| 25 |
+
# LSTM Layer
|
| 26 |
+
# batch_first=True means input shape is (batch, seq, feature)
|
| 27 |
+
self.lstm = nn.LSTM(
|
| 28 |
+
input_size,
|
| 29 |
+
hidden_size,
|
| 30 |
+
num_layers,
|
| 31 |
+
batch_first=True,
|
| 32 |
+
bidirectional=True,
|
| 33 |
+
dropout=config.RNN_DROPOUT if num_layers > 1 else 0
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Fully Connected Layer
|
| 37 |
+
# * 2 because of bidirectional
|
| 38 |
+
self.fc = nn.Sequential(
|
| 39 |
+
nn.Linear(hidden_size * 2, 512),
|
| 40 |
+
nn.ReLU(),
|
| 41 |
+
nn.Dropout(0.3),
|
| 42 |
+
nn.Linear(512, num_classes)
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
# x shape: (batch, 3, 32, 32)
|
| 47 |
+
# Convert to: (batch, seq_len=32, input_size=96)
|
| 48 |
+
batch_size = x.size(0)
|
| 49 |
+
|
| 50 |
+
# Rearrange image rows into a sequence
|
| 51 |
+
# (batch, 3, 32, 32) -> (batch, 32, 3, 32) -> (batch, 32, 96)
|
| 52 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 53 |
+
x = x.view(batch_size, 32, -1)
|
| 54 |
+
|
| 55 |
+
# LSTM Forward pass
|
| 56 |
+
# out: tensor of shape (batch, seq_len, hidden_size * 2)
|
| 57 |
+
out, _ = self.lstm(x)
|
| 58 |
+
|
| 59 |
+
# Take the output of the last time step
|
| 60 |
+
out = out[:, -1, :]
|
| 61 |
+
|
| 62 |
+
# Classification
|
| 63 |
+
out = self.fc(out)
|
| 64 |
+
|
| 65 |
+
return out
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_model(num_classes=10, device='cpu'):
|
| 69 |
+
"""
|
| 70 |
+
Create and return the RNN model
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
num_classes (int): Number of output classes
|
| 74 |
+
device (str or torch.device): Device to load the model on
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
CIFAR10RNN: The RNN model
|
| 78 |
+
"""
|
| 79 |
+
model = CIFAR10RNN(
|
| 80 |
+
input_size=32*3,
|
| 81 |
+
hidden_size=config.HIDDEN_SIZE,
|
| 82 |
+
num_layers=config.NUM_LAYERS,
|
| 83 |
+
num_classes=num_classes
|
| 84 |
+
)
|
| 85 |
+
model = model.to(device)
|
| 86 |
+
return model
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def count_parameters(model):
|
| 90 |
+
"""
|
| 91 |
+
Count the number of trainable parameters in the model
|
| 92 |
+
"""
|
| 93 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
torchvision>=0.15.0
|
| 3 |
+
numpy>=1.24.0
|
| 4 |
+
matplotlib>=3.7.0
|
| 5 |
+
pillow>=9.5.0
|
| 6 |
+
flask>=2.3.0
|
| 7 |
+
tqdm>=4.65.0
|
| 8 |
+
scikit-learn>=1.3.0
|
static/script.js
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// CIFAR-10 Classifier JavaScript
|
| 2 |
+
|
| 3 |
+
let selectedFile = null;
|
| 4 |
+
|
| 5 |
+
// DOM Elements
|
| 6 |
+
const uploadArea = document.getElementById('uploadArea');
|
| 7 |
+
const fileInput = document.getElementById('fileInput');
|
| 8 |
+
const previewCard = document.getElementById('previewCard');
|
| 9 |
+
const previewImage = document.getElementById('previewImage');
|
| 10 |
+
const classifyBtn = document.getElementById('classifyBtn');
|
| 11 |
+
const randomBtn = document.getElementById('randomBtn');
|
| 12 |
+
const resultsSection = document.getElementById('resultsSection');
|
| 13 |
+
const loadingOverlay = document.getElementById('loadingOverlay');
|
| 14 |
+
|
| 15 |
+
// Upload area click handler
|
| 16 |
+
uploadArea.addEventListener('click', () => {
|
| 17 |
+
fileInput.click();
|
| 18 |
+
});
|
| 19 |
+
|
| 20 |
+
// File input change handler
|
| 21 |
+
fileInput.addEventListener('change', (e) => {
|
| 22 |
+
const file = e.target.files[0];
|
| 23 |
+
if (file) {
|
| 24 |
+
handleFile(file);
|
| 25 |
+
}
|
| 26 |
+
});
|
| 27 |
+
|
| 28 |
+
// Drag and drop handlers
|
| 29 |
+
uploadArea.addEventListener('dragover', (e) => {
|
| 30 |
+
e.preventDefault();
|
| 31 |
+
uploadArea.classList.add('drag-over');
|
| 32 |
+
});
|
| 33 |
+
|
| 34 |
+
uploadArea.addEventListener('dragleave', () => {
|
| 35 |
+
uploadArea.classList.remove('drag-over');
|
| 36 |
+
});
|
| 37 |
+
|
| 38 |
+
uploadArea.addEventListener('drop', (e) => {
|
| 39 |
+
e.preventDefault();
|
| 40 |
+
uploadArea.classList.remove('drag-over');
|
| 41 |
+
|
| 42 |
+
const file = e.dataTransfer.files[0];
|
| 43 |
+
if (file && file.type.startsWith('image/')) {
|
| 44 |
+
handleFile(file);
|
| 45 |
+
}
|
| 46 |
+
});
|
| 47 |
+
|
| 48 |
+
// Handle file selection
|
| 49 |
+
function handleFile(file) {
|
| 50 |
+
selectedFile = file;
|
| 51 |
+
|
| 52 |
+
const reader = new FileReader();
|
| 53 |
+
reader.onload = (e) => {
|
| 54 |
+
previewImage.src = e.target.result;
|
| 55 |
+
previewCard.style.display = 'block';
|
| 56 |
+
resultsSection.style.display = 'none';
|
| 57 |
+
};
|
| 58 |
+
reader.readAsDataURL(file);
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
// Classify button handler
|
| 62 |
+
classifyBtn.addEventListener('click', async () => {
|
| 63 |
+
if (!selectedFile) return;
|
| 64 |
+
|
| 65 |
+
const formData = new FormData();
|
| 66 |
+
formData.append('file', selectedFile);
|
| 67 |
+
|
| 68 |
+
try {
|
| 69 |
+
loadingOverlay.style.display = 'flex';
|
| 70 |
+
|
| 71 |
+
const response = await fetch('/predict', {
|
| 72 |
+
method: 'POST',
|
| 73 |
+
body: formData
|
| 74 |
+
});
|
| 75 |
+
|
| 76 |
+
const data = await response.json();
|
| 77 |
+
|
| 78 |
+
if (data.error) {
|
| 79 |
+
alert('Error: ' + data.error);
|
| 80 |
+
return;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
displayResults(data);
|
| 84 |
+
|
| 85 |
+
} catch (error) {
|
| 86 |
+
alert('Error: ' + error.message);
|
| 87 |
+
} finally {
|
| 88 |
+
loadingOverlay.style.display = 'none';
|
| 89 |
+
}
|
| 90 |
+
});
|
| 91 |
+
|
| 92 |
+
// Random sample button handler
|
| 93 |
+
randomBtn.addEventListener('click', async () => {
|
| 94 |
+
try {
|
| 95 |
+
loadingOverlay.style.display = 'flex';
|
| 96 |
+
|
| 97 |
+
const response = await fetch('/random_sample');
|
| 98 |
+
const data = await response.json();
|
| 99 |
+
|
| 100 |
+
if (data.error) {
|
| 101 |
+
alert('Error: ' + data.error);
|
| 102 |
+
return;
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
// Convert base64 to blob
|
| 106 |
+
const blob = await fetch(data.image).then(r => r.blob());
|
| 107 |
+
const file = new File([blob], 'random_sample.png', { type: 'image/png' });
|
| 108 |
+
|
| 109 |
+
handleFile(file);
|
| 110 |
+
|
| 111 |
+
} catch (error) {
|
| 112 |
+
alert('Error: ' + error.message);
|
| 113 |
+
} finally {
|
| 114 |
+
loadingOverlay.style.display = 'none';
|
| 115 |
+
}
|
| 116 |
+
});
|
| 117 |
+
|
| 118 |
+
// Display classification results
|
| 119 |
+
function displayResults(data) {
|
| 120 |
+
document.getElementById('predictedClass').textContent = data.predicted_class;
|
| 121 |
+
document.getElementById('confidenceValue').textContent = data.confidence.toFixed(2) + '%';
|
| 122 |
+
|
| 123 |
+
// Update confidence badge color based on confidence level
|
| 124 |
+
const badge = document.getElementById('confidenceBadge');
|
| 125 |
+
if (data.confidence >= 80) {
|
| 126 |
+
badge.style.background = 'rgba(79, 172, 254, 0.2)';
|
| 127 |
+
badge.style.borderColor = 'rgba(79, 172, 254, 0.4)';
|
| 128 |
+
badge.style.color = '#4facfe';
|
| 129 |
+
} else if (data.confidence >= 60) {
|
| 130 |
+
badge.style.background = 'rgba(240, 147, 251, 0.2)';
|
| 131 |
+
badge.style.borderColor = 'rgba(240, 147, 251, 0.4)';
|
| 132 |
+
badge.style.color = '#f093fb';
|
| 133 |
+
} else {
|
| 134 |
+
badge.style.background = 'rgba(245, 87, 108, 0.2)';
|
| 135 |
+
badge.style.borderColor = 'rgba(245, 87, 108, 0.4)';
|
| 136 |
+
badge.style.color = '#f5576c';
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
// Display top 5 predictions
|
| 140 |
+
const top5Container = document.getElementById('top5Container');
|
| 141 |
+
top5Container.innerHTML = '';
|
| 142 |
+
|
| 143 |
+
data.top5_predictions.forEach((pred, index) => {
|
| 144 |
+
const item = document.createElement('div');
|
| 145 |
+
item.className = 'prediction-item';
|
| 146 |
+
item.style.animationDelay = `${index * 0.1}s`;
|
| 147 |
+
|
| 148 |
+
item.innerHTML = `
|
| 149 |
+
<span class="prediction-item-name">${pred.class}</span>
|
| 150 |
+
<div class="prediction-item-bar">
|
| 151 |
+
<div class="prediction-item-fill" style="width: ${pred.probability}%"></div>
|
| 152 |
+
</div>
|
| 153 |
+
<span class="prediction-item-value">${pred.probability.toFixed(2)}%</span>
|
| 154 |
+
`;
|
| 155 |
+
|
| 156 |
+
top5Container.appendChild(item);
|
| 157 |
+
});
|
| 158 |
+
|
| 159 |
+
resultsSection.style.display = 'grid';
|
| 160 |
+
|
| 161 |
+
// Scroll to results
|
| 162 |
+
resultsSection.scrollIntoView({ behavior: 'smooth', block: 'nearest' });
|
| 163 |
+
}
|
static/style.css
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Modern CSS for CIFAR-10 Classifier */
|
| 2 |
+
|
| 3 |
+
:root {
|
| 4 |
+
--primary-gradient: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 5 |
+
--secondary-gradient: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
|
| 6 |
+
--success-gradient: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
|
| 7 |
+
--bg-primary: #0a0e27;
|
| 8 |
+
--bg-secondary: #151932;
|
| 9 |
+
--bg-card: rgba(255, 255, 255, 0.05);
|
| 10 |
+
--bg-card-hover: rgba(255, 255, 255, 0.08);
|
| 11 |
+
--text-primary: #ffffff;
|
| 12 |
+
--text-secondary: #a0aec0;
|
| 13 |
+
--text-muted: #718096;
|
| 14 |
+
--border-color: rgba(255, 255, 255, 0.1);
|
| 15 |
+
--shadow-lg: 0 8px 32px rgba(0, 0, 0, 0.3);
|
| 16 |
+
--spacing-xs: 0.5rem;
|
| 17 |
+
--spacing-sm: 1rem;
|
| 18 |
+
--spacing-md: 1.5rem;
|
| 19 |
+
--spacing-lg: 2rem;
|
| 20 |
+
--spacing-xl: 3rem;
|
| 21 |
+
--radius-sm: 8px;
|
| 22 |
+
--radius-md: 12px;
|
| 23 |
+
--radius-lg: 16px;
|
| 24 |
+
--radius-xl: 24px;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
* {
|
| 28 |
+
margin: 0;
|
| 29 |
+
padding: 0;
|
| 30 |
+
box-sizing: border-box;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
body {
|
| 34 |
+
font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
|
| 35 |
+
background: var(--bg-primary);
|
| 36 |
+
color: var(--text-primary);
|
| 37 |
+
line-height: 1.6;
|
| 38 |
+
min-height: 100vh;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
body::before {
|
| 42 |
+
content: '';
|
| 43 |
+
position: fixed;
|
| 44 |
+
top: 0;
|
| 45 |
+
left: 0;
|
| 46 |
+
width: 100%;
|
| 47 |
+
height: 100%;
|
| 48 |
+
background: radial-gradient(circle at 20% 50%, rgba(102, 126, 234, 0.1) 0%, transparent 50%),
|
| 49 |
+
radial-gradient(circle at 80% 80%, rgba(118, 75, 162, 0.1) 0%, transparent 50%);
|
| 50 |
+
z-index: -1;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
.container {
|
| 54 |
+
max-width: 1400px;
|
| 55 |
+
margin: 0 auto;
|
| 56 |
+
padding: var(--spacing-lg);
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
.header {
|
| 60 |
+
text-align: center;
|
| 61 |
+
margin-bottom: var(--spacing-xl);
|
| 62 |
+
padding: var(--spacing-xl) 0;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
.title {
|
| 66 |
+
font-size: 3.5rem;
|
| 67 |
+
font-weight: 700;
|
| 68 |
+
margin-bottom: var(--spacing-sm);
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
.gradient-text {
|
| 72 |
+
background: var(--primary-gradient);
|
| 73 |
+
-webkit-background-clip: text;
|
| 74 |
+
-webkit-text-fill-color: transparent;
|
| 75 |
+
background-clip: text;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
.subtitle {
|
| 79 |
+
font-size: 1.25rem;
|
| 80 |
+
color: var(--text-secondary);
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
.upload-section {
|
| 84 |
+
display: grid;
|
| 85 |
+
grid-template-columns: 1fr 1fr;
|
| 86 |
+
gap: var(--spacing-lg);
|
| 87 |
+
margin-bottom: var(--spacing-xl);
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
.card {
|
| 91 |
+
background: var(--bg-card);
|
| 92 |
+
backdrop-filter: blur(10px);
|
| 93 |
+
border: 1px solid var(--border-color);
|
| 94 |
+
border-radius: var(--radius-lg);
|
| 95 |
+
padding: var(--spacing-lg);
|
| 96 |
+
transition: all 0.3s ease;
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
.card:hover {
|
| 100 |
+
background: var(--bg-card-hover);
|
| 101 |
+
transform: translateY(-2px);
|
| 102 |
+
box-shadow: var(--shadow-lg);
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
.card-title {
|
| 106 |
+
font-size: 1.5rem;
|
| 107 |
+
font-weight: 600;
|
| 108 |
+
margin-bottom: var(--spacing-md);
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
.upload-area {
|
| 112 |
+
border: 2px dashed var(--border-color);
|
| 113 |
+
border-radius: var(--radius-md);
|
| 114 |
+
padding: var(--spacing-xl);
|
| 115 |
+
text-align: center;
|
| 116 |
+
cursor: pointer;
|
| 117 |
+
transition: all 0.3s ease;
|
| 118 |
+
margin-bottom: var(--spacing-md);
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
.upload-area:hover {
|
| 122 |
+
border-color: #667eea;
|
| 123 |
+
background: rgba(102, 126, 234, 0.05);
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
.upload-icon {
|
| 127 |
+
width: 64px;
|
| 128 |
+
height: 64px;
|
| 129 |
+
margin: 0 auto var(--spacing-md);
|
| 130 |
+
color: #667eea;
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
.upload-text {
|
| 134 |
+
font-size: 1.125rem;
|
| 135 |
+
font-weight: 500;
|
| 136 |
+
margin-bottom: var(--spacing-xs);
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
.upload-subtext {
|
| 140 |
+
font-size: 0.875rem;
|
| 141 |
+
color: var(--text-muted);
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
.image-preview {
|
| 145 |
+
width: 100%;
|
| 146 |
+
height: 300px;
|
| 147 |
+
border-radius: var(--radius-md);
|
| 148 |
+
overflow: hidden;
|
| 149 |
+
margin-bottom: var(--spacing-md);
|
| 150 |
+
background: var(--bg-secondary);
|
| 151 |
+
display: flex;
|
| 152 |
+
align-items: center;
|
| 153 |
+
justify-content: center;
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
.image-preview img {
|
| 157 |
+
max-width: 100%;
|
| 158 |
+
max-height: 100%;
|
| 159 |
+
object-fit: contain;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
.btn {
|
| 163 |
+
display: inline-flex;
|
| 164 |
+
align-items: center;
|
| 165 |
+
justify-content: center;
|
| 166 |
+
gap: var(--spacing-xs);
|
| 167 |
+
padding: 0.875rem 1.75rem;
|
| 168 |
+
font-size: 1rem;
|
| 169 |
+
font-weight: 600;
|
| 170 |
+
border: none;
|
| 171 |
+
border-radius: var(--radius-md);
|
| 172 |
+
cursor: pointer;
|
| 173 |
+
transition: all 0.3s ease;
|
| 174 |
+
width: 100%;
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
.btn svg {
|
| 178 |
+
width: 20px;
|
| 179 |
+
height: 20px;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
.btn-primary {
|
| 183 |
+
background: var(--primary-gradient);
|
| 184 |
+
color: white;
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
.btn-primary:hover {
|
| 188 |
+
transform: translateY(-2px);
|
| 189 |
+
box-shadow: 0 8px 24px rgba(102, 126, 234, 0.4);
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
.btn-secondary {
|
| 193 |
+
background: var(--bg-secondary);
|
| 194 |
+
color: var(--text-primary);
|
| 195 |
+
border: 1px solid var(--border-color);
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
.btn-secondary:hover {
|
| 199 |
+
background: var(--bg-card);
|
| 200 |
+
border-color: #667eea;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
.results-section {
|
| 204 |
+
display: grid;
|
| 205 |
+
grid-template-columns: 2fr 1fr;
|
| 206 |
+
gap: var(--spacing-lg);
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
.prediction-main {
|
| 210 |
+
text-align: center;
|
| 211 |
+
padding: var(--spacing-lg);
|
| 212 |
+
background: var(--bg-secondary);
|
| 213 |
+
border-radius: var(--radius-md);
|
| 214 |
+
margin-bottom: var(--spacing-lg);
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
.prediction-label {
|
| 218 |
+
font-size: 0.875rem;
|
| 219 |
+
text-transform: uppercase;
|
| 220 |
+
letter-spacing: 0.1em;
|
| 221 |
+
color: var(--text-muted);
|
| 222 |
+
margin-bottom: var(--spacing-sm);
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
.prediction-class {
|
| 226 |
+
font-size: 2.5rem;
|
| 227 |
+
font-weight: 700;
|
| 228 |
+
background: var(--success-gradient);
|
| 229 |
+
-webkit-background-clip: text;
|
| 230 |
+
-webkit-text-fill-color: transparent;
|
| 231 |
+
background-clip: text;
|
| 232 |
+
margin-bottom: var(--spacing-md);
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
.confidence-badge {
|
| 236 |
+
display: inline-block;
|
| 237 |
+
padding: var(--spacing-xs) var(--spacing-md);
|
| 238 |
+
background: rgba(79, 172, 254, 0.2);
|
| 239 |
+
border: 1px solid rgba(79, 172, 254, 0.4);
|
| 240 |
+
border-radius: var(--radius-xl);
|
| 241 |
+
font-size: 0.875rem;
|
| 242 |
+
font-weight: 600;
|
| 243 |
+
color: #4facfe;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
.predictions-title {
|
| 247 |
+
font-size: 1.125rem;
|
| 248 |
+
font-weight: 600;
|
| 249 |
+
margin-bottom: var(--spacing-md);
|
| 250 |
+
color: var(--text-secondary);
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
.prediction-item {
|
| 254 |
+
display: flex;
|
| 255 |
+
align-items: center;
|
| 256 |
+
justify-content: space-between;
|
| 257 |
+
padding: var(--spacing-sm);
|
| 258 |
+
background: var(--bg-secondary);
|
| 259 |
+
border-radius: var(--radius-sm);
|
| 260 |
+
margin-bottom: var(--spacing-sm);
|
| 261 |
+
transition: all 0.3s ease;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
.prediction-item:hover {
|
| 265 |
+
background: var(--bg-card);
|
| 266 |
+
transform: translateX(4px);
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
.prediction-item-name {
|
| 270 |
+
font-weight: 500;
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
.prediction-item-bar {
|
| 274 |
+
flex: 1;
|
| 275 |
+
height: 8px;
|
| 276 |
+
background: var(--bg-primary);
|
| 277 |
+
border-radius: var(--radius-xl);
|
| 278 |
+
margin: 0 var(--spacing-md);
|
| 279 |
+
overflow: hidden;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
.prediction-item-fill {
|
| 283 |
+
height: 100%;
|
| 284 |
+
background: var(--primary-gradient);
|
| 285 |
+
border-radius: var(--radius-xl);
|
| 286 |
+
transition: width 0.8s ease;
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
.prediction-item-value {
|
| 290 |
+
font-weight: 600;
|
| 291 |
+
color: var(--text-secondary);
|
| 292 |
+
min-width: 50px;
|
| 293 |
+
text-align: right;
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
.info-title {
|
| 297 |
+
font-size: 1.125rem;
|
| 298 |
+
font-weight: 600;
|
| 299 |
+
margin-bottom: var(--spacing-md);
|
| 300 |
+
color: var(--text-secondary);
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
.classes-grid {
|
| 304 |
+
display: grid;
|
| 305 |
+
grid-template-columns: 1fr 1fr;
|
| 306 |
+
gap: var(--spacing-sm);
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
.class-item {
|
| 310 |
+
display: flex;
|
| 311 |
+
align-items: center;
|
| 312 |
+
gap: var(--spacing-sm);
|
| 313 |
+
padding: var(--spacing-sm);
|
| 314 |
+
background: var(--bg-secondary);
|
| 315 |
+
border-radius: var(--radius-sm);
|
| 316 |
+
transition: all 0.3s ease;
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
.class-item:hover {
|
| 320 |
+
background: var(--bg-card);
|
| 321 |
+
transform: translateX(4px);
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
.class-icon {
|
| 325 |
+
width: 32px;
|
| 326 |
+
height: 32px;
|
| 327 |
+
display: flex;
|
| 328 |
+
align-items: center;
|
| 329 |
+
justify-content: center;
|
| 330 |
+
background: var(--primary-gradient);
|
| 331 |
+
border-radius: var(--radius-sm);
|
| 332 |
+
font-weight: 600;
|
| 333 |
+
font-size: 0.875rem;
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
.class-name {
|
| 337 |
+
font-weight: 500;
|
| 338 |
+
text-transform: capitalize;
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
.loading-overlay {
|
| 342 |
+
position: fixed;
|
| 343 |
+
top: 0;
|
| 344 |
+
left: 0;
|
| 345 |
+
width: 100%;
|
| 346 |
+
height: 100%;
|
| 347 |
+
background: rgba(10, 14, 39, 0.9);
|
| 348 |
+
backdrop-filter: blur(8px);
|
| 349 |
+
display: flex;
|
| 350 |
+
flex-direction: column;
|
| 351 |
+
align-items: center;
|
| 352 |
+
justify-content: center;
|
| 353 |
+
z-index: 1000;
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
.spinner {
|
| 357 |
+
width: 64px;
|
| 358 |
+
height: 64px;
|
| 359 |
+
border: 4px solid var(--border-color);
|
| 360 |
+
border-top-color: #667eea;
|
| 361 |
+
border-radius: 50%;
|
| 362 |
+
animation: spin 1s linear infinite;
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
@keyframes spin {
|
| 366 |
+
to { transform: rotate(360deg); }
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
.loading-text {
|
| 370 |
+
margin-top: var(--spacing-md);
|
| 371 |
+
font-size: 1.125rem;
|
| 372 |
+
color: var(--text-secondary);
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
.footer {
|
| 376 |
+
text-align: center;
|
| 377 |
+
padding: var(--spacing-xl) 0;
|
| 378 |
+
color: var(--text-muted);
|
| 379 |
+
font-size: 0.875rem;
|
| 380 |
+
border-top: 1px solid var(--border-color);
|
| 381 |
+
margin-top: var(--spacing-xl);
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
@media (max-width: 1024px) {
|
| 385 |
+
.upload-section, .results-section {
|
| 386 |
+
grid-template-columns: 1fr;
|
| 387 |
+
}
|
| 388 |
+
.title {
|
| 389 |
+
font-size: 2.5rem;
|
| 390 |
+
}
|
| 391 |
+
}
|
templates/index.html
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>CIFAR-10 Image Classifier</title>
|
| 7 |
+
<meta name="description" content="Deep learning powered CIFAR-10 image classification using Convolutional Neural Networks">
|
| 8 |
+
<link rel="stylesheet" href="{{ url_for('static', filename='style.css') }}">
|
| 9 |
+
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 10 |
+
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 11 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap" rel="stylesheet">
|
| 12 |
+
</head>
|
| 13 |
+
<body>
|
| 14 |
+
<div class="container">
|
| 15 |
+
<header class="header">
|
| 16 |
+
<div class="header-content">
|
| 17 |
+
<h1 class="title">
|
| 18 |
+
<span class="gradient-text">CIFAR-10</span> Image Classifier
|
| 19 |
+
</h1>
|
| 20 |
+
<p class="subtitle">Powered by Deep Learning & Convolutional Neural Networks</p>
|
| 21 |
+
</div>
|
| 22 |
+
</header>
|
| 23 |
+
|
| 24 |
+
<main class="main-content">
|
| 25 |
+
<div class="upload-section">
|
| 26 |
+
<div class="card upload-card">
|
| 27 |
+
<h2 class="card-title">Upload Image</h2>
|
| 28 |
+
<div class="upload-area" id="uploadArea">
|
| 29 |
+
<svg class="upload-icon" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
| 30 |
+
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M7 16a4 4 0 01-.88-7.903A5 5 0 1115.9 6L16 6a5 5 0 011 9.9M15 13l-3-3m0 0l-3 3m3-3v12" />
|
| 31 |
+
</svg>
|
| 32 |
+
<p class="upload-text">Drag & drop an image here</p>
|
| 33 |
+
<p class="upload-subtext">or click to browse</p>
|
| 34 |
+
<input type="file" id="fileInput" accept="image/*" hidden>
|
| 35 |
+
</div>
|
| 36 |
+
|
| 37 |
+
<button class="btn btn-secondary" id="randomBtn">
|
| 38 |
+
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
| 39 |
+
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15" />
|
| 40 |
+
</svg>
|
| 41 |
+
Try Random Sample
|
| 42 |
+
</button>
|
| 43 |
+
</div>
|
| 44 |
+
|
| 45 |
+
<div class="card preview-card" id="previewCard" style="display: none;">
|
| 46 |
+
<h2 class="card-title">Preview</h2>
|
| 47 |
+
<div class="image-preview">
|
| 48 |
+
<img id="previewImage" src="" alt="Preview">
|
| 49 |
+
</div>
|
| 50 |
+
<button class="btn btn-primary" id="classifyBtn">
|
| 51 |
+
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
| 52 |
+
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z" />
|
| 53 |
+
</svg>
|
| 54 |
+
Classify Image
|
| 55 |
+
</button>
|
| 56 |
+
</div>
|
| 57 |
+
</div>
|
| 58 |
+
|
| 59 |
+
<div class="results-section" id="resultsSection" style="display: none;">
|
| 60 |
+
<div class="card results-card">
|
| 61 |
+
<h2 class="card-title">Classification Results</h2>
|
| 62 |
+
|
| 63 |
+
<div class="prediction-main">
|
| 64 |
+
<div class="prediction-label">Predicted Class</div>
|
| 65 |
+
<div class="prediction-class" id="predictedClass">-</div>
|
| 66 |
+
<div class="confidence-badge" id="confidenceBadge">
|
| 67 |
+
<span id="confidenceValue">0%</span> confidence
|
| 68 |
+
</div>
|
| 69 |
+
</div>
|
| 70 |
+
|
| 71 |
+
<div class="top-predictions">
|
| 72 |
+
<h3 class="predictions-title">Top 5 Predictions</h3>
|
| 73 |
+
<div id="top5Container"></div>
|
| 74 |
+
</div>
|
| 75 |
+
</div>
|
| 76 |
+
|
| 77 |
+
<div class="card info-card">
|
| 78 |
+
<h3 class="info-title">CIFAR-10 Classes</h3>
|
| 79 |
+
<div class="classes-grid">
|
| 80 |
+
{% for class_name in class_names %}
|
| 81 |
+
<div class="class-item">
|
| 82 |
+
<div class="class-icon">{{ loop.index0 }}</div>
|
| 83 |
+
<div class="class-name">{{ class_name }}</div>
|
| 84 |
+
</div>
|
| 85 |
+
{% endfor %}
|
| 86 |
+
</div>
|
| 87 |
+
</div>
|
| 88 |
+
</div>
|
| 89 |
+
|
| 90 |
+
<div class="loading-overlay" id="loadingOverlay" style="display: none;">
|
| 91 |
+
<div class="spinner"></div>
|
| 92 |
+
<p class="loading-text">Classifying image...</p>
|
| 93 |
+
</div>
|
| 94 |
+
</main>
|
| 95 |
+
|
| 96 |
+
<footer class="footer">
|
| 97 |
+
<p>Built with PyTorch & Flask | CNN Architecture with 3 Convolutional Blocks</p>
|
| 98 |
+
</footer>
|
| 99 |
+
</div>
|
| 100 |
+
|
| 101 |
+
<script src="{{ url_for('static', filename='script.js') }}"></script>
|
| 102 |
+
</body>
|
| 103 |
+
</html>
|
train.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training script for CIFAR-10 CNN
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
|
| 11 |
+
import config
|
| 12 |
+
from model import get_model, count_parameters
|
| 13 |
+
from data_loader import get_data_loaders
|
| 14 |
+
from utils import save_checkpoint, load_checkpoint, plot_training_history
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def train_epoch(model, train_loader, criterion, optimizer, device):
|
| 18 |
+
"""
|
| 19 |
+
Train the model for one epoch
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
model: PyTorch model
|
| 23 |
+
train_loader: Training data loader
|
| 24 |
+
criterion: Loss function
|
| 25 |
+
optimizer: Optimizer
|
| 26 |
+
device: Device to train on
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
tuple: (average_loss, accuracy)
|
| 30 |
+
"""
|
| 31 |
+
model.train()
|
| 32 |
+
running_loss = 0.0
|
| 33 |
+
correct = 0
|
| 34 |
+
total = 0
|
| 35 |
+
|
| 36 |
+
pbar = tqdm(train_loader, desc='Training')
|
| 37 |
+
for inputs, labels in pbar:
|
| 38 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
| 39 |
+
|
| 40 |
+
# Zero the parameter gradients
|
| 41 |
+
optimizer.zero_grad()
|
| 42 |
+
|
| 43 |
+
# Forward pass
|
| 44 |
+
outputs = model(inputs)
|
| 45 |
+
loss = criterion(outputs, labels)
|
| 46 |
+
|
| 47 |
+
# Backward pass and optimize
|
| 48 |
+
loss.backward()
|
| 49 |
+
optimizer.step()
|
| 50 |
+
|
| 51 |
+
# Statistics
|
| 52 |
+
running_loss += loss.item()
|
| 53 |
+
_, predicted = outputs.max(1)
|
| 54 |
+
total += labels.size(0)
|
| 55 |
+
correct += predicted.eq(labels).sum().item()
|
| 56 |
+
|
| 57 |
+
# Update progress bar
|
| 58 |
+
pbar.set_postfix({
|
| 59 |
+
'loss': f'{running_loss / (pbar.n + 1):.4f}',
|
| 60 |
+
'acc': f'{100. * correct / total:.2f}%'
|
| 61 |
+
})
|
| 62 |
+
|
| 63 |
+
epoch_loss = running_loss / len(train_loader)
|
| 64 |
+
epoch_acc = 100. * correct / total
|
| 65 |
+
|
| 66 |
+
return epoch_loss, epoch_acc
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def validate(model, test_loader, criterion, device):
|
| 70 |
+
"""
|
| 71 |
+
Validate the model
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
model: PyTorch model
|
| 75 |
+
test_loader: Test data loader
|
| 76 |
+
criterion: Loss function
|
| 77 |
+
device: Device to validate on
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
tuple: (average_loss, accuracy)
|
| 81 |
+
"""
|
| 82 |
+
model.eval()
|
| 83 |
+
running_loss = 0.0
|
| 84 |
+
correct = 0
|
| 85 |
+
total = 0
|
| 86 |
+
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
pbar = tqdm(test_loader, desc='Validation')
|
| 89 |
+
for inputs, labels in pbar:
|
| 90 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
| 91 |
+
|
| 92 |
+
# Forward pass
|
| 93 |
+
outputs = model(inputs)
|
| 94 |
+
loss = criterion(outputs, labels)
|
| 95 |
+
|
| 96 |
+
# Statistics
|
| 97 |
+
running_loss += loss.item()
|
| 98 |
+
_, predicted = outputs.max(1)
|
| 99 |
+
total += labels.size(0)
|
| 100 |
+
correct += predicted.eq(labels).sum().item()
|
| 101 |
+
|
| 102 |
+
# Update progress bar
|
| 103 |
+
pbar.set_postfix({
|
| 104 |
+
'loss': f'{running_loss / (pbar.n + 1):.4f}',
|
| 105 |
+
'acc': f'{100. * correct / total:.2f}%'
|
| 106 |
+
})
|
| 107 |
+
|
| 108 |
+
epoch_loss = running_loss / len(test_loader)
|
| 109 |
+
epoch_acc = 100. * correct / total
|
| 110 |
+
|
| 111 |
+
return epoch_loss, epoch_acc
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def train():
|
| 115 |
+
"""
|
| 116 |
+
Main training function
|
| 117 |
+
"""
|
| 118 |
+
# Create directories
|
| 119 |
+
os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)
|
| 120 |
+
os.makedirs(config.PLOTS_DIR, exist_ok=True)
|
| 121 |
+
|
| 122 |
+
# Get data loaders
|
| 123 |
+
print("Loading CIFAR-10 dataset...")
|
| 124 |
+
train_loader, test_loader = get_data_loaders()
|
| 125 |
+
print(f"Training samples: {len(train_loader.dataset)}")
|
| 126 |
+
print(f"Test samples: {len(test_loader.dataset)}")
|
| 127 |
+
|
| 128 |
+
# Create model
|
| 129 |
+
print(f"\nCreating model on device: {config.DEVICE}")
|
| 130 |
+
model = get_model(num_classes=config.NUM_CLASSES, device=config.DEVICE)
|
| 131 |
+
print(f"Model parameters: {count_parameters(model):,}")
|
| 132 |
+
|
| 133 |
+
# Loss function and optimizer
|
| 134 |
+
criterion = nn.CrossEntropyLoss()
|
| 135 |
+
optimizer = optim.SGD(
|
| 136 |
+
model.parameters(),
|
| 137 |
+
lr=config.LEARNING_RATE,
|
| 138 |
+
momentum=config.MOMENTUM,
|
| 139 |
+
weight_decay=config.WEIGHT_DECAY
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Learning rate scheduler
|
| 143 |
+
scheduler = None
|
| 144 |
+
if config.USE_SCHEDULER:
|
| 145 |
+
scheduler = optim.lr_scheduler.StepLR(
|
| 146 |
+
optimizer,
|
| 147 |
+
step_size=config.SCHEDULER_STEP_SIZE,
|
| 148 |
+
gamma=config.SCHEDULER_GAMMA
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Training history
|
| 152 |
+
history = {
|
| 153 |
+
'train_loss': [],
|
| 154 |
+
'train_acc': [],
|
| 155 |
+
'val_loss': [],
|
| 156 |
+
'val_acc': []
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
best_acc = 0.0
|
| 160 |
+
start_epoch = 0
|
| 161 |
+
|
| 162 |
+
# Training loop
|
| 163 |
+
print(f"\nStarting training for {config.EPOCHS} epochs...")
|
| 164 |
+
for epoch in range(start_epoch, config.EPOCHS):
|
| 165 |
+
print(f"\nEpoch {epoch + 1}/{config.EPOCHS}")
|
| 166 |
+
print("-" * 50)
|
| 167 |
+
|
| 168 |
+
# Train
|
| 169 |
+
train_loss, train_acc = train_epoch(
|
| 170 |
+
model, train_loader, criterion, optimizer, config.DEVICE
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Validate
|
| 174 |
+
val_loss, val_acc = validate(
|
| 175 |
+
model, test_loader, criterion, config.DEVICE
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Update learning rate
|
| 179 |
+
if scheduler:
|
| 180 |
+
scheduler.step()
|
| 181 |
+
current_lr = scheduler.get_last_lr()[0]
|
| 182 |
+
print(f"Learning rate: {current_lr:.6f}")
|
| 183 |
+
|
| 184 |
+
# Save history
|
| 185 |
+
history['train_loss'].append(train_loss)
|
| 186 |
+
history['train_acc'].append(train_acc)
|
| 187 |
+
history['val_loss'].append(val_loss)
|
| 188 |
+
history['val_acc'].append(val_acc)
|
| 189 |
+
|
| 190 |
+
# Print epoch summary
|
| 191 |
+
print(f"\nEpoch {epoch + 1} Summary:")
|
| 192 |
+
print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
|
| 193 |
+
print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
|
| 194 |
+
|
| 195 |
+
# Save best model
|
| 196 |
+
if val_acc > best_acc:
|
| 197 |
+
best_acc = val_acc
|
| 198 |
+
save_checkpoint(
|
| 199 |
+
model, optimizer, epoch, val_acc,
|
| 200 |
+
config.BEST_MODEL_PATH
|
| 201 |
+
)
|
| 202 |
+
print(f"✓ Best model saved with accuracy: {best_acc:.2f}%")
|
| 203 |
+
|
| 204 |
+
# Save last model
|
| 205 |
+
save_checkpoint(
|
| 206 |
+
model, optimizer, epoch, val_acc,
|
| 207 |
+
config.LAST_MODEL_PATH
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Plot training history
|
| 211 |
+
plot_training_history(history, config.PLOTS_DIR)
|
| 212 |
+
|
| 213 |
+
print("\n" + "=" * 50)
|
| 214 |
+
print(f"Training completed!")
|
| 215 |
+
print(f"Best validation accuracy: {best_acc:.2f}%")
|
| 216 |
+
print("=" * 50)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
if __name__ == '__main__':
|
| 220 |
+
train()
|
train_log.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
utils.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for the CIFAR-10 CNN project
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
import matplotlib
|
| 7 |
+
matplotlib.use('Agg')
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import numpy as np
|
| 10 |
+
from sklearn.metrics import confusion_matrix, classification_report
|
| 11 |
+
import seaborn as sns
|
| 12 |
+
|
| 13 |
+
import config
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def save_checkpoint(model, optimizer, epoch, accuracy, filepath):
|
| 17 |
+
"""
|
| 18 |
+
Save model checkpoint
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
model: PyTorch model
|
| 22 |
+
optimizer: Optimizer
|
| 23 |
+
epoch: Current epoch
|
| 24 |
+
accuracy: Current accuracy
|
| 25 |
+
filepath: Path to save checkpoint
|
| 26 |
+
"""
|
| 27 |
+
checkpoint = {
|
| 28 |
+
'epoch': epoch,
|
| 29 |
+
'model_state_dict': model.state_dict(),
|
| 30 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 31 |
+
'accuracy': accuracy
|
| 32 |
+
}
|
| 33 |
+
torch.save(checkpoint, filepath)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_checkpoint(model, optimizer, filepath):
|
| 37 |
+
"""
|
| 38 |
+
Load model checkpoint
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
model: PyTorch model
|
| 42 |
+
optimizer: Optimizer
|
| 43 |
+
filepath: Path to checkpoint file
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
tuple: (epoch, accuracy)
|
| 47 |
+
"""
|
| 48 |
+
checkpoint = torch.load(filepath, map_location=config.DEVICE)
|
| 49 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 50 |
+
if optimizer:
|
| 51 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 52 |
+
epoch = checkpoint['epoch']
|
| 53 |
+
accuracy = checkpoint['accuracy']
|
| 54 |
+
return epoch, accuracy
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def plot_training_history(history, save_dir):
|
| 58 |
+
"""
|
| 59 |
+
Plot training history
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
history: Dictionary containing training history
|
| 63 |
+
save_dir: Directory to save plots
|
| 64 |
+
"""
|
| 65 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
| 66 |
+
|
| 67 |
+
# Plot loss
|
| 68 |
+
ax1.plot(history['train_loss'], label='Train Loss', linewidth=2)
|
| 69 |
+
ax1.plot(history['val_loss'], label='Validation Loss', linewidth=2)
|
| 70 |
+
ax1.set_xlabel('Epoch', fontsize=12)
|
| 71 |
+
ax1.set_ylabel('Loss', fontsize=12)
|
| 72 |
+
ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
|
| 73 |
+
ax1.legend(fontsize=10)
|
| 74 |
+
ax1.grid(True, alpha=0.3)
|
| 75 |
+
|
| 76 |
+
# Plot accuracy
|
| 77 |
+
ax2.plot(history['train_acc'], label='Train Accuracy', linewidth=2)
|
| 78 |
+
ax2.plot(history['val_acc'], label='Validation Accuracy', linewidth=2)
|
| 79 |
+
ax2.set_xlabel('Epoch', fontsize=12)
|
| 80 |
+
ax2.set_ylabel('Accuracy (%)', fontsize=12)
|
| 81 |
+
ax2.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
|
| 82 |
+
ax2.legend(fontsize=10)
|
| 83 |
+
ax2.grid(True, alpha=0.3)
|
| 84 |
+
|
| 85 |
+
plt.tight_layout()
|
| 86 |
+
plt.savefig(os.path.join(save_dir, 'training_history.png'), dpi=300, bbox_inches='tight')
|
| 87 |
+
plt.close()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def plot_confusion_matrix(y_true, y_pred, save_path):
|
| 91 |
+
"""
|
| 92 |
+
Plot confusion matrix
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
y_true: True labels
|
| 96 |
+
y_pred: Predicted labels
|
| 97 |
+
save_path: Path to save the plot
|
| 98 |
+
"""
|
| 99 |
+
cm = confusion_matrix(y_true, y_pred)
|
| 100 |
+
|
| 101 |
+
plt.figure(figsize=(12, 10))
|
| 102 |
+
sns.heatmap(
|
| 103 |
+
cm, annot=True, fmt='d', cmap='Blues',
|
| 104 |
+
xticklabels=config.CLASS_NAMES,
|
| 105 |
+
yticklabels=config.CLASS_NAMES,
|
| 106 |
+
cbar_kws={'label': 'Count'}
|
| 107 |
+
)
|
| 108 |
+
plt.xlabel('Predicted Label', fontsize=12)
|
| 109 |
+
plt.ylabel('True Label', fontsize=12)
|
| 110 |
+
plt.title('Confusion Matrix', fontsize=14, fontweight='bold')
|
| 111 |
+
plt.tight_layout()
|
| 112 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 113 |
+
plt.close()
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def print_classification_report(y_true, y_pred):
|
| 117 |
+
"""
|
| 118 |
+
Print classification report
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
y_true: True labels
|
| 122 |
+
y_pred: Predicted labels
|
| 123 |
+
"""
|
| 124 |
+
report = classification_report(
|
| 125 |
+
y_true, y_pred,
|
| 126 |
+
target_names=config.CLASS_NAMES,
|
| 127 |
+
digits=4
|
| 128 |
+
)
|
| 129 |
+
print("\nClassification Report:")
|
| 130 |
+
print("=" * 80)
|
| 131 |
+
print(report)
|
| 132 |
+
print("=" * 80)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def visualize_predictions(model, test_loader, device, num_images=16):
|
| 136 |
+
"""
|
| 137 |
+
Visualize model predictions
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
model: PyTorch model
|
| 141 |
+
test_loader: Test data loader
|
| 142 |
+
device: Device to run on
|
| 143 |
+
num_images: Number of images to visualize
|
| 144 |
+
"""
|
| 145 |
+
model.eval()
|
| 146 |
+
|
| 147 |
+
# Get a batch of images
|
| 148 |
+
images, labels = next(iter(test_loader))
|
| 149 |
+
images, labels = images[:num_images], labels[:num_images]
|
| 150 |
+
images_device = images.to(device)
|
| 151 |
+
|
| 152 |
+
# Get predictions
|
| 153 |
+
with torch.no_grad():
|
| 154 |
+
outputs = model(images_device)
|
| 155 |
+
_, predicted = outputs.max(1)
|
| 156 |
+
|
| 157 |
+
# Plot
|
| 158 |
+
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
|
| 159 |
+
axes = axes.ravel()
|
| 160 |
+
|
| 161 |
+
for idx in range(num_images):
|
| 162 |
+
# Denormalize image
|
| 163 |
+
img = images[idx].cpu().numpy().transpose(1, 2, 0)
|
| 164 |
+
mean = np.array([0.4914, 0.4822, 0.4465])
|
| 165 |
+
std = np.array([0.2470, 0.2435, 0.2616])
|
| 166 |
+
img = img * std + mean
|
| 167 |
+
img = np.clip(img, 0, 1)
|
| 168 |
+
|
| 169 |
+
# Plot
|
| 170 |
+
axes[idx].imshow(img)
|
| 171 |
+
axes[idx].axis('off')
|
| 172 |
+
|
| 173 |
+
true_label = config.CLASS_NAMES[labels[idx]]
|
| 174 |
+
pred_label = config.CLASS_NAMES[predicted[idx].cpu()]
|
| 175 |
+
|
| 176 |
+
color = 'green' if labels[idx] == predicted[idx].cpu() else 'red'
|
| 177 |
+
axes[idx].set_title(
|
| 178 |
+
f'True: {true_label}\nPred: {pred_label}',
|
| 179 |
+
color=color, fontsize=10
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
plt.tight_layout()
|
| 183 |
+
plt.savefig(os.path.join(config.PLOTS_DIR, 'predictions.png'), dpi=300, bbox_inches='tight')
|
| 184 |
+
plt.close()
|
| 185 |
+
|
| 186 |
+
print(f"Predictions visualization saved to {config.PLOTS_DIR}/predictions.png")
|