import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader, random_split from torch.utils.tensorboard import SummaryWriter import matplotlib.pyplot as plt import seaborn as sns import numpy as np import argparse import os import logging from tqdm import tqdm from datetime import datetime import json import random from sklearn.metrics import confusion_matrix, classification_report from pathlib import Path # Setup logging def setup_logging(log_dir): log_dir = Path(log_dir) log_dir.mkdir(parents=True, exist_ok=True) logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_dir / 'training.log'), logging.StreamHandler() ] ) return logging.getLogger(__name__) # Set random seeds for reproducibility def set_seed(seed=42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # CNN Model Architecture class ConvNet(nn.Module): """Convolutional Neural Network for MNIST""" def __init__(self, dropout_rate=0.3, num_classes=10): super(ConvNet, self).__init__() # Convolutional layers self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(32) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(64) self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.bn3 = nn.BatchNorm2d(128) self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1) self.bn4 = nn.BatchNorm2d(128) self.pool = nn.MaxPool2d(2, 2) self.dropout_conv = nn.Dropout2d(dropout_rate * 0.5) # Fully connected layers self.fc1 = nn.Linear(128 * 7 * 7, 256) self.bn5 = nn.BatchNorm1d(256) self.dropout1 = nn.Dropout(dropout_rate) self.fc2 = nn.Linear(256, 128) self.bn6 = nn.BatchNorm1d(128) self.dropout2 = nn.Dropout(dropout_rate * 0.5) self.fc3 = nn.Linear(128, num_classes) self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.Linear)): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, x): # Block 1 x = self.conv1(x) x = self.bn1(x) x = torch.relu(x) x = self.conv2(x) x = self.bn2(x) x = torch.relu(x) x = self.pool(x) x = self.dropout_conv(x) # Block 2 x = self.conv3(x) x = self.bn3(x) x = torch.relu(x) x = self.conv4(x) x = self.bn4(x) x = torch.relu(x) x = self.pool(x) x = self.dropout_conv(x) # Flatten x = x.view(x.size(0), -1) # FC layers x = self.fc1(x) x = self.bn5(x) x = torch.relu(x) x = self.dropout1(x) x = self.fc2(x) x = self.bn6(x) x = torch.relu(x) x = self.dropout2(x) x = self.fc3(x) return x # Improved Fully Connected Network class ImprovedNN(nn.Module): """Enhanced fully connected network with configurable architecture""" def __init__(self, input_size=784, hidden_sizes=[512, 256, 128], num_classes=10, dropout_rate=0.3): super(ImprovedNN, self).__init__() layers = [] prev_size = input_size for i, hidden_size in enumerate(hidden_sizes): layers.extend([ nn.Linear(prev_size, hidden_size), nn.BatchNorm1d(hidden_size), nn.ReLU(), nn.Dropout(dropout_rate if i < len(hidden_sizes) - 1 else dropout_rate * 0.5) ]) prev_size = hidden_size layers.append(nn.Linear(prev_size, num_classes)) self.network = nn.Sequential(*layers) self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, x): x = x.view(x.size(0), -1) return self.network(x) # Trainer class class Trainer: def __init__(self, model, train_loader, val_loader, test_loader, criterion, optimizer, scheduler, device, args, logger): self.model = model self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader self.criterion = criterion self.optimizer = optimizer self.scheduler = scheduler self.device = device self.args = args self.logger = logger # Setup TensorBoard self.writer = SummaryWriter(log_dir=args.log_dir) # Training history self.train_losses = [] self.val_losses = [] self.train_accs = [] self.val_accs = [] self.best_val_acc = 0.0 self.patience_counter = 0 # Mixed precision training self.scaler = torch.cuda.amp.GradScaler() if args.use_amp and device.type == 'cuda' else None def train_epoch(self, epoch): self.model.train() running_loss = 0.0 correct = 0 total = 0 progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1} [Train]") for batch_idx, (images, labels) in enumerate(progress_bar): images, labels = images.to(self.device, non_blocking=True), labels.to(self.device, non_blocking=True) self.optimizer.zero_grad(set_to_none=True) # Mixed precision training if self.scaler: with torch.cuda.amp.autocast(): outputs = self.model(images) loss = self.criterion(outputs, labels) self.scaler.scale(loss).backward() self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.scaler.step(self.optimizer) self.scaler.update() else: outputs = self.model(images) loss = self.criterion(outputs, labels) loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.optimizer.step() running_loss += loss.item() _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() # Log to TensorBoard global_step = epoch * len(self.train_loader) + batch_idx if batch_idx % 50 == 0: self.writer.add_scalar('Train/BatchLoss', loss.item(), global_step) self.writer.add_scalar('Train/BatchAcc', 100. * correct / total, global_step) progress_bar.set_postfix({ 'Loss': f"{loss.item():.4f}", 'Acc': f"{100.*correct/total:.2f}%" }) epoch_loss = running_loss / len(self.train_loader) epoch_acc = 100. * correct / total return epoch_loss, epoch_acc def validate(self, loader, phase="Val"): self.model.eval() running_loss = 0.0 correct = 0 total = 0 all_preds = [] all_labels = [] with torch.no_grad(): progress_bar = tqdm(loader, desc=f"[{phase}]") for images, labels in progress_bar: images, labels = images.to(self.device, non_blocking=True), labels.to(self.device, non_blocking=True) if self.scaler: with torch.cuda.amp.autocast(): outputs = self.model(images) loss = self.criterion(outputs, labels) else: outputs = self.model(images) loss = self.criterion(outputs, labels) running_loss += loss.item() _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() all_preds.extend(predicted.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) progress_bar.set_postfix({ 'Loss': f"{loss.item():.4f}", 'Acc': f"{100.*correct/total:.2f}%" }) epoch_loss = running_loss / len(loader) epoch_acc = 100. * correct / total return epoch_loss, epoch_acc, np.array(all_preds), np.array(all_labels) def train(self): self.logger.info(f"Starting training for {self.args.epochs} epochs") self.logger.info(f"Model: {self.args.model_type}, Optimizer: {self.args.optimizer}") self.logger.info(f"Learning rate: {self.args.lr}, Batch size: {self.args.batch_size}") start_time = datetime.now() for epoch in range(self.args.epochs): # Learning rate warmup if epoch < self.args.warmup_epochs: warmup_lr = self.args.lr * (epoch + 1) / self.args.warmup_epochs for param_group in self.optimizer.param_groups: param_group['lr'] = warmup_lr train_loss, train_acc = self.train_epoch(epoch) val_loss, val_acc, val_preds, val_labels = self.validate(self.val_loader, "Val") self.train_losses.append(train_loss) self.val_losses.append(val_loss) self.train_accs.append(train_acc) self.val_accs.append(val_acc) # Step scheduler after warmup if epoch >= self.args.warmup_epochs: self.scheduler.step() current_lr = self.optimizer.param_groups[0]['lr'] # Log to TensorBoard self.writer.add_scalar('Epoch/TrainLoss', train_loss, epoch) self.writer.add_scalar('Epoch/ValLoss', val_loss, epoch) self.writer.add_scalar('Epoch/TrainAcc', train_acc, epoch) self.writer.add_scalar('Epoch/ValAcc', val_acc, epoch) self.writer.add_scalar('Epoch/LearningRate', current_lr, epoch) # Per-class accuracy per_class_acc = self._compute_per_class_accuracy(val_preds, val_labels) for class_idx, acc in enumerate(per_class_acc): self.writer.add_scalar(f'PerClass/Val_Class_{class_idx}', acc, epoch) self.logger.info(f"Epoch {epoch+1}/{self.args.epochs} | LR: {current_lr:.6f}") self.logger.info(f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%") self.logger.info(f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%") self.logger.info(f"Per-class Val Acc: {[f'{acc:.1f}%' for acc in per_class_acc]}") # Save best model if val_acc > self.best_val_acc: self.best_val_acc = val_acc self.patience_counter = 0 self.save_checkpoint(epoch, val_acc, val_loss, train_acc, train_loss, is_best=True) self.logger.info(f"✓ New best model saved! Val Acc: {val_acc:.2f}%") else: self.patience_counter += 1 self.logger.info(f"No improvement. Patience: {self.patience_counter}/{self.args.early_stop_patience}") # Save regular checkpoint if (epoch + 1) % self.args.save_freq == 0: self.save_checkpoint(epoch, val_acc, val_loss, train_acc, train_loss, is_best=False) # Early stopping if self.patience_counter >= self.args.early_stop_patience: self.logger.info(f"Early stopping triggered after {epoch+1} epochs") break print("-" * 70) training_time = datetime.now() - start_time self.logger.info(f"Training complete! Time: {training_time}") self.logger.info(f"Best Val Acc: {self.best_val_acc:.2f}%") # Save training history self.save_training_history() return self.best_val_acc def _compute_per_class_accuracy(self, preds, labels): per_class_acc = [] for class_idx in range(10): mask = labels == class_idx if mask.sum() > 0: class_acc = 100. * (preds[mask] == labels[mask]).sum() / mask.sum() per_class_acc.append(class_acc) else: per_class_acc.append(0.0) return per_class_acc def save_checkpoint(self, epoch, val_acc, val_loss, train_acc, train_loss, is_best=False): checkpoint = { 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'val_acc': val_acc, 'val_loss': val_loss, 'train_acc': train_acc, 'train_loss': train_loss, 'best_val_acc': self.best_val_acc, 'args': vars(self.args) } if is_best: path = Path(self.args.save_dir) / 'best_model.pth' else: path = Path(self.args.save_dir) / f'checkpoint_epoch_{epoch+1}.pth' torch.save(checkpoint, path) def save_training_history(self): history = { 'train_losses': self.train_losses, 'val_losses': self.val_losses, 'train_accs': self.train_accs, 'val_accs': self.val_accs, 'best_val_acc': self.best_val_acc } path = Path(self.args.save_dir) / 'training_history.json' with open(path, 'w') as f: json.dump(history, f, indent=4) self.logger.info(f"Training history saved to {path}") # Visualization functions def plot_training_curves(history_path, save_path): with open(history_path, 'r') as f: history = json.load(f) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) epochs_range = range(1, len(history['train_losses']) + 1) ax1.plot(epochs_range, history['train_losses'], 'b-', label='Train Loss', linewidth=2) ax1.plot(epochs_range, history['val_losses'], 'r-', label='Val Loss', linewidth=2) ax1.set_xlabel('Epoch', fontsize=12) ax1.set_ylabel('Loss', fontsize=12) ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold') ax1.legend() ax1.grid(True, alpha=0.3) ax2.plot(epochs_range, history['train_accs'], 'b-', label='Train Acc', linewidth=2) ax2.plot(epochs_range, history['val_accs'], 'r-', label='Val Acc', linewidth=2) ax2.set_xlabel('Epoch', fontsize=12) ax2.set_ylabel('Accuracy (%)', fontsize=12) ax2.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold') ax2.legend() ax2.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(save_path, dpi=150) plt.close() def plot_confusion_matrix(y_true, y_pred, save_path): cm = confusion_matrix(y_true, y_pred) plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=range(10), yticklabels=range(10)) plt.xlabel('Predicted Label', fontsize=12) plt.ylabel('True Label', fontsize=12) plt.title('Confusion Matrix', fontsize=14, fontweight='bold') plt.tight_layout() plt.savefig(save_path, dpi=150) plt.close() def plot_predictions(model, test_loader, device, save_path, num_samples=20): model.eval() dataiter = iter(test_loader) images, labels = next(dataiter) images, labels = images.to(device), labels.to(device) rows = 4 cols = num_samples // rows fig, axes = plt.subplots(rows, cols, figsize=(15, 8)) axes = axes.ravel() with torch.no_grad(): outputs = model(images[:num_samples]) _, predicted = torch.max(outputs, 1) probs = torch.softmax(outputs, dim=1) for i in range(num_samples): img = images[i].cpu().squeeze().numpy() # Denormalize img = img * 0.3081 + 0.1307 img = np.clip(img, 0, 1) axes[i].imshow(img, cmap='gray') color = 'green' if predicted[i] == labels[i] else 'red' confidence = probs[i][predicted[i]].item() * 100 axes[i].set_title(f"Pred: {predicted[i].item()} ({confidence:.1f}%)\nTrue: {labels[i].item()}", color=color, fontweight='bold', fontsize=9) axes[i].axis('off') plt.tight_layout() plt.savefig(save_path, dpi=150) plt.close() def evaluate_model(model, test_loader, device, logger, save_dir): model.eval() all_preds = [] all_labels = [] with torch.no_grad(): for images, labels in tqdm(test_loader, desc="Evaluating"): images = images.to(device) outputs = model(images) _, predicted = torch.max(outputs, 1) all_preds.extend(predicted.cpu().numpy()) all_labels.extend(labels.numpy()) all_preds = np.array(all_preds) all_labels = np.array(all_labels) # Overall accuracy accuracy = 100. * (all_preds == all_labels).sum() / len(all_labels) logger.info(f"Test Accuracy: {accuracy:.2f}%") # Classification report report = classification_report(all_labels, all_preds, target_names=[str(i) for i in range(10)]) logger.info(f"\nClassification Report:\n{report}") # Save report report_path = Path(save_dir) / 'classification_report.txt' with open(report_path, 'w') as f: f.write(report) # Plot confusion matrix cm_path = Path(save_dir) / 'confusion_matrix.png' plot_confusion_matrix(all_labels, all_preds, cm_path) logger.info(f"Confusion matrix saved to {cm_path}") return accuracy, all_preds, all_labels def parse_args(): parser = argparse.ArgumentParser(description='Enhanced MNIST Classifier with Advanced Features') # Model settings parser.add_argument('--model-type', type=str, default='cnn', choices=['cnn', 'fc'], help='Model architecture type') parser.add_argument('--dropout-rate', type=float, default=0.3, help='Dropout rate') # Training settings parser.add_argument('--epochs', type=int, default=20, help='Number of epochs') parser.add_argument('--batch-size', type=int, default=128, help='Batch size') parser.add_argument('--lr', type=float, default=0.001, help='Initial learning rate') parser.add_argument('--optimizer', type=str, default='adamw', choices=['adam', 'sgd', 'adamw'], help='Optimizer choice') parser.add_argument('--weight-decay', type=float, default=1e-4, help='Weight decay') parser.add_argument('--scheduler', type=str, default='onecycle', choices=['cosine', 'onecycle', 'step'], help='Learning rate scheduler') parser.add_argument('--warmup-epochs', type=int, default=2, help='Number of warmup epochs') # Data settings parser.add_argument('--data-dir', type=str, default='./data', help='Data directory') parser.add_argument('--val-split', type=float, default=0.1, help='Validation split ratio') parser.add_argument('--num-workers', type=int, default=4, help='Number of data loading workers') # Regularization parser.add_argument('--early-stop-patience', type=int, default=7, help='Early stopping patience') parser.add_argument('--use-amp', action='store_true', help='Use automatic mixed precision') # Saving and logging parser.add_argument('--save-dir', type=str, default='./checkpoints', help='Save directory') parser.add_argument('--log-dir', type=str, default='./runs', help='TensorBoard log directory') parser.add_argument('--save-freq', type=int, default=5, help='Save checkpoint every N epochs') parser.add_argument('--seed', type=int, default=42, help='Random seed') # Hardware parser.add_argument('--use-gpu', action='store_true', help='Use GPU if available') return parser.parse_args() def main(): args = parse_args() # Set random seed set_seed(args.seed) # Create directories Path(args.save_dir).mkdir(parents=True, exist_ok=True) Path(args.log_dir).mkdir(parents=True, exist_ok=True) # Setup logging logger = setup_logging(args.save_dir) logger.info(f"Arguments: {vars(args)}") # Device handling device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else 'cpu') logger.info(f"Using device: {device}") if device.type == 'cuda': logger.info(f"GPU: {torch.cuda.get_device_name(0)}") # Enhanced data preparation with augmentation os.makedirs(args.data_dir, exist_ok=True) train_transform = transforms.Compose([ transforms.RandomRotation(10), transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), transforms.RandomErasing(p=0.1, scale=(0.02, 0.1)) ]) test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) # Load datasets full_train_dataset = datasets.MNIST(root=args.data_dir, train=True, download=True, transform=train_transform) test_dataset = datasets.MNIST(root=args.data_dir, train=False, download=True, transform=test_transform) # Split train into train and validation val_size = int(len(full_train_dataset) * args.val_split) train_size = len(full_train_dataset) - val_size train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size]) logger.info(f"Train size: {train_size}, Val size: {val_size}, Test size: {len(test_dataset)}") # Create data loaders train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True if device.type == 'cuda' else False, persistent_workers=True if args.num_workers > 0 else False ) val_loader = DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True if device.type == 'cuda' else False, persistent_workers=True if args.num_workers > 0 else False ) test_loader = DataLoader( test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True if device.type == 'cuda' else False, persistent_workers=True if args.num_workers > 0 else False ) # Create model if args.model_type == 'cnn': model = ConvNet(dropout_rate=args.dropout_rate).to(device) else: model = ImprovedNN(dropout_rate=args.dropout_rate).to(device) logger.info(f"Model: {args.model_type}") logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") # Loss and Optimizer criterion = nn.CrossEntropyLoss() if args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'adamw': optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) else: optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay, nesterov=True) # Learning rate scheduler if args.scheduler == 'cosine': scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - args.warmup_epochs) elif args.scheduler == 'onecycle': scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=args.lr * 10, epochs=args.epochs - args.warmup_epochs, steps_per_epoch=len(train_loader) ) else: scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) # Create trainer trainer = Trainer(model, train_loader, val_loader, test_loader, criterion, optimizer, scheduler, device, args, logger) # Train model best_val_acc = trainer.train() # Load best model best_model_path = Path(args.save_dir) / 'best_model.pth' checkpoint = torch.load(best_model_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) logger.info(f"Loaded best model from epoch {checkpoint['epoch']+1}") # Final evaluation on test set logger.info("\n" + "="*70) logger.info("Final Evaluation on Test Set") logger.info("="*70) test_acc, test_preds, test_labels = evaluate_model(model, test_loader, device, logger, args.save_dir) # Plot training curves history_path = Path(args.save_dir) / 'training_history.json' curves_path = Path(args.save_dir) / 'training_curves.png' plot_training_curves(history_path, curves_path) logger.info(f"Training curves saved to {curves_path}") # Plot predictions pred_path = Path(args.save_dir) / 'predictions.png' plot_predictions(model, test_loader, device, pred_path) logger.info(f"Predictions saved to {pred_path}") # Print usage instructions logger.info("\n" + "="*70) logger.info("Model Loading Instructions:") logger.info(f"from improved_mnist_classifier import {model.__class__.__name__}") logger.info(f"model = {model.__class__.__name__}().to(device)") logger.info(f"checkpoint = torch.load('{best_model_path}')") logger.info(f"model.load_state_dict(checkpoint['model_state_dict'])") logger.info(f"model.eval()") logger.info("="*70) logger.info(f"\nTraining complete! Best Val Acc: {best_val_acc:.2f}%, Test Acc: {test_acc:.2f}%") if __name__ == '__main__': main()