|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torchvision import datasets, transforms |
|
|
from torch.utils.data import DataLoader, random_split |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
import numpy as np |
|
|
import argparse |
|
|
import os |
|
|
import logging |
|
|
from tqdm import tqdm |
|
|
from datetime import datetime |
|
|
import json |
|
|
import random |
|
|
from sklearn.metrics import confusion_matrix, classification_report |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
def setup_logging(log_dir): |
|
|
log_dir = Path(log_dir) |
|
|
log_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
|
handlers=[ |
|
|
logging.FileHandler(log_dir / 'training.log'), |
|
|
logging.StreamHandler() |
|
|
] |
|
|
) |
|
|
return logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def set_seed(seed=42): |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
|
|
|
class ConvNet(nn.Module): |
|
|
"""Convolutional Neural Network for MNIST""" |
|
|
def __init__(self, dropout_rate=0.3, num_classes=10): |
|
|
super(ConvNet, self).__init__() |
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) |
|
|
self.bn1 = nn.BatchNorm2d(32) |
|
|
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) |
|
|
self.bn2 = nn.BatchNorm2d(64) |
|
|
|
|
|
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) |
|
|
self.bn3 = nn.BatchNorm2d(128) |
|
|
self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1) |
|
|
self.bn4 = nn.BatchNorm2d(128) |
|
|
|
|
|
self.pool = nn.MaxPool2d(2, 2) |
|
|
self.dropout_conv = nn.Dropout2d(dropout_rate * 0.5) |
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(128 * 7 * 7, 256) |
|
|
self.bn5 = nn.BatchNorm1d(256) |
|
|
self.dropout1 = nn.Dropout(dropout_rate) |
|
|
|
|
|
self.fc2 = nn.Linear(256, 128) |
|
|
self.bn6 = nn.BatchNorm1d(128) |
|
|
self.dropout2 = nn.Dropout(dropout_rate * 0.5) |
|
|
|
|
|
self.fc3 = nn.Linear(128, num_classes) |
|
|
|
|
|
self._initialize_weights() |
|
|
|
|
|
def _initialize_weights(self): |
|
|
for m in self.modules(): |
|
|
if isinstance(m, (nn.Conv2d, nn.Linear)): |
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
|
|
if m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): |
|
|
nn.init.constant_(m.weight, 1) |
|
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x = self.conv1(x) |
|
|
x = self.bn1(x) |
|
|
x = torch.relu(x) |
|
|
x = self.conv2(x) |
|
|
x = self.bn2(x) |
|
|
x = torch.relu(x) |
|
|
x = self.pool(x) |
|
|
x = self.dropout_conv(x) |
|
|
|
|
|
|
|
|
x = self.conv3(x) |
|
|
x = self.bn3(x) |
|
|
x = torch.relu(x) |
|
|
x = self.conv4(x) |
|
|
x = self.bn4(x) |
|
|
x = torch.relu(x) |
|
|
x = self.pool(x) |
|
|
x = self.dropout_conv(x) |
|
|
|
|
|
|
|
|
x = x.view(x.size(0), -1) |
|
|
|
|
|
|
|
|
x = self.fc1(x) |
|
|
x = self.bn5(x) |
|
|
x = torch.relu(x) |
|
|
x = self.dropout1(x) |
|
|
|
|
|
x = self.fc2(x) |
|
|
x = self.bn6(x) |
|
|
x = torch.relu(x) |
|
|
x = self.dropout2(x) |
|
|
|
|
|
x = self.fc3(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class ImprovedNN(nn.Module): |
|
|
"""Enhanced fully connected network with configurable architecture""" |
|
|
def __init__(self, input_size=784, hidden_sizes=[512, 256, 128], |
|
|
num_classes=10, dropout_rate=0.3): |
|
|
super(ImprovedNN, self).__init__() |
|
|
|
|
|
layers = [] |
|
|
prev_size = input_size |
|
|
|
|
|
for i, hidden_size in enumerate(hidden_sizes): |
|
|
layers.extend([ |
|
|
nn.Linear(prev_size, hidden_size), |
|
|
nn.BatchNorm1d(hidden_size), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout_rate if i < len(hidden_sizes) - 1 else dropout_rate * 0.5) |
|
|
]) |
|
|
prev_size = hidden_size |
|
|
|
|
|
layers.append(nn.Linear(prev_size, num_classes)) |
|
|
self.network = nn.Sequential(*layers) |
|
|
|
|
|
self._initialize_weights() |
|
|
|
|
|
def _initialize_weights(self): |
|
|
for m in self.modules(): |
|
|
if isinstance(m, nn.Linear): |
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
|
|
if m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.BatchNorm1d): |
|
|
nn.init.constant_(m.weight, 1) |
|
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x.view(x.size(0), -1) |
|
|
return self.network(x) |
|
|
|
|
|
|
|
|
class Trainer: |
|
|
def __init__(self, model, train_loader, val_loader, test_loader, |
|
|
criterion, optimizer, scheduler, device, args, logger): |
|
|
self.model = model |
|
|
self.train_loader = train_loader |
|
|
self.val_loader = val_loader |
|
|
self.test_loader = test_loader |
|
|
self.criterion = criterion |
|
|
self.optimizer = optimizer |
|
|
self.scheduler = scheduler |
|
|
self.device = device |
|
|
self.args = args |
|
|
self.logger = logger |
|
|
|
|
|
|
|
|
self.writer = SummaryWriter(log_dir=args.log_dir) |
|
|
|
|
|
|
|
|
self.train_losses = [] |
|
|
self.val_losses = [] |
|
|
self.train_accs = [] |
|
|
self.val_accs = [] |
|
|
self.best_val_acc = 0.0 |
|
|
self.patience_counter = 0 |
|
|
|
|
|
|
|
|
self.scaler = torch.cuda.amp.GradScaler() if args.use_amp and device.type == 'cuda' else None |
|
|
|
|
|
def train_epoch(self, epoch): |
|
|
self.model.train() |
|
|
running_loss = 0.0 |
|
|
correct = 0 |
|
|
total = 0 |
|
|
|
|
|
progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1} [Train]") |
|
|
|
|
|
for batch_idx, (images, labels) in enumerate(progress_bar): |
|
|
images, labels = images.to(self.device, non_blocking=True), labels.to(self.device, non_blocking=True) |
|
|
|
|
|
self.optimizer.zero_grad(set_to_none=True) |
|
|
|
|
|
|
|
|
if self.scaler: |
|
|
with torch.cuda.amp.autocast(): |
|
|
outputs = self.model(images) |
|
|
loss = self.criterion(outputs, labels) |
|
|
|
|
|
self.scaler.scale(loss).backward() |
|
|
self.scaler.unscale_(self.optimizer) |
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
|
|
self.scaler.step(self.optimizer) |
|
|
self.scaler.update() |
|
|
else: |
|
|
outputs = self.model(images) |
|
|
loss = self.criterion(outputs, labels) |
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
|
|
self.optimizer.step() |
|
|
|
|
|
running_loss += loss.item() |
|
|
_, predicted = torch.max(outputs, 1) |
|
|
total += labels.size(0) |
|
|
correct += (predicted == labels).sum().item() |
|
|
|
|
|
|
|
|
global_step = epoch * len(self.train_loader) + batch_idx |
|
|
if batch_idx % 50 == 0: |
|
|
self.writer.add_scalar('Train/BatchLoss', loss.item(), global_step) |
|
|
self.writer.add_scalar('Train/BatchAcc', 100. * correct / total, global_step) |
|
|
|
|
|
progress_bar.set_postfix({ |
|
|
'Loss': f"{loss.item():.4f}", |
|
|
'Acc': f"{100.*correct/total:.2f}%" |
|
|
}) |
|
|
|
|
|
epoch_loss = running_loss / len(self.train_loader) |
|
|
epoch_acc = 100. * correct / total |
|
|
|
|
|
return epoch_loss, epoch_acc |
|
|
|
|
|
def validate(self, loader, phase="Val"): |
|
|
self.model.eval() |
|
|
running_loss = 0.0 |
|
|
correct = 0 |
|
|
total = 0 |
|
|
|
|
|
all_preds = [] |
|
|
all_labels = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
progress_bar = tqdm(loader, desc=f"[{phase}]") |
|
|
for images, labels in progress_bar: |
|
|
images, labels = images.to(self.device, non_blocking=True), labels.to(self.device, non_blocking=True) |
|
|
|
|
|
if self.scaler: |
|
|
with torch.cuda.amp.autocast(): |
|
|
outputs = self.model(images) |
|
|
loss = self.criterion(outputs, labels) |
|
|
else: |
|
|
outputs = self.model(images) |
|
|
loss = self.criterion(outputs, labels) |
|
|
|
|
|
running_loss += loss.item() |
|
|
_, predicted = torch.max(outputs, 1) |
|
|
total += labels.size(0) |
|
|
correct += (predicted == labels).sum().item() |
|
|
|
|
|
all_preds.extend(predicted.cpu().numpy()) |
|
|
all_labels.extend(labels.cpu().numpy()) |
|
|
|
|
|
progress_bar.set_postfix({ |
|
|
'Loss': f"{loss.item():.4f}", |
|
|
'Acc': f"{100.*correct/total:.2f}%" |
|
|
}) |
|
|
|
|
|
epoch_loss = running_loss / len(loader) |
|
|
epoch_acc = 100. * correct / total |
|
|
|
|
|
return epoch_loss, epoch_acc, np.array(all_preds), np.array(all_labels) |
|
|
|
|
|
def train(self): |
|
|
self.logger.info(f"Starting training for {self.args.epochs} epochs") |
|
|
self.logger.info(f"Model: {self.args.model_type}, Optimizer: {self.args.optimizer}") |
|
|
self.logger.info(f"Learning rate: {self.args.lr}, Batch size: {self.args.batch_size}") |
|
|
|
|
|
start_time = datetime.now() |
|
|
|
|
|
for epoch in range(self.args.epochs): |
|
|
|
|
|
if epoch < self.args.warmup_epochs: |
|
|
warmup_lr = self.args.lr * (epoch + 1) / self.args.warmup_epochs |
|
|
for param_group in self.optimizer.param_groups: |
|
|
param_group['lr'] = warmup_lr |
|
|
|
|
|
train_loss, train_acc = self.train_epoch(epoch) |
|
|
val_loss, val_acc, val_preds, val_labels = self.validate(self.val_loader, "Val") |
|
|
|
|
|
self.train_losses.append(train_loss) |
|
|
self.val_losses.append(val_loss) |
|
|
self.train_accs.append(train_acc) |
|
|
self.val_accs.append(val_acc) |
|
|
|
|
|
|
|
|
if epoch >= self.args.warmup_epochs: |
|
|
self.scheduler.step() |
|
|
|
|
|
current_lr = self.optimizer.param_groups[0]['lr'] |
|
|
|
|
|
|
|
|
self.writer.add_scalar('Epoch/TrainLoss', train_loss, epoch) |
|
|
self.writer.add_scalar('Epoch/ValLoss', val_loss, epoch) |
|
|
self.writer.add_scalar('Epoch/TrainAcc', train_acc, epoch) |
|
|
self.writer.add_scalar('Epoch/ValAcc', val_acc, epoch) |
|
|
self.writer.add_scalar('Epoch/LearningRate', current_lr, epoch) |
|
|
|
|
|
|
|
|
per_class_acc = self._compute_per_class_accuracy(val_preds, val_labels) |
|
|
for class_idx, acc in enumerate(per_class_acc): |
|
|
self.writer.add_scalar(f'PerClass/Val_Class_{class_idx}', acc, epoch) |
|
|
|
|
|
self.logger.info(f"Epoch {epoch+1}/{self.args.epochs} | LR: {current_lr:.6f}") |
|
|
self.logger.info(f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%") |
|
|
self.logger.info(f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%") |
|
|
self.logger.info(f"Per-class Val Acc: {[f'{acc:.1f}%' for acc in per_class_acc]}") |
|
|
|
|
|
|
|
|
if val_acc > self.best_val_acc: |
|
|
self.best_val_acc = val_acc |
|
|
self.patience_counter = 0 |
|
|
self.save_checkpoint(epoch, val_acc, val_loss, train_acc, train_loss, is_best=True) |
|
|
self.logger.info(f"✓ New best model saved! Val Acc: {val_acc:.2f}%") |
|
|
else: |
|
|
self.patience_counter += 1 |
|
|
self.logger.info(f"No improvement. Patience: {self.patience_counter}/{self.args.early_stop_patience}") |
|
|
|
|
|
|
|
|
if (epoch + 1) % self.args.save_freq == 0: |
|
|
self.save_checkpoint(epoch, val_acc, val_loss, train_acc, train_loss, is_best=False) |
|
|
|
|
|
|
|
|
if self.patience_counter >= self.args.early_stop_patience: |
|
|
self.logger.info(f"Early stopping triggered after {epoch+1} epochs") |
|
|
break |
|
|
|
|
|
print("-" * 70) |
|
|
|
|
|
training_time = datetime.now() - start_time |
|
|
self.logger.info(f"Training complete! Time: {training_time}") |
|
|
self.logger.info(f"Best Val Acc: {self.best_val_acc:.2f}%") |
|
|
|
|
|
|
|
|
self.save_training_history() |
|
|
|
|
|
return self.best_val_acc |
|
|
|
|
|
def _compute_per_class_accuracy(self, preds, labels): |
|
|
per_class_acc = [] |
|
|
for class_idx in range(10): |
|
|
mask = labels == class_idx |
|
|
if mask.sum() > 0: |
|
|
class_acc = 100. * (preds[mask] == labels[mask]).sum() / mask.sum() |
|
|
per_class_acc.append(class_acc) |
|
|
else: |
|
|
per_class_acc.append(0.0) |
|
|
return per_class_acc |
|
|
|
|
|
def save_checkpoint(self, epoch, val_acc, val_loss, train_acc, train_loss, is_best=False): |
|
|
checkpoint = { |
|
|
'epoch': epoch, |
|
|
'model_state_dict': self.model.state_dict(), |
|
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
|
'scheduler_state_dict': self.scheduler.state_dict(), |
|
|
'val_acc': val_acc, |
|
|
'val_loss': val_loss, |
|
|
'train_acc': train_acc, |
|
|
'train_loss': train_loss, |
|
|
'best_val_acc': self.best_val_acc, |
|
|
'args': vars(self.args) |
|
|
} |
|
|
|
|
|
if is_best: |
|
|
path = Path(self.args.save_dir) / 'best_model.pth' |
|
|
else: |
|
|
path = Path(self.args.save_dir) / f'checkpoint_epoch_{epoch+1}.pth' |
|
|
|
|
|
torch.save(checkpoint, path) |
|
|
|
|
|
def save_training_history(self): |
|
|
history = { |
|
|
'train_losses': self.train_losses, |
|
|
'val_losses': self.val_losses, |
|
|
'train_accs': self.train_accs, |
|
|
'val_accs': self.val_accs, |
|
|
'best_val_acc': self.best_val_acc |
|
|
} |
|
|
|
|
|
path = Path(self.args.save_dir) / 'training_history.json' |
|
|
with open(path, 'w') as f: |
|
|
json.dump(history, f, indent=4) |
|
|
|
|
|
self.logger.info(f"Training history saved to {path}") |
|
|
|
|
|
|
|
|
def plot_training_curves(history_path, save_path): |
|
|
with open(history_path, 'r') as f: |
|
|
history = json.load(f) |
|
|
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) |
|
|
|
|
|
epochs_range = range(1, len(history['train_losses']) + 1) |
|
|
|
|
|
ax1.plot(epochs_range, history['train_losses'], 'b-', label='Train Loss', linewidth=2) |
|
|
ax1.plot(epochs_range, history['val_losses'], 'r-', label='Val Loss', linewidth=2) |
|
|
ax1.set_xlabel('Epoch', fontsize=12) |
|
|
ax1.set_ylabel('Loss', fontsize=12) |
|
|
ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold') |
|
|
ax1.legend() |
|
|
ax1.grid(True, alpha=0.3) |
|
|
|
|
|
ax2.plot(epochs_range, history['train_accs'], 'b-', label='Train Acc', linewidth=2) |
|
|
ax2.plot(epochs_range, history['val_accs'], 'r-', label='Val Acc', linewidth=2) |
|
|
ax2.set_xlabel('Epoch', fontsize=12) |
|
|
ax2.set_ylabel('Accuracy (%)', fontsize=12) |
|
|
ax2.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold') |
|
|
ax2.legend() |
|
|
ax2.grid(True, alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(save_path, dpi=150) |
|
|
plt.close() |
|
|
|
|
|
def plot_confusion_matrix(y_true, y_pred, save_path): |
|
|
cm = confusion_matrix(y_true, y_pred) |
|
|
|
|
|
plt.figure(figsize=(10, 8)) |
|
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', |
|
|
xticklabels=range(10), yticklabels=range(10)) |
|
|
plt.xlabel('Predicted Label', fontsize=12) |
|
|
plt.ylabel('True Label', fontsize=12) |
|
|
plt.title('Confusion Matrix', fontsize=14, fontweight='bold') |
|
|
plt.tight_layout() |
|
|
plt.savefig(save_path, dpi=150) |
|
|
plt.close() |
|
|
|
|
|
def plot_predictions(model, test_loader, device, save_path, num_samples=20): |
|
|
model.eval() |
|
|
dataiter = iter(test_loader) |
|
|
images, labels = next(dataiter) |
|
|
images, labels = images.to(device), labels.to(device) |
|
|
|
|
|
rows = 4 |
|
|
cols = num_samples // rows |
|
|
fig, axes = plt.subplots(rows, cols, figsize=(15, 8)) |
|
|
axes = axes.ravel() |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(images[:num_samples]) |
|
|
_, predicted = torch.max(outputs, 1) |
|
|
probs = torch.softmax(outputs, dim=1) |
|
|
|
|
|
for i in range(num_samples): |
|
|
img = images[i].cpu().squeeze().numpy() |
|
|
|
|
|
|
|
|
img = img * 0.3081 + 0.1307 |
|
|
img = np.clip(img, 0, 1) |
|
|
|
|
|
axes[i].imshow(img, cmap='gray') |
|
|
color = 'green' if predicted[i] == labels[i] else 'red' |
|
|
confidence = probs[i][predicted[i]].item() * 100 |
|
|
axes[i].set_title(f"Pred: {predicted[i].item()} ({confidence:.1f}%)\nTrue: {labels[i].item()}", |
|
|
color=color, fontweight='bold', fontsize=9) |
|
|
axes[i].axis('off') |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(save_path, dpi=150) |
|
|
plt.close() |
|
|
|
|
|
def evaluate_model(model, test_loader, device, logger, save_dir): |
|
|
model.eval() |
|
|
all_preds = [] |
|
|
all_labels = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for images, labels in tqdm(test_loader, desc="Evaluating"): |
|
|
images = images.to(device) |
|
|
outputs = model(images) |
|
|
_, predicted = torch.max(outputs, 1) |
|
|
|
|
|
all_preds.extend(predicted.cpu().numpy()) |
|
|
all_labels.extend(labels.numpy()) |
|
|
|
|
|
all_preds = np.array(all_preds) |
|
|
all_labels = np.array(all_labels) |
|
|
|
|
|
|
|
|
accuracy = 100. * (all_preds == all_labels).sum() / len(all_labels) |
|
|
logger.info(f"Test Accuracy: {accuracy:.2f}%") |
|
|
|
|
|
|
|
|
report = classification_report(all_labels, all_preds, target_names=[str(i) for i in range(10)]) |
|
|
logger.info(f"\nClassification Report:\n{report}") |
|
|
|
|
|
|
|
|
report_path = Path(save_dir) / 'classification_report.txt' |
|
|
with open(report_path, 'w') as f: |
|
|
f.write(report) |
|
|
|
|
|
|
|
|
cm_path = Path(save_dir) / 'confusion_matrix.png' |
|
|
plot_confusion_matrix(all_labels, all_preds, cm_path) |
|
|
logger.info(f"Confusion matrix saved to {cm_path}") |
|
|
|
|
|
return accuracy, all_preds, all_labels |
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description='Enhanced MNIST Classifier with Advanced Features') |
|
|
|
|
|
|
|
|
parser.add_argument('--model-type', type=str, default='cnn', choices=['cnn', 'fc'], |
|
|
help='Model architecture type') |
|
|
parser.add_argument('--dropout-rate', type=float, default=0.3, help='Dropout rate') |
|
|
|
|
|
|
|
|
parser.add_argument('--epochs', type=int, default=20, help='Number of epochs') |
|
|
parser.add_argument('--batch-size', type=int, default=128, help='Batch size') |
|
|
parser.add_argument('--lr', type=float, default=0.001, help='Initial learning rate') |
|
|
parser.add_argument('--optimizer', type=str, default='adamw', |
|
|
choices=['adam', 'sgd', 'adamw'], help='Optimizer choice') |
|
|
parser.add_argument('--weight-decay', type=float, default=1e-4, help='Weight decay') |
|
|
parser.add_argument('--scheduler', type=str, default='onecycle', |
|
|
choices=['cosine', 'onecycle', 'step'], help='Learning rate scheduler') |
|
|
parser.add_argument('--warmup-epochs', type=int, default=2, help='Number of warmup epochs') |
|
|
|
|
|
|
|
|
parser.add_argument('--data-dir', type=str, default='./data', help='Data directory') |
|
|
parser.add_argument('--val-split', type=float, default=0.1, help='Validation split ratio') |
|
|
parser.add_argument('--num-workers', type=int, default=4, help='Number of data loading workers') |
|
|
|
|
|
|
|
|
parser.add_argument('--early-stop-patience', type=int, default=7, |
|
|
help='Early stopping patience') |
|
|
parser.add_argument('--use-amp', action='store_true', help='Use automatic mixed precision') |
|
|
|
|
|
|
|
|
parser.add_argument('--save-dir', type=str, default='./checkpoints', help='Save directory') |
|
|
parser.add_argument('--log-dir', type=str, default='./runs', help='TensorBoard log directory') |
|
|
parser.add_argument('--save-freq', type=int, default=5, help='Save checkpoint every N epochs') |
|
|
parser.add_argument('--seed', type=int, default=42, help='Random seed') |
|
|
|
|
|
|
|
|
parser.add_argument('--use-gpu', action='store_true', help='Use GPU if available') |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
set_seed(args.seed) |
|
|
|
|
|
|
|
|
Path(args.save_dir).mkdir(parents=True, exist_ok=True) |
|
|
Path(args.log_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
logger = setup_logging(args.save_dir) |
|
|
logger.info(f"Arguments: {vars(args)}") |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else 'cpu') |
|
|
logger.info(f"Using device: {device}") |
|
|
if device.type == 'cuda': |
|
|
logger.info(f"GPU: {torch.cuda.get_device_name(0)}") |
|
|
|
|
|
|
|
|
os.makedirs(args.data_dir, exist_ok=True) |
|
|
|
|
|
train_transform = transforms.Compose([ |
|
|
transforms.RandomRotation(10), |
|
|
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.1307,), (0.3081,)), |
|
|
transforms.RandomErasing(p=0.1, scale=(0.02, 0.1)) |
|
|
]) |
|
|
|
|
|
test_transform = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.1307,), (0.3081,)) |
|
|
]) |
|
|
|
|
|
|
|
|
full_train_dataset = datasets.MNIST(root=args.data_dir, train=True, download=True, transform=train_transform) |
|
|
test_dataset = datasets.MNIST(root=args.data_dir, train=False, download=True, transform=test_transform) |
|
|
|
|
|
|
|
|
val_size = int(len(full_train_dataset) * args.val_split) |
|
|
train_size = len(full_train_dataset) - val_size |
|
|
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size]) |
|
|
|
|
|
logger.info(f"Train size: {train_size}, Val size: {val_size}, Test size: {len(test_dataset)}") |
|
|
|
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=args.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=args.num_workers, |
|
|
pin_memory=True if device.type == 'cuda' else False, |
|
|
persistent_workers=True if args.num_workers > 0 else False |
|
|
) |
|
|
val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=args.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=args.num_workers, |
|
|
pin_memory=True if device.type == 'cuda' else False, |
|
|
persistent_workers=True if args.num_workers > 0 else False |
|
|
) |
|
|
test_loader = DataLoader( |
|
|
test_dataset, |
|
|
batch_size=args.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=args.num_workers, |
|
|
pin_memory=True if device.type == 'cuda' else False, |
|
|
persistent_workers=True if args.num_workers > 0 else False |
|
|
) |
|
|
|
|
|
|
|
|
if args.model_type == 'cnn': |
|
|
model = ConvNet(dropout_rate=args.dropout_rate).to(device) |
|
|
else: |
|
|
model = ImprovedNN(dropout_rate=args.dropout_rate).to(device) |
|
|
|
|
|
logger.info(f"Model: {args.model_type}") |
|
|
logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") |
|
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
if args.optimizer == 'adam': |
|
|
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) |
|
|
elif args.optimizer == 'adamw': |
|
|
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) |
|
|
else: |
|
|
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, |
|
|
weight_decay=args.weight_decay, nesterov=True) |
|
|
|
|
|
|
|
|
if args.scheduler == 'cosine': |
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - args.warmup_epochs) |
|
|
elif args.scheduler == 'onecycle': |
|
|
scheduler = optim.lr_scheduler.OneCycleLR( |
|
|
optimizer, max_lr=args.lr * 10, |
|
|
epochs=args.epochs - args.warmup_epochs, |
|
|
steps_per_epoch=len(train_loader) |
|
|
) |
|
|
else: |
|
|
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) |
|
|
|
|
|
|
|
|
trainer = Trainer(model, train_loader, val_loader, test_loader, |
|
|
criterion, optimizer, scheduler, device, args, logger) |
|
|
|
|
|
|
|
|
best_val_acc = trainer.train() |
|
|
|
|
|
|
|
|
best_model_path = Path(args.save_dir) / 'best_model.pth' |
|
|
checkpoint = torch.load(best_model_path, map_location=device) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
logger.info(f"Loaded best model from epoch {checkpoint['epoch']+1}") |
|
|
|
|
|
|
|
|
logger.info("\n" + "="*70) |
|
|
logger.info("Final Evaluation on Test Set") |
|
|
logger.info("="*70) |
|
|
test_acc, test_preds, test_labels = evaluate_model(model, test_loader, device, logger, args.save_dir) |
|
|
|
|
|
|
|
|
history_path = Path(args.save_dir) / 'training_history.json' |
|
|
curves_path = Path(args.save_dir) / 'training_curves.png' |
|
|
plot_training_curves(history_path, curves_path) |
|
|
logger.info(f"Training curves saved to {curves_path}") |
|
|
|
|
|
|
|
|
pred_path = Path(args.save_dir) / 'predictions.png' |
|
|
plot_predictions(model, test_loader, device, pred_path) |
|
|
logger.info(f"Predictions saved to {pred_path}") |
|
|
|
|
|
|
|
|
logger.info("\n" + "="*70) |
|
|
logger.info("Model Loading Instructions:") |
|
|
logger.info(f"from improved_mnist_classifier import {model.__class__.__name__}") |
|
|
logger.info(f"model = {model.__class__.__name__}().to(device)") |
|
|
logger.info(f"checkpoint = torch.load('{best_model_path}')") |
|
|
logger.info(f"model.load_state_dict(checkpoint['model_state_dict'])") |
|
|
logger.info(f"model.eval()") |
|
|
logger.info("="*70) |
|
|
|
|
|
logger.info(f"\nTraining complete! Best Val Acc: {best_val_acc:.2f}%, Test Acc: {test_acc:.2f}%") |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |