N-I-M-I commited on
Commit
233caeb
·
verified ·
1 Parent(s): cd7755a

Upload folder using huggingface_hub

Browse files
Files changed (15) hide show
  1. .gitignore +63 -0
  2. README.md +173 -0
  3. app.py +166 -0
  4. config.py +50 -0
  5. data_loader.py +100 -0
  6. debug_train.py +88 -0
  7. evaluate.py +96 -0
  8. model.py +93 -0
  9. requirements.txt +8 -0
  10. static/script.js +163 -0
  11. static/style.css +391 -0
  12. templates/index.html +103 -0
  13. train.py +220 -0
  14. train_log.txt +0 -0
  15. utils.py +186 -0
.gitignore ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual Environment
24
+ venv/
25
+ ENV/
26
+ env/
27
+ .venv
28
+
29
+ # PyTorch
30
+ *.pth
31
+ *.pt
32
+ checkpoints/
33
+ !checkpoints/.gitkeep
34
+
35
+ # Data
36
+ data/
37
+ *.pkl
38
+ *.pickle
39
+
40
+ # Plots and visualizations
41
+ plots/
42
+ !plots/.gitkeep
43
+
44
+ # IDE
45
+ .vscode/
46
+ .idea/
47
+ *.swp
48
+ *.swo
49
+ *~
50
+
51
+ # OS
52
+ .DS_Store
53
+ Thumbs.db
54
+
55
+ # Jupyter Notebook
56
+ .ipynb_checkpoints
57
+
58
+ # Flask
59
+ instance/
60
+ .webassets-cache
61
+
62
+ # Logs
63
+ *.log
README.md CHANGED
@@ -1,3 +1,176 @@
1
  ---
2
  license: mit
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ datasets:
4
+ - cifar10
5
+ metrics:
6
+ - accuracy
7
+ library_name: pytorch
8
+ tags:
9
+ - image-classification
10
+ - sequence-classification
11
  ---
12
+
13
+ # CIFAR-10 RNN Image Classifier
14
+
15
+ An end-to-end deep learning project for classifying CIFAR-10 images using a Recurrent Neural Network (LSTM) built with PyTorch. Includes a modern web interface for real-time image classification.
16
+
17
+ ## 🌟 Features
18
+
19
+ - **Custom RNN Architecture**: Bidirectional LSTM layers with dropout
20
+ - **Complete Training Pipeline**: Automated training with validation, checkpointing, and visualization
21
+ - **Comprehensive Evaluation**: Confusion matrix, classification reports, and prediction visualizations
22
+ - **Modern Web Interface**: Beautiful Flask web app for real-time image classification
23
+ - **CIFAR-10 Dataset**: Automatically downloads and preprocesses the dataset
24
+
25
+ ## 📊 Model Architecture
26
+
27
+ The model treats each 32x32 RGB image as a sequence of 32 rows, where each row has 96 features (32 pixels * 3 channels).
28
+
29
+ ```
30
+ Input (Batch, 3, 32, 32)
31
+
32
+ Reshape (Batch, 32, 96)
33
+
34
+ Bidirectional LSTM (Hidden: 256, Layers: 2, Dropout: 0.2)
35
+
36
+ Last Time Step Output
37
+
38
+ Fully Connected (512) → ReLU → Dropout(0.3)
39
+
40
+ Output (10 classes)
41
+ ```
42
+
43
+ ## 🚀 Quick Start
44
+
45
+ ### 1. Install Dependencies
46
+
47
+ ```bash
48
+ pip install -r requirements.txt
49
+ ```
50
+
51
+ ### 2. Train the Model
52
+
53
+ ```bash
54
+ python train.py
55
+ ```
56
+
57
+ This will:
58
+ - Download the CIFAR-10 dataset automatically
59
+ - Train the model for 50 epochs
60
+ - Save checkpoints in `./checkpoints/`
61
+ - Generate training plots in `./plots/`
62
+
63
+ ### 3. Evaluate the Model
64
+
65
+ ```bash
66
+ python evaluate.py
67
+ ```
68
+
69
+ This will:
70
+ - Load the best model checkpoint
71
+ - Evaluate on the test set
72
+ - Generate confusion matrix
73
+ - Create prediction visualizations
74
+
75
+ ### 4. Run the Web Application
76
+
77
+ ```bash
78
+ python app.py
79
+ ```
80
+
81
+ Then open your browser and navigate to `http://localhost:5000`
82
+
83
+ ## 📁 Project Structure
84
+
85
+ ```
86
+ CNN/
87
+ ├── config.py # Configuration and hyperparameters
88
+ ├── data_loader.py # Data loading and preprocessing
89
+ ├── model.py # CNN model architecture
90
+ ├── train.py # Training script
91
+ ├── evaluate.py # Evaluation script
92
+ ├── utils.py # Utility functions
93
+ ├── app.py # Flask web application
94
+ ├── requirements.txt # Python dependencies
95
+ ├── templates/
96
+ │ └── index.html # Web interface HTML
97
+ ├── static/
98
+ │ ├── style.css # Web interface CSS
99
+ │ └── script.js # Web interface JavaScript
100
+ ├── checkpoints/ # Model checkpoints (created during training)
101
+ ├── plots/ # Training visualizations (created during training)
102
+ └── data/ # CIFAR-10 dataset (downloaded automatically)
103
+ ```
104
+
105
+ ## 🎯 CIFAR-10 Classes
106
+
107
+ The model classifies images into 10 categories:
108
+ 1. Airplane
109
+ 2. Automobile
110
+ 3. Bird
111
+ 4. Cat
112
+ 5. Deer
113
+ 6. Dog
114
+ 7. Frog
115
+ 8. Horse
116
+ 9. Ship
117
+ 10. Truck
118
+
119
+ ## ⚙️ Configuration
120
+
121
+ Edit `config.py` to customize:
122
+ - **Training**: epochs, batch size, learning rate
123
+ - **Model**: number of classes, architecture parameters
124
+ - **Data**: augmentation settings, normalization values
125
+ - **Paths**: checkpoint and plot directories
126
+
127
+ ## 📈 Training Details
128
+
129
+ - **Optimizer**: SGD with momentum (0.9) and weight decay (5e-4)
130
+ - **Loss Function**: Cross-Entropy Loss
131
+ - **Learning Rate**: 0.001 with step decay
132
+ - **Batch Size**: 128
133
+ - **Data Augmentation**: Random crop and horizontal flip
134
+ - **Regularization**: Batch normalization and dropout
135
+
136
+ ## 🎨 Web Interface Features
137
+
138
+ - **Drag & Drop**: Upload images via drag-and-drop
139
+ - **Random Samples**: Test with random CIFAR-10 images
140
+ - **Real-time Classification**: Instant predictions with confidence scores
141
+ - **Top-5 Predictions**: View probability distribution
142
+ - **Modern UI**: Dark theme with smooth animations
143
+
144
+ ## 📊 Expected Performance
145
+
146
+ With the default configuration, the model typically achieves:
147
+ - **Training Accuracy**: ~90%
148
+ - **Validation Accuracy**: ~85%
149
+
150
+ ## 🛠️ Requirements
151
+
152
+ - Python 3.7+
153
+ - PyTorch 2.0+
154
+ - torchvision
155
+ - Flask
156
+ - NumPy
157
+ - Matplotlib
158
+ - scikit-learn
159
+ - Pillow
160
+ - tqdm
161
+
162
+ ## 📝 License
163
+
164
+ This project is open source and available for educational purposes.
165
+
166
+ ## 🤝 Contributing
167
+
168
+ Feel free to fork this project and submit pull requests for improvements!
169
+
170
+ ## 📧 Contact
171
+
172
+ For questions or feedback, please open an issue on the repository.
173
+
174
+ ---
175
+
176
+ **Built with ❤️ using PyTorch and Flask**
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Flask web application for CIFAR-10 image classification
3
+ """
4
+ import os
5
+ import io
6
+ import base64
7
+ import torch
8
+ from PIL import Image
9
+ from flask import Flask, render_template, request, jsonify
10
+ import torchvision.transforms as transforms
11
+ import numpy as np
12
+
13
+ import config
14
+ from model import get_model
15
+ from utils import load_checkpoint
16
+
17
+ app = Flask(__name__)
18
+
19
+ # Global model variable
20
+ model = None
21
+
22
+
23
+ def load_model():
24
+ """Load the trained model"""
25
+ global model
26
+
27
+ if not os.path.exists(config.BEST_MODEL_PATH):
28
+ print(f"Warning: Model checkpoint not found at {config.BEST_MODEL_PATH}")
29
+ return False
30
+
31
+ model = get_model(num_classes=config.NUM_CLASSES, device=config.DEVICE)
32
+ epoch, accuracy = load_checkpoint(model, None, config.BEST_MODEL_PATH)
33
+ model.eval()
34
+
35
+ print(f"Model loaded from epoch {epoch + 1} with accuracy: {accuracy:.2f}%")
36
+ return True
37
+
38
+
39
+ def preprocess_image(image):
40
+ """
41
+ Preprocess image for model prediction
42
+
43
+ Args:
44
+ image: PIL Image
45
+
46
+ Returns:
47
+ torch.Tensor: Preprocessed image tensor
48
+ """
49
+ transform = transforms.Compose([
50
+ transforms.Resize((32, 32)),
51
+ transforms.ToTensor(),
52
+ transforms.Normalize(
53
+ mean=[0.4914, 0.4822, 0.4465],
54
+ std=[0.2470, 0.2435, 0.2616]
55
+ )
56
+ ])
57
+
58
+ return transform(image).unsqueeze(0)
59
+
60
+
61
+ @app.route('/')
62
+ def index():
63
+ """Render the main page"""
64
+ return render_template('index.html', class_names=config.CLASS_NAMES)
65
+
66
+
67
+ @app.route('/predict', methods=['POST'])
68
+ def predict():
69
+ """Handle prediction requests"""
70
+ if model is None:
71
+ return jsonify({'error': 'Model not loaded'}), 500
72
+
73
+ try:
74
+ # Get image from request
75
+ if 'file' not in request.files:
76
+ return jsonify({'error': 'No file provided'}), 400
77
+
78
+ file = request.files['file']
79
+
80
+ if file.filename == '':
81
+ return jsonify({'error': 'No file selected'}), 400
82
+
83
+ # Read and preprocess image
84
+ image = Image.open(file.stream).convert('RGB')
85
+ input_tensor = preprocess_image(image).to(config.DEVICE)
86
+
87
+ # Make prediction
88
+ with torch.no_grad():
89
+ output = model(input_tensor)
90
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
91
+ confidence, predicted = torch.max(probabilities, 0)
92
+
93
+ # Get top 5 predictions
94
+ top5_prob, top5_idx = torch.topk(probabilities, 5)
95
+ top5_predictions = [
96
+ {
97
+ 'class': config.CLASS_NAMES[idx],
98
+ 'probability': float(prob * 100)
99
+ }
100
+ for idx, prob in zip(top5_idx.cpu().numpy(), top5_prob.cpu().numpy())
101
+ ]
102
+
103
+ # Prepare response
104
+ response = {
105
+ 'predicted_class': config.CLASS_NAMES[predicted.item()],
106
+ 'confidence': float(confidence.item() * 100),
107
+ 'top5_predictions': top5_predictions
108
+ }
109
+
110
+ return jsonify(response)
111
+
112
+ except Exception as e:
113
+ return jsonify({'error': str(e)}), 500
114
+
115
+
116
+ @app.route('/random_sample', methods=['GET'])
117
+ def random_sample():
118
+ """Get a random sample from CIFAR-10 test set or generate dummy if missing"""
119
+ try:
120
+ from data_loader import get_data_loaders
121
+ # Check if dataset exists
122
+ dataset_path = os.path.join(config.DATA_DIR, 'cifar-10-batches-py')
123
+
124
+ if os.path.exists(dataset_path):
125
+ _, test_loader = get_data_loaders()
126
+ dataset = test_loader.dataset
127
+ idx = np.random.randint(0, len(dataset))
128
+ image, label = dataset[idx]
129
+
130
+ # Denormalize image
131
+ mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
132
+ std = torch.tensor([0.2470, 0.2435, 0.2616]).view(3, 1, 1)
133
+ image_denorm = image * std + mean
134
+ image_denorm = torch.clamp(image_denorm, 0, 1)
135
+
136
+ # Convert to PIL Image
137
+ image_np = (image_denorm.numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
138
+ label_name = config.CLASS_NAMES[label]
139
+ else:
140
+ # Generate dummy image for demonstration
141
+ image_np = np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8)
142
+ label_name = "Dummy Sample (Dataset still downloading)"
143
+
144
+ pil_image = Image.fromarray(image_np)
145
+
146
+ # Convert to base64
147
+ buffered = io.BytesIO()
148
+ pil_image.save(buffered, format="PNG")
149
+ img_str = base64.b64encode(buffered.getvalue()).decode()
150
+
151
+ return jsonify({
152
+ 'image': f'data:image/png;base64,{img_str}',
153
+ 'true_label': label_name
154
+ })
155
+
156
+ except Exception as e:
157
+ return jsonify({'error': str(e)}), 500
158
+
159
+
160
+ if __name__ == '__main__':
161
+ # Load model
162
+ if load_model():
163
+ print("Starting Flask application...")
164
+ app.run(debug=True, host='0.0.0.0', port=5000)
165
+ else:
166
+ print("Failed to load model. Please train the model first using train.py")
config.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration file for CIFAR-10 CNN project
3
+ """
4
+ import torch
5
+
6
+ # Data configuration
7
+ DATA_DIR = './data'
8
+ BATCH_SIZE = 32
9
+ NUM_WORKERS = 0 # Set to 0 for better stability on some systems without GPU
10
+
11
+ # Model configuration
12
+ NUM_CLASSES = 10
13
+ INPUT_CHANNELS = 3
14
+ IMAGE_SIZE = 32
15
+ HIDDEN_SIZE = 256
16
+ NUM_LAYERS = 2
17
+ RNN_DROPOUT = 0.2
18
+
19
+ # Training configuration
20
+ EPOCHS = 5
21
+ LEARNING_RATE = 0.01 # Increased slightly for faster convergence in few epochs
22
+ WEIGHT_DECAY = 5e-4
23
+ MOMENTUM = 0.9
24
+
25
+ # Device configuration
26
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
+
28
+ # Checkpoint configuration
29
+ CHECKPOINT_DIR = './checkpoints'
30
+ BEST_MODEL_PATH = './checkpoints/best_model.pth'
31
+ LAST_MODEL_PATH = './checkpoints/last_model.pth'
32
+
33
+ # Visualization configuration
34
+ PLOTS_DIR = './plots'
35
+
36
+ # CIFAR-10 class names
37
+ CLASS_NAMES = [
38
+ 'airplane', 'automobile', 'bird', 'cat', 'deer',
39
+ 'dog', 'frog', 'horse', 'ship', 'truck'
40
+ ]
41
+
42
+ # Data augmentation settings
43
+ USE_AUGMENTATION = True
44
+ RANDOM_CROP_PADDING = 4
45
+ RANDOM_HORIZONTAL_FLIP = 0.5
46
+
47
+ # Learning rate scheduler
48
+ USE_SCHEDULER = True
49
+ SCHEDULER_STEP_SIZE = 20
50
+ SCHEDULER_GAMMA = 0.1
data_loader.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data loading and preprocessing for CIFAR-10 dataset
3
+ """
4
+ import torch
5
+ from torch.utils.data import DataLoader
6
+ from torchvision import datasets, transforms
7
+ import config
8
+
9
+
10
+ def get_transforms(train=True):
11
+ """
12
+ Get data transformations for training or testing
13
+
14
+ Args:
15
+ train (bool): If True, returns training transforms with augmentation
16
+
17
+ Returns:
18
+ torchvision.transforms.Compose: Composed transforms
19
+ """
20
+ if train and config.USE_AUGMENTATION:
21
+ transform = transforms.Compose([
22
+ transforms.RandomCrop(32, padding=config.RANDOM_CROP_PADDING),
23
+ transforms.RandomHorizontalFlip(p=config.RANDOM_HORIZONTAL_FLIP),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(
26
+ mean=[0.4914, 0.4822, 0.4465],
27
+ std=[0.2470, 0.2435, 0.2616]
28
+ )
29
+ ])
30
+ else:
31
+ transform = transforms.Compose([
32
+ transforms.ToTensor(),
33
+ transforms.Normalize(
34
+ mean=[0.4914, 0.4822, 0.4465],
35
+ std=[0.2470, 0.2435, 0.2616]
36
+ )
37
+ ])
38
+
39
+ return transform
40
+
41
+
42
+ def get_data_loaders():
43
+ """
44
+ Create train and test data loaders for CIFAR-10
45
+
46
+ Returns:
47
+ tuple: (train_loader, test_loader)
48
+ """
49
+ # Get transforms
50
+ train_transform = get_transforms(train=True)
51
+ test_transform = get_transforms(train=False)
52
+
53
+ # Load datasets
54
+ train_dataset = datasets.CIFAR10(
55
+ root=config.DATA_DIR,
56
+ train=True,
57
+ download=True,
58
+ transform=train_transform
59
+ )
60
+
61
+ test_dataset = datasets.CIFAR10(
62
+ root=config.DATA_DIR,
63
+ train=False,
64
+ download=True,
65
+ transform=test_transform
66
+ )
67
+
68
+ # Create data loaders
69
+ train_loader = DataLoader(
70
+ train_dataset,
71
+ batch_size=config.BATCH_SIZE,
72
+ shuffle=True,
73
+ num_workers=config.NUM_WORKERS,
74
+ pin_memory=True if config.DEVICE.type == 'cuda' else False
75
+ )
76
+
77
+ test_loader = DataLoader(
78
+ test_dataset,
79
+ batch_size=config.BATCH_SIZE,
80
+ shuffle=False,
81
+ num_workers=config.NUM_WORKERS,
82
+ pin_memory=True if config.DEVICE.type == 'cuda' else False
83
+ )
84
+
85
+ return train_loader, test_loader
86
+
87
+
88
+ def denormalize(tensor):
89
+ """
90
+ Denormalize a tensor image for visualization
91
+
92
+ Args:
93
+ tensor: Normalized tensor image
94
+
95
+ Returns:
96
+ tensor: Denormalized tensor image
97
+ """
98
+ mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
99
+ std = torch.tensor([0.2470, 0.2435, 0.2616]).view(3, 1, 1)
100
+ return tensor * std + mean
debug_train.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Debug training script using dummy data to test the pipeline without downloading CIFAR-10
3
+ """
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ from torch.utils.data import DataLoader, TensorDataset
9
+ from tqdm import tqdm
10
+
11
+ import config
12
+ from model import get_model, count_parameters
13
+ from utils import save_checkpoint, plot_training_history
14
+
15
+ def get_dummy_data_loaders():
16
+ """Create dummy data loaders for testing"""
17
+ # Create random images (32x32) and labels (0-9)
18
+ train_size = 100
19
+ test_size = 20
20
+
21
+ train_images = torch.randn(train_size, 3, 32, 32)
22
+ train_labels = torch.randint(0, 10, (train_size,))
23
+
24
+ test_images = torch.randn(test_size, 3, 32, 32)
25
+ test_labels = torch.randint(0, 10, (test_size,))
26
+
27
+ train_dataset = TensorDataset(train_images, train_labels)
28
+ test_dataset = TensorDataset(test_images, test_labels)
29
+
30
+ train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)
31
+ test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE, shuffle=False)
32
+
33
+ return train_loader, test_loader
34
+
35
+ def debug_train():
36
+ """Debug training function"""
37
+ os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)
38
+ os.makedirs(config.PLOTS_DIR, exist_ok=True)
39
+
40
+ print("Creating dummy data loaders...")
41
+ train_loader, test_loader = get_dummy_data_loaders()
42
+
43
+ print(f"Creating model on {config.DEVICE}...")
44
+ model = get_model(num_classes=config.NUM_CLASSES, device=config.DEVICE)
45
+
46
+ criterion = nn.CrossEntropyLoss()
47
+ optimizer = optim.SGD(model.parameters(), lr=0.01)
48
+
49
+ history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
50
+
51
+ print("Starting debug training for 2 epochs...")
52
+ for epoch in range(2):
53
+ model.train()
54
+ running_loss = 0.0
55
+ correct = 0
56
+ total = 0
57
+
58
+ for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
59
+ inputs, labels = inputs.to(config.DEVICE), labels.to(config.DEVICE)
60
+ optimizer.zero_grad()
61
+ outputs = model(inputs)
62
+ loss = criterion(outputs, labels)
63
+ loss.backward()
64
+ optimizer.step()
65
+
66
+ running_loss += loss.item()
67
+ _, predicted = outputs.max(1)
68
+ total += labels.size(0)
69
+ correct += predicted.eq(labels).sum().item()
70
+
71
+ train_loss = running_loss / len(train_loader)
72
+ train_acc = 100. * correct / total
73
+
74
+ history['train_loss'].append(train_loss)
75
+ history['train_acc'].append(train_acc)
76
+ history['val_loss'].append(train_loss) # Just use train loss for dummy validation
77
+ history['val_acc'].append(train_acc)
78
+
79
+ print(f"Epoch {epoch+1}: Loss {train_loss:.4f}, Acc {train_acc:.2f}%")
80
+
81
+ # Save "best" model for app testing
82
+ save_checkpoint(model, optimizer, epoch, train_acc, config.BEST_MODEL_PATH)
83
+ plot_training_history(history, config.PLOTS_DIR)
84
+
85
+ print("\nDebug training complete. 'best_model.pth' created for testing the web app.")
86
+
87
+ if __name__ == "__main__":
88
+ debug_train()
evaluate.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation script for CIFAR-10 CNN
3
+ """
4
+ import os
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ import config
9
+ from model import get_model
10
+ from data_loader import get_data_loaders
11
+ from utils import (
12
+ load_checkpoint, plot_confusion_matrix,
13
+ print_classification_report, visualize_predictions
14
+ )
15
+
16
+
17
+ def evaluate():
18
+ """
19
+ Evaluate the trained model
20
+ """
21
+ # Create plots directory
22
+ os.makedirs(config.PLOTS_DIR, exist_ok=True)
23
+
24
+ # Get data loaders
25
+ print("Loading CIFAR-10 dataset...")
26
+ _, test_loader = get_data_loaders()
27
+ print(f"Test samples: {len(test_loader.dataset)}")
28
+
29
+ # Create model
30
+ print(f"\nLoading model from {config.BEST_MODEL_PATH}")
31
+ model = get_model(num_classes=config.NUM_CLASSES, device=config.DEVICE)
32
+
33
+ # Load checkpoint
34
+ if not os.path.exists(config.BEST_MODEL_PATH):
35
+ print(f"Error: Model checkpoint not found at {config.BEST_MODEL_PATH}")
36
+ print("Please train the model first using train.py")
37
+ return
38
+
39
+ epoch, accuracy = load_checkpoint(model, None, config.BEST_MODEL_PATH)
40
+ print(f"Loaded model from epoch {epoch + 1} with accuracy: {accuracy:.2f}%")
41
+
42
+ # Evaluate
43
+ model.eval()
44
+ correct = 0
45
+ total = 0
46
+ all_predictions = []
47
+ all_labels = []
48
+
49
+ print("\nEvaluating model...")
50
+ with torch.no_grad():
51
+ pbar = tqdm(test_loader, desc='Evaluating')
52
+ for inputs, labels in pbar:
53
+ inputs, labels = inputs.to(config.DEVICE), labels.to(config.DEVICE)
54
+
55
+ # Forward pass
56
+ outputs = model(inputs)
57
+ _, predicted = outputs.max(1)
58
+
59
+ # Statistics
60
+ total += labels.size(0)
61
+ correct += predicted.eq(labels).sum().item()
62
+
63
+ # Store predictions and labels
64
+ all_predictions.extend(predicted.cpu().numpy())
65
+ all_labels.extend(labels.cpu().numpy())
66
+
67
+ # Update progress bar
68
+ pbar.set_postfix({'acc': f'{100. * correct / total:.2f}%'})
69
+
70
+ # Calculate final accuracy
71
+ final_accuracy = 100. * correct / total
72
+
73
+ # Print results
74
+ print("\n" + "=" * 80)
75
+ print(f"Final Test Accuracy: {final_accuracy:.2f}%")
76
+ print(f"Correct predictions: {correct}/{total}")
77
+ print("=" * 80)
78
+
79
+ # Print classification report
80
+ print_classification_report(all_labels, all_predictions)
81
+
82
+ # Plot confusion matrix
83
+ print("\nGenerating confusion matrix...")
84
+ cm_path = os.path.join(config.PLOTS_DIR, 'confusion_matrix.png')
85
+ plot_confusion_matrix(all_labels, all_predictions, cm_path)
86
+ print(f"Confusion matrix saved to {cm_path}")
87
+
88
+ # Visualize predictions
89
+ print("\nGenerating prediction visualizations...")
90
+ visualize_predictions(model, test_loader, config.DEVICE, num_images=16)
91
+
92
+ print("\nEvaluation completed!")
93
+
94
+
95
+ if __name__ == '__main__':
96
+ evaluate()
model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RNN Model Architecture for CIFAR-10 Classification
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ import config
7
+
8
+
9
+ class CIFAR10RNN(nn.Module):
10
+ """
11
+ Recurrent Neural Network (LSTM) for CIFAR-10 classification
12
+
13
+ Architecture:
14
+ - Input sequence: 32 rows of 32x3 pixels (= 96 features per step)
15
+ - Bidirectional LSTM layers
16
+ - Fully connected layer for classification
17
+ """
18
+
19
+ def __init__(self, input_size=96, hidden_size=256, num_layers=2, num_classes=10):
20
+ super(CIFAR10RNN, self).__init__()
21
+
22
+ self.hidden_size = hidden_size
23
+ self.num_layers = num_layers
24
+
25
+ # LSTM Layer
26
+ # batch_first=True means input shape is (batch, seq, feature)
27
+ self.lstm = nn.LSTM(
28
+ input_size,
29
+ hidden_size,
30
+ num_layers,
31
+ batch_first=True,
32
+ bidirectional=True,
33
+ dropout=config.RNN_DROPOUT if num_layers > 1 else 0
34
+ )
35
+
36
+ # Fully Connected Layer
37
+ # * 2 because of bidirectional
38
+ self.fc = nn.Sequential(
39
+ nn.Linear(hidden_size * 2, 512),
40
+ nn.ReLU(),
41
+ nn.Dropout(0.3),
42
+ nn.Linear(512, num_classes)
43
+ )
44
+
45
+ def forward(self, x):
46
+ # x shape: (batch, 3, 32, 32)
47
+ # Convert to: (batch, seq_len=32, input_size=96)
48
+ batch_size = x.size(0)
49
+
50
+ # Rearrange image rows into a sequence
51
+ # (batch, 3, 32, 32) -> (batch, 32, 3, 32) -> (batch, 32, 96)
52
+ x = x.permute(0, 2, 1, 3).contiguous()
53
+ x = x.view(batch_size, 32, -1)
54
+
55
+ # LSTM Forward pass
56
+ # out: tensor of shape (batch, seq_len, hidden_size * 2)
57
+ out, _ = self.lstm(x)
58
+
59
+ # Take the output of the last time step
60
+ out = out[:, -1, :]
61
+
62
+ # Classification
63
+ out = self.fc(out)
64
+
65
+ return out
66
+
67
+
68
+ def get_model(num_classes=10, device='cpu'):
69
+ """
70
+ Create and return the RNN model
71
+
72
+ Args:
73
+ num_classes (int): Number of output classes
74
+ device (str or torch.device): Device to load the model on
75
+
76
+ Returns:
77
+ CIFAR10RNN: The RNN model
78
+ """
79
+ model = CIFAR10RNN(
80
+ input_size=32*3,
81
+ hidden_size=config.HIDDEN_SIZE,
82
+ num_layers=config.NUM_LAYERS,
83
+ num_classes=num_classes
84
+ )
85
+ model = model.to(device)
86
+ return model
87
+
88
+
89
+ def count_parameters(model):
90
+ """
91
+ Count the number of trainable parameters in the model
92
+ """
93
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ numpy>=1.24.0
4
+ matplotlib>=3.7.0
5
+ pillow>=9.5.0
6
+ flask>=2.3.0
7
+ tqdm>=4.65.0
8
+ scikit-learn>=1.3.0
static/script.js ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // CIFAR-10 Classifier JavaScript
2
+
3
+ let selectedFile = null;
4
+
5
+ // DOM Elements
6
+ const uploadArea = document.getElementById('uploadArea');
7
+ const fileInput = document.getElementById('fileInput');
8
+ const previewCard = document.getElementById('previewCard');
9
+ const previewImage = document.getElementById('previewImage');
10
+ const classifyBtn = document.getElementById('classifyBtn');
11
+ const randomBtn = document.getElementById('randomBtn');
12
+ const resultsSection = document.getElementById('resultsSection');
13
+ const loadingOverlay = document.getElementById('loadingOverlay');
14
+
15
+ // Upload area click handler
16
+ uploadArea.addEventListener('click', () => {
17
+ fileInput.click();
18
+ });
19
+
20
+ // File input change handler
21
+ fileInput.addEventListener('change', (e) => {
22
+ const file = e.target.files[0];
23
+ if (file) {
24
+ handleFile(file);
25
+ }
26
+ });
27
+
28
+ // Drag and drop handlers
29
+ uploadArea.addEventListener('dragover', (e) => {
30
+ e.preventDefault();
31
+ uploadArea.classList.add('drag-over');
32
+ });
33
+
34
+ uploadArea.addEventListener('dragleave', () => {
35
+ uploadArea.classList.remove('drag-over');
36
+ });
37
+
38
+ uploadArea.addEventListener('drop', (e) => {
39
+ e.preventDefault();
40
+ uploadArea.classList.remove('drag-over');
41
+
42
+ const file = e.dataTransfer.files[0];
43
+ if (file && file.type.startsWith('image/')) {
44
+ handleFile(file);
45
+ }
46
+ });
47
+
48
+ // Handle file selection
49
+ function handleFile(file) {
50
+ selectedFile = file;
51
+
52
+ const reader = new FileReader();
53
+ reader.onload = (e) => {
54
+ previewImage.src = e.target.result;
55
+ previewCard.style.display = 'block';
56
+ resultsSection.style.display = 'none';
57
+ };
58
+ reader.readAsDataURL(file);
59
+ }
60
+
61
+ // Classify button handler
62
+ classifyBtn.addEventListener('click', async () => {
63
+ if (!selectedFile) return;
64
+
65
+ const formData = new FormData();
66
+ formData.append('file', selectedFile);
67
+
68
+ try {
69
+ loadingOverlay.style.display = 'flex';
70
+
71
+ const response = await fetch('/predict', {
72
+ method: 'POST',
73
+ body: formData
74
+ });
75
+
76
+ const data = await response.json();
77
+
78
+ if (data.error) {
79
+ alert('Error: ' + data.error);
80
+ return;
81
+ }
82
+
83
+ displayResults(data);
84
+
85
+ } catch (error) {
86
+ alert('Error: ' + error.message);
87
+ } finally {
88
+ loadingOverlay.style.display = 'none';
89
+ }
90
+ });
91
+
92
+ // Random sample button handler
93
+ randomBtn.addEventListener('click', async () => {
94
+ try {
95
+ loadingOverlay.style.display = 'flex';
96
+
97
+ const response = await fetch('/random_sample');
98
+ const data = await response.json();
99
+
100
+ if (data.error) {
101
+ alert('Error: ' + data.error);
102
+ return;
103
+ }
104
+
105
+ // Convert base64 to blob
106
+ const blob = await fetch(data.image).then(r => r.blob());
107
+ const file = new File([blob], 'random_sample.png', { type: 'image/png' });
108
+
109
+ handleFile(file);
110
+
111
+ } catch (error) {
112
+ alert('Error: ' + error.message);
113
+ } finally {
114
+ loadingOverlay.style.display = 'none';
115
+ }
116
+ });
117
+
118
+ // Display classification results
119
+ function displayResults(data) {
120
+ document.getElementById('predictedClass').textContent = data.predicted_class;
121
+ document.getElementById('confidenceValue').textContent = data.confidence.toFixed(2) + '%';
122
+
123
+ // Update confidence badge color based on confidence level
124
+ const badge = document.getElementById('confidenceBadge');
125
+ if (data.confidence >= 80) {
126
+ badge.style.background = 'rgba(79, 172, 254, 0.2)';
127
+ badge.style.borderColor = 'rgba(79, 172, 254, 0.4)';
128
+ badge.style.color = '#4facfe';
129
+ } else if (data.confidence >= 60) {
130
+ badge.style.background = 'rgba(240, 147, 251, 0.2)';
131
+ badge.style.borderColor = 'rgba(240, 147, 251, 0.4)';
132
+ badge.style.color = '#f093fb';
133
+ } else {
134
+ badge.style.background = 'rgba(245, 87, 108, 0.2)';
135
+ badge.style.borderColor = 'rgba(245, 87, 108, 0.4)';
136
+ badge.style.color = '#f5576c';
137
+ }
138
+
139
+ // Display top 5 predictions
140
+ const top5Container = document.getElementById('top5Container');
141
+ top5Container.innerHTML = '';
142
+
143
+ data.top5_predictions.forEach((pred, index) => {
144
+ const item = document.createElement('div');
145
+ item.className = 'prediction-item';
146
+ item.style.animationDelay = `${index * 0.1}s`;
147
+
148
+ item.innerHTML = `
149
+ <span class="prediction-item-name">${pred.class}</span>
150
+ <div class="prediction-item-bar">
151
+ <div class="prediction-item-fill" style="width: ${pred.probability}%"></div>
152
+ </div>
153
+ <span class="prediction-item-value">${pred.probability.toFixed(2)}%</span>
154
+ `;
155
+
156
+ top5Container.appendChild(item);
157
+ });
158
+
159
+ resultsSection.style.display = 'grid';
160
+
161
+ // Scroll to results
162
+ resultsSection.scrollIntoView({ behavior: 'smooth', block: 'nearest' });
163
+ }
static/style.css ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Modern CSS for CIFAR-10 Classifier */
2
+
3
+ :root {
4
+ --primary-gradient: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
5
+ --secondary-gradient: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
6
+ --success-gradient: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
7
+ --bg-primary: #0a0e27;
8
+ --bg-secondary: #151932;
9
+ --bg-card: rgba(255, 255, 255, 0.05);
10
+ --bg-card-hover: rgba(255, 255, 255, 0.08);
11
+ --text-primary: #ffffff;
12
+ --text-secondary: #a0aec0;
13
+ --text-muted: #718096;
14
+ --border-color: rgba(255, 255, 255, 0.1);
15
+ --shadow-lg: 0 8px 32px rgba(0, 0, 0, 0.3);
16
+ --spacing-xs: 0.5rem;
17
+ --spacing-sm: 1rem;
18
+ --spacing-md: 1.5rem;
19
+ --spacing-lg: 2rem;
20
+ --spacing-xl: 3rem;
21
+ --radius-sm: 8px;
22
+ --radius-md: 12px;
23
+ --radius-lg: 16px;
24
+ --radius-xl: 24px;
25
+ }
26
+
27
+ * {
28
+ margin: 0;
29
+ padding: 0;
30
+ box-sizing: border-box;
31
+ }
32
+
33
+ body {
34
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
35
+ background: var(--bg-primary);
36
+ color: var(--text-primary);
37
+ line-height: 1.6;
38
+ min-height: 100vh;
39
+ }
40
+
41
+ body::before {
42
+ content: '';
43
+ position: fixed;
44
+ top: 0;
45
+ left: 0;
46
+ width: 100%;
47
+ height: 100%;
48
+ background: radial-gradient(circle at 20% 50%, rgba(102, 126, 234, 0.1) 0%, transparent 50%),
49
+ radial-gradient(circle at 80% 80%, rgba(118, 75, 162, 0.1) 0%, transparent 50%);
50
+ z-index: -1;
51
+ }
52
+
53
+ .container {
54
+ max-width: 1400px;
55
+ margin: 0 auto;
56
+ padding: var(--spacing-lg);
57
+ }
58
+
59
+ .header {
60
+ text-align: center;
61
+ margin-bottom: var(--spacing-xl);
62
+ padding: var(--spacing-xl) 0;
63
+ }
64
+
65
+ .title {
66
+ font-size: 3.5rem;
67
+ font-weight: 700;
68
+ margin-bottom: var(--spacing-sm);
69
+ }
70
+
71
+ .gradient-text {
72
+ background: var(--primary-gradient);
73
+ -webkit-background-clip: text;
74
+ -webkit-text-fill-color: transparent;
75
+ background-clip: text;
76
+ }
77
+
78
+ .subtitle {
79
+ font-size: 1.25rem;
80
+ color: var(--text-secondary);
81
+ }
82
+
83
+ .upload-section {
84
+ display: grid;
85
+ grid-template-columns: 1fr 1fr;
86
+ gap: var(--spacing-lg);
87
+ margin-bottom: var(--spacing-xl);
88
+ }
89
+
90
+ .card {
91
+ background: var(--bg-card);
92
+ backdrop-filter: blur(10px);
93
+ border: 1px solid var(--border-color);
94
+ border-radius: var(--radius-lg);
95
+ padding: var(--spacing-lg);
96
+ transition: all 0.3s ease;
97
+ }
98
+
99
+ .card:hover {
100
+ background: var(--bg-card-hover);
101
+ transform: translateY(-2px);
102
+ box-shadow: var(--shadow-lg);
103
+ }
104
+
105
+ .card-title {
106
+ font-size: 1.5rem;
107
+ font-weight: 600;
108
+ margin-bottom: var(--spacing-md);
109
+ }
110
+
111
+ .upload-area {
112
+ border: 2px dashed var(--border-color);
113
+ border-radius: var(--radius-md);
114
+ padding: var(--spacing-xl);
115
+ text-align: center;
116
+ cursor: pointer;
117
+ transition: all 0.3s ease;
118
+ margin-bottom: var(--spacing-md);
119
+ }
120
+
121
+ .upload-area:hover {
122
+ border-color: #667eea;
123
+ background: rgba(102, 126, 234, 0.05);
124
+ }
125
+
126
+ .upload-icon {
127
+ width: 64px;
128
+ height: 64px;
129
+ margin: 0 auto var(--spacing-md);
130
+ color: #667eea;
131
+ }
132
+
133
+ .upload-text {
134
+ font-size: 1.125rem;
135
+ font-weight: 500;
136
+ margin-bottom: var(--spacing-xs);
137
+ }
138
+
139
+ .upload-subtext {
140
+ font-size: 0.875rem;
141
+ color: var(--text-muted);
142
+ }
143
+
144
+ .image-preview {
145
+ width: 100%;
146
+ height: 300px;
147
+ border-radius: var(--radius-md);
148
+ overflow: hidden;
149
+ margin-bottom: var(--spacing-md);
150
+ background: var(--bg-secondary);
151
+ display: flex;
152
+ align-items: center;
153
+ justify-content: center;
154
+ }
155
+
156
+ .image-preview img {
157
+ max-width: 100%;
158
+ max-height: 100%;
159
+ object-fit: contain;
160
+ }
161
+
162
+ .btn {
163
+ display: inline-flex;
164
+ align-items: center;
165
+ justify-content: center;
166
+ gap: var(--spacing-xs);
167
+ padding: 0.875rem 1.75rem;
168
+ font-size: 1rem;
169
+ font-weight: 600;
170
+ border: none;
171
+ border-radius: var(--radius-md);
172
+ cursor: pointer;
173
+ transition: all 0.3s ease;
174
+ width: 100%;
175
+ }
176
+
177
+ .btn svg {
178
+ width: 20px;
179
+ height: 20px;
180
+ }
181
+
182
+ .btn-primary {
183
+ background: var(--primary-gradient);
184
+ color: white;
185
+ }
186
+
187
+ .btn-primary:hover {
188
+ transform: translateY(-2px);
189
+ box-shadow: 0 8px 24px rgba(102, 126, 234, 0.4);
190
+ }
191
+
192
+ .btn-secondary {
193
+ background: var(--bg-secondary);
194
+ color: var(--text-primary);
195
+ border: 1px solid var(--border-color);
196
+ }
197
+
198
+ .btn-secondary:hover {
199
+ background: var(--bg-card);
200
+ border-color: #667eea;
201
+ }
202
+
203
+ .results-section {
204
+ display: grid;
205
+ grid-template-columns: 2fr 1fr;
206
+ gap: var(--spacing-lg);
207
+ }
208
+
209
+ .prediction-main {
210
+ text-align: center;
211
+ padding: var(--spacing-lg);
212
+ background: var(--bg-secondary);
213
+ border-radius: var(--radius-md);
214
+ margin-bottom: var(--spacing-lg);
215
+ }
216
+
217
+ .prediction-label {
218
+ font-size: 0.875rem;
219
+ text-transform: uppercase;
220
+ letter-spacing: 0.1em;
221
+ color: var(--text-muted);
222
+ margin-bottom: var(--spacing-sm);
223
+ }
224
+
225
+ .prediction-class {
226
+ font-size: 2.5rem;
227
+ font-weight: 700;
228
+ background: var(--success-gradient);
229
+ -webkit-background-clip: text;
230
+ -webkit-text-fill-color: transparent;
231
+ background-clip: text;
232
+ margin-bottom: var(--spacing-md);
233
+ }
234
+
235
+ .confidence-badge {
236
+ display: inline-block;
237
+ padding: var(--spacing-xs) var(--spacing-md);
238
+ background: rgba(79, 172, 254, 0.2);
239
+ border: 1px solid rgba(79, 172, 254, 0.4);
240
+ border-radius: var(--radius-xl);
241
+ font-size: 0.875rem;
242
+ font-weight: 600;
243
+ color: #4facfe;
244
+ }
245
+
246
+ .predictions-title {
247
+ font-size: 1.125rem;
248
+ font-weight: 600;
249
+ margin-bottom: var(--spacing-md);
250
+ color: var(--text-secondary);
251
+ }
252
+
253
+ .prediction-item {
254
+ display: flex;
255
+ align-items: center;
256
+ justify-content: space-between;
257
+ padding: var(--spacing-sm);
258
+ background: var(--bg-secondary);
259
+ border-radius: var(--radius-sm);
260
+ margin-bottom: var(--spacing-sm);
261
+ transition: all 0.3s ease;
262
+ }
263
+
264
+ .prediction-item:hover {
265
+ background: var(--bg-card);
266
+ transform: translateX(4px);
267
+ }
268
+
269
+ .prediction-item-name {
270
+ font-weight: 500;
271
+ }
272
+
273
+ .prediction-item-bar {
274
+ flex: 1;
275
+ height: 8px;
276
+ background: var(--bg-primary);
277
+ border-radius: var(--radius-xl);
278
+ margin: 0 var(--spacing-md);
279
+ overflow: hidden;
280
+ }
281
+
282
+ .prediction-item-fill {
283
+ height: 100%;
284
+ background: var(--primary-gradient);
285
+ border-radius: var(--radius-xl);
286
+ transition: width 0.8s ease;
287
+ }
288
+
289
+ .prediction-item-value {
290
+ font-weight: 600;
291
+ color: var(--text-secondary);
292
+ min-width: 50px;
293
+ text-align: right;
294
+ }
295
+
296
+ .info-title {
297
+ font-size: 1.125rem;
298
+ font-weight: 600;
299
+ margin-bottom: var(--spacing-md);
300
+ color: var(--text-secondary);
301
+ }
302
+
303
+ .classes-grid {
304
+ display: grid;
305
+ grid-template-columns: 1fr 1fr;
306
+ gap: var(--spacing-sm);
307
+ }
308
+
309
+ .class-item {
310
+ display: flex;
311
+ align-items: center;
312
+ gap: var(--spacing-sm);
313
+ padding: var(--spacing-sm);
314
+ background: var(--bg-secondary);
315
+ border-radius: var(--radius-sm);
316
+ transition: all 0.3s ease;
317
+ }
318
+
319
+ .class-item:hover {
320
+ background: var(--bg-card);
321
+ transform: translateX(4px);
322
+ }
323
+
324
+ .class-icon {
325
+ width: 32px;
326
+ height: 32px;
327
+ display: flex;
328
+ align-items: center;
329
+ justify-content: center;
330
+ background: var(--primary-gradient);
331
+ border-radius: var(--radius-sm);
332
+ font-weight: 600;
333
+ font-size: 0.875rem;
334
+ }
335
+
336
+ .class-name {
337
+ font-weight: 500;
338
+ text-transform: capitalize;
339
+ }
340
+
341
+ .loading-overlay {
342
+ position: fixed;
343
+ top: 0;
344
+ left: 0;
345
+ width: 100%;
346
+ height: 100%;
347
+ background: rgba(10, 14, 39, 0.9);
348
+ backdrop-filter: blur(8px);
349
+ display: flex;
350
+ flex-direction: column;
351
+ align-items: center;
352
+ justify-content: center;
353
+ z-index: 1000;
354
+ }
355
+
356
+ .spinner {
357
+ width: 64px;
358
+ height: 64px;
359
+ border: 4px solid var(--border-color);
360
+ border-top-color: #667eea;
361
+ border-radius: 50%;
362
+ animation: spin 1s linear infinite;
363
+ }
364
+
365
+ @keyframes spin {
366
+ to { transform: rotate(360deg); }
367
+ }
368
+
369
+ .loading-text {
370
+ margin-top: var(--spacing-md);
371
+ font-size: 1.125rem;
372
+ color: var(--text-secondary);
373
+ }
374
+
375
+ .footer {
376
+ text-align: center;
377
+ padding: var(--spacing-xl) 0;
378
+ color: var(--text-muted);
379
+ font-size: 0.875rem;
380
+ border-top: 1px solid var(--border-color);
381
+ margin-top: var(--spacing-xl);
382
+ }
383
+
384
+ @media (max-width: 1024px) {
385
+ .upload-section, .results-section {
386
+ grid-template-columns: 1fr;
387
+ }
388
+ .title {
389
+ font-size: 2.5rem;
390
+ }
391
+ }
templates/index.html ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>CIFAR-10 Image Classifier</title>
7
+ <meta name="description" content="Deep learning powered CIFAR-10 image classification using Convolutional Neural Networks">
8
+ <link rel="stylesheet" href="{{ url_for('static', filename='style.css') }}">
9
+ <link rel="preconnect" href="https://fonts.googleapis.com">
10
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
11
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap" rel="stylesheet">
12
+ </head>
13
+ <body>
14
+ <div class="container">
15
+ <header class="header">
16
+ <div class="header-content">
17
+ <h1 class="title">
18
+ <span class="gradient-text">CIFAR-10</span> Image Classifier
19
+ </h1>
20
+ <p class="subtitle">Powered by Deep Learning & Convolutional Neural Networks</p>
21
+ </div>
22
+ </header>
23
+
24
+ <main class="main-content">
25
+ <div class="upload-section">
26
+ <div class="card upload-card">
27
+ <h2 class="card-title">Upload Image</h2>
28
+ <div class="upload-area" id="uploadArea">
29
+ <svg class="upload-icon" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke="currentColor">
30
+ <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M7 16a4 4 0 01-.88-7.903A5 5 0 1115.9 6L16 6a5 5 0 011 9.9M15 13l-3-3m0 0l-3 3m3-3v12" />
31
+ </svg>
32
+ <p class="upload-text">Drag & drop an image here</p>
33
+ <p class="upload-subtext">or click to browse</p>
34
+ <input type="file" id="fileInput" accept="image/*" hidden>
35
+ </div>
36
+
37
+ <button class="btn btn-secondary" id="randomBtn">
38
+ <svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke="currentColor">
39
+ <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15" />
40
+ </svg>
41
+ Try Random Sample
42
+ </button>
43
+ </div>
44
+
45
+ <div class="card preview-card" id="previewCard" style="display: none;">
46
+ <h2 class="card-title">Preview</h2>
47
+ <div class="image-preview">
48
+ <img id="previewImage" src="" alt="Preview">
49
+ </div>
50
+ <button class="btn btn-primary" id="classifyBtn">
51
+ <svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke="currentColor">
52
+ <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z" />
53
+ </svg>
54
+ Classify Image
55
+ </button>
56
+ </div>
57
+ </div>
58
+
59
+ <div class="results-section" id="resultsSection" style="display: none;">
60
+ <div class="card results-card">
61
+ <h2 class="card-title">Classification Results</h2>
62
+
63
+ <div class="prediction-main">
64
+ <div class="prediction-label">Predicted Class</div>
65
+ <div class="prediction-class" id="predictedClass">-</div>
66
+ <div class="confidence-badge" id="confidenceBadge">
67
+ <span id="confidenceValue">0%</span> confidence
68
+ </div>
69
+ </div>
70
+
71
+ <div class="top-predictions">
72
+ <h3 class="predictions-title">Top 5 Predictions</h3>
73
+ <div id="top5Container"></div>
74
+ </div>
75
+ </div>
76
+
77
+ <div class="card info-card">
78
+ <h3 class="info-title">CIFAR-10 Classes</h3>
79
+ <div class="classes-grid">
80
+ {% for class_name in class_names %}
81
+ <div class="class-item">
82
+ <div class="class-icon">{{ loop.index0 }}</div>
83
+ <div class="class-name">{{ class_name }}</div>
84
+ </div>
85
+ {% endfor %}
86
+ </div>
87
+ </div>
88
+ </div>
89
+
90
+ <div class="loading-overlay" id="loadingOverlay" style="display: none;">
91
+ <div class="spinner"></div>
92
+ <p class="loading-text">Classifying image...</p>
93
+ </div>
94
+ </main>
95
+
96
+ <footer class="footer">
97
+ <p>Built with PyTorch & Flask | CNN Architecture with 3 Convolutional Blocks</p>
98
+ </footer>
99
+ </div>
100
+
101
+ <script src="{{ url_for('static', filename='script.js') }}"></script>
102
+ </body>
103
+ </html>
train.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training script for CIFAR-10 CNN
3
+ """
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ from tqdm import tqdm
9
+ import matplotlib.pyplot as plt
10
+
11
+ import config
12
+ from model import get_model, count_parameters
13
+ from data_loader import get_data_loaders
14
+ from utils import save_checkpoint, load_checkpoint, plot_training_history
15
+
16
+
17
+ def train_epoch(model, train_loader, criterion, optimizer, device):
18
+ """
19
+ Train the model for one epoch
20
+
21
+ Args:
22
+ model: PyTorch model
23
+ train_loader: Training data loader
24
+ criterion: Loss function
25
+ optimizer: Optimizer
26
+ device: Device to train on
27
+
28
+ Returns:
29
+ tuple: (average_loss, accuracy)
30
+ """
31
+ model.train()
32
+ running_loss = 0.0
33
+ correct = 0
34
+ total = 0
35
+
36
+ pbar = tqdm(train_loader, desc='Training')
37
+ for inputs, labels in pbar:
38
+ inputs, labels = inputs.to(device), labels.to(device)
39
+
40
+ # Zero the parameter gradients
41
+ optimizer.zero_grad()
42
+
43
+ # Forward pass
44
+ outputs = model(inputs)
45
+ loss = criterion(outputs, labels)
46
+
47
+ # Backward pass and optimize
48
+ loss.backward()
49
+ optimizer.step()
50
+
51
+ # Statistics
52
+ running_loss += loss.item()
53
+ _, predicted = outputs.max(1)
54
+ total += labels.size(0)
55
+ correct += predicted.eq(labels).sum().item()
56
+
57
+ # Update progress bar
58
+ pbar.set_postfix({
59
+ 'loss': f'{running_loss / (pbar.n + 1):.4f}',
60
+ 'acc': f'{100. * correct / total:.2f}%'
61
+ })
62
+
63
+ epoch_loss = running_loss / len(train_loader)
64
+ epoch_acc = 100. * correct / total
65
+
66
+ return epoch_loss, epoch_acc
67
+
68
+
69
+ def validate(model, test_loader, criterion, device):
70
+ """
71
+ Validate the model
72
+
73
+ Args:
74
+ model: PyTorch model
75
+ test_loader: Test data loader
76
+ criterion: Loss function
77
+ device: Device to validate on
78
+
79
+ Returns:
80
+ tuple: (average_loss, accuracy)
81
+ """
82
+ model.eval()
83
+ running_loss = 0.0
84
+ correct = 0
85
+ total = 0
86
+
87
+ with torch.no_grad():
88
+ pbar = tqdm(test_loader, desc='Validation')
89
+ for inputs, labels in pbar:
90
+ inputs, labels = inputs.to(device), labels.to(device)
91
+
92
+ # Forward pass
93
+ outputs = model(inputs)
94
+ loss = criterion(outputs, labels)
95
+
96
+ # Statistics
97
+ running_loss += loss.item()
98
+ _, predicted = outputs.max(1)
99
+ total += labels.size(0)
100
+ correct += predicted.eq(labels).sum().item()
101
+
102
+ # Update progress bar
103
+ pbar.set_postfix({
104
+ 'loss': f'{running_loss / (pbar.n + 1):.4f}',
105
+ 'acc': f'{100. * correct / total:.2f}%'
106
+ })
107
+
108
+ epoch_loss = running_loss / len(test_loader)
109
+ epoch_acc = 100. * correct / total
110
+
111
+ return epoch_loss, epoch_acc
112
+
113
+
114
+ def train():
115
+ """
116
+ Main training function
117
+ """
118
+ # Create directories
119
+ os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)
120
+ os.makedirs(config.PLOTS_DIR, exist_ok=True)
121
+
122
+ # Get data loaders
123
+ print("Loading CIFAR-10 dataset...")
124
+ train_loader, test_loader = get_data_loaders()
125
+ print(f"Training samples: {len(train_loader.dataset)}")
126
+ print(f"Test samples: {len(test_loader.dataset)}")
127
+
128
+ # Create model
129
+ print(f"\nCreating model on device: {config.DEVICE}")
130
+ model = get_model(num_classes=config.NUM_CLASSES, device=config.DEVICE)
131
+ print(f"Model parameters: {count_parameters(model):,}")
132
+
133
+ # Loss function and optimizer
134
+ criterion = nn.CrossEntropyLoss()
135
+ optimizer = optim.SGD(
136
+ model.parameters(),
137
+ lr=config.LEARNING_RATE,
138
+ momentum=config.MOMENTUM,
139
+ weight_decay=config.WEIGHT_DECAY
140
+ )
141
+
142
+ # Learning rate scheduler
143
+ scheduler = None
144
+ if config.USE_SCHEDULER:
145
+ scheduler = optim.lr_scheduler.StepLR(
146
+ optimizer,
147
+ step_size=config.SCHEDULER_STEP_SIZE,
148
+ gamma=config.SCHEDULER_GAMMA
149
+ )
150
+
151
+ # Training history
152
+ history = {
153
+ 'train_loss': [],
154
+ 'train_acc': [],
155
+ 'val_loss': [],
156
+ 'val_acc': []
157
+ }
158
+
159
+ best_acc = 0.0
160
+ start_epoch = 0
161
+
162
+ # Training loop
163
+ print(f"\nStarting training for {config.EPOCHS} epochs...")
164
+ for epoch in range(start_epoch, config.EPOCHS):
165
+ print(f"\nEpoch {epoch + 1}/{config.EPOCHS}")
166
+ print("-" * 50)
167
+
168
+ # Train
169
+ train_loss, train_acc = train_epoch(
170
+ model, train_loader, criterion, optimizer, config.DEVICE
171
+ )
172
+
173
+ # Validate
174
+ val_loss, val_acc = validate(
175
+ model, test_loader, criterion, config.DEVICE
176
+ )
177
+
178
+ # Update learning rate
179
+ if scheduler:
180
+ scheduler.step()
181
+ current_lr = scheduler.get_last_lr()[0]
182
+ print(f"Learning rate: {current_lr:.6f}")
183
+
184
+ # Save history
185
+ history['train_loss'].append(train_loss)
186
+ history['train_acc'].append(train_acc)
187
+ history['val_loss'].append(val_loss)
188
+ history['val_acc'].append(val_acc)
189
+
190
+ # Print epoch summary
191
+ print(f"\nEpoch {epoch + 1} Summary:")
192
+ print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
193
+ print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
194
+
195
+ # Save best model
196
+ if val_acc > best_acc:
197
+ best_acc = val_acc
198
+ save_checkpoint(
199
+ model, optimizer, epoch, val_acc,
200
+ config.BEST_MODEL_PATH
201
+ )
202
+ print(f"✓ Best model saved with accuracy: {best_acc:.2f}%")
203
+
204
+ # Save last model
205
+ save_checkpoint(
206
+ model, optimizer, epoch, val_acc,
207
+ config.LAST_MODEL_PATH
208
+ )
209
+
210
+ # Plot training history
211
+ plot_training_history(history, config.PLOTS_DIR)
212
+
213
+ print("\n" + "=" * 50)
214
+ print(f"Training completed!")
215
+ print(f"Best validation accuracy: {best_acc:.2f}%")
216
+ print("=" * 50)
217
+
218
+
219
+ if __name__ == '__main__':
220
+ train()
train_log.txt ADDED
The diff for this file is too large to render. See raw diff
 
utils.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for the CIFAR-10 CNN project
3
+ """
4
+ import os
5
+ import torch
6
+ import matplotlib
7
+ matplotlib.use('Agg')
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ from sklearn.metrics import confusion_matrix, classification_report
11
+ import seaborn as sns
12
+
13
+ import config
14
+
15
+
16
+ def save_checkpoint(model, optimizer, epoch, accuracy, filepath):
17
+ """
18
+ Save model checkpoint
19
+
20
+ Args:
21
+ model: PyTorch model
22
+ optimizer: Optimizer
23
+ epoch: Current epoch
24
+ accuracy: Current accuracy
25
+ filepath: Path to save checkpoint
26
+ """
27
+ checkpoint = {
28
+ 'epoch': epoch,
29
+ 'model_state_dict': model.state_dict(),
30
+ 'optimizer_state_dict': optimizer.state_dict(),
31
+ 'accuracy': accuracy
32
+ }
33
+ torch.save(checkpoint, filepath)
34
+
35
+
36
+ def load_checkpoint(model, optimizer, filepath):
37
+ """
38
+ Load model checkpoint
39
+
40
+ Args:
41
+ model: PyTorch model
42
+ optimizer: Optimizer
43
+ filepath: Path to checkpoint file
44
+
45
+ Returns:
46
+ tuple: (epoch, accuracy)
47
+ """
48
+ checkpoint = torch.load(filepath, map_location=config.DEVICE)
49
+ model.load_state_dict(checkpoint['model_state_dict'])
50
+ if optimizer:
51
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
52
+ epoch = checkpoint['epoch']
53
+ accuracy = checkpoint['accuracy']
54
+ return epoch, accuracy
55
+
56
+
57
+ def plot_training_history(history, save_dir):
58
+ """
59
+ Plot training history
60
+
61
+ Args:
62
+ history: Dictionary containing training history
63
+ save_dir: Directory to save plots
64
+ """
65
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
66
+
67
+ # Plot loss
68
+ ax1.plot(history['train_loss'], label='Train Loss', linewidth=2)
69
+ ax1.plot(history['val_loss'], label='Validation Loss', linewidth=2)
70
+ ax1.set_xlabel('Epoch', fontsize=12)
71
+ ax1.set_ylabel('Loss', fontsize=12)
72
+ ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
73
+ ax1.legend(fontsize=10)
74
+ ax1.grid(True, alpha=0.3)
75
+
76
+ # Plot accuracy
77
+ ax2.plot(history['train_acc'], label='Train Accuracy', linewidth=2)
78
+ ax2.plot(history['val_acc'], label='Validation Accuracy', linewidth=2)
79
+ ax2.set_xlabel('Epoch', fontsize=12)
80
+ ax2.set_ylabel('Accuracy (%)', fontsize=12)
81
+ ax2.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
82
+ ax2.legend(fontsize=10)
83
+ ax2.grid(True, alpha=0.3)
84
+
85
+ plt.tight_layout()
86
+ plt.savefig(os.path.join(save_dir, 'training_history.png'), dpi=300, bbox_inches='tight')
87
+ plt.close()
88
+
89
+
90
+ def plot_confusion_matrix(y_true, y_pred, save_path):
91
+ """
92
+ Plot confusion matrix
93
+
94
+ Args:
95
+ y_true: True labels
96
+ y_pred: Predicted labels
97
+ save_path: Path to save the plot
98
+ """
99
+ cm = confusion_matrix(y_true, y_pred)
100
+
101
+ plt.figure(figsize=(12, 10))
102
+ sns.heatmap(
103
+ cm, annot=True, fmt='d', cmap='Blues',
104
+ xticklabels=config.CLASS_NAMES,
105
+ yticklabels=config.CLASS_NAMES,
106
+ cbar_kws={'label': 'Count'}
107
+ )
108
+ plt.xlabel('Predicted Label', fontsize=12)
109
+ plt.ylabel('True Label', fontsize=12)
110
+ plt.title('Confusion Matrix', fontsize=14, fontweight='bold')
111
+ plt.tight_layout()
112
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
113
+ plt.close()
114
+
115
+
116
+ def print_classification_report(y_true, y_pred):
117
+ """
118
+ Print classification report
119
+
120
+ Args:
121
+ y_true: True labels
122
+ y_pred: Predicted labels
123
+ """
124
+ report = classification_report(
125
+ y_true, y_pred,
126
+ target_names=config.CLASS_NAMES,
127
+ digits=4
128
+ )
129
+ print("\nClassification Report:")
130
+ print("=" * 80)
131
+ print(report)
132
+ print("=" * 80)
133
+
134
+
135
+ def visualize_predictions(model, test_loader, device, num_images=16):
136
+ """
137
+ Visualize model predictions
138
+
139
+ Args:
140
+ model: PyTorch model
141
+ test_loader: Test data loader
142
+ device: Device to run on
143
+ num_images: Number of images to visualize
144
+ """
145
+ model.eval()
146
+
147
+ # Get a batch of images
148
+ images, labels = next(iter(test_loader))
149
+ images, labels = images[:num_images], labels[:num_images]
150
+ images_device = images.to(device)
151
+
152
+ # Get predictions
153
+ with torch.no_grad():
154
+ outputs = model(images_device)
155
+ _, predicted = outputs.max(1)
156
+
157
+ # Plot
158
+ fig, axes = plt.subplots(4, 4, figsize=(12, 12))
159
+ axes = axes.ravel()
160
+
161
+ for idx in range(num_images):
162
+ # Denormalize image
163
+ img = images[idx].cpu().numpy().transpose(1, 2, 0)
164
+ mean = np.array([0.4914, 0.4822, 0.4465])
165
+ std = np.array([0.2470, 0.2435, 0.2616])
166
+ img = img * std + mean
167
+ img = np.clip(img, 0, 1)
168
+
169
+ # Plot
170
+ axes[idx].imshow(img)
171
+ axes[idx].axis('off')
172
+
173
+ true_label = config.CLASS_NAMES[labels[idx]]
174
+ pred_label = config.CLASS_NAMES[predicted[idx].cpu()]
175
+
176
+ color = 'green' if labels[idx] == predicted[idx].cpu() else 'red'
177
+ axes[idx].set_title(
178
+ f'True: {true_label}\nPred: {pred_label}',
179
+ color=color, fontsize=10
180
+ )
181
+
182
+ plt.tight_layout()
183
+ plt.savefig(os.path.join(config.PLOTS_DIR, 'predictions.png'), dpi=300, bbox_inches='tight')
184
+ plt.close()
185
+
186
+ print(f"Predictions visualization saved to {config.PLOTS_DIR}/predictions.png")