GilbertKrantz
FIX : ty Type Error
6e8e8fb
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import gc
from utils.Evaluator import ClassificationEvaluator
from utils.Callback import EarlyStopping
def train_model(
model: nn.Module,
criterion: nn.Module,
optimizer: optim.Optimizer,
scheduler,
train_loader: DataLoader,
val_loader: DataLoader,
early_stopping: EarlyStopping,
epochs: int = 15,
use_ddp: bool = False,
) -> tuple:
"""
Train the model and perform validation using multiple GPUs.
Supports both DataParallel (DP) and DistributedDataParallel (DDP) modes.
Args:
model: Model to train
criterion: Loss function
optimizer: Optimizer for training
scheduler: Learning rate scheduler
train_loader: DataLoader for training data
val_loader: DataLoader for validation data
early_stopping: Early stopping handler
epochs: Maximum number of epochs to train
use_ddp: Whether to use DistributedDataParallel (True) or DataParallel (False)
"""
# Check available GPUs
num_gpus = torch.cuda.device_count()
if num_gpus < 2:
print(
f"Warning: Requested multi-GPU training but only {num_gpus} GPU(s) available. Continuing with available resources."
)
else:
print(f"Using {num_gpus} GPUs for training")
# Setup device and model
if num_gpus >= 2:
if use_ddp:
# For DistributedDataParallel
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# Initialize process group
dist.init_process_group(backend="nccl")
local_rank = dist.get_rank()
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
model = model.to(device)
model = DDP(model, device_ids=[local_rank])
else:
# For DataParallel (simpler to use)
device = torch.device("cuda:0")
model = model.to(device)
model = torch.nn.DataParallel(model)
else:
# Single GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
train_losses = []
val_losses = []
train_accs = []
val_accs = []
# Store validation predictions and labels for final evaluation
all_val_labels = []
all_val_preds = []
all_val_scores = []
for epoch in range(epochs):
print(f"Epoch {epoch+1}/{epochs}")
# Training phase
model.train()
running_loss = 0.0
correct = 0
total = 0
for inputs, labels in tqdm(train_loader, desc="Training"):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
if total == 0:
print("Warning: No training samples found. Skipping training.")
epoch_train_loss = 0.0
epoch_train_acc = 0.0
else:
epoch_train_loss = running_loss / total
epoch_train_acc = correct / total
train_losses.append(epoch_train_loss)
train_accs.append(epoch_train_acc)
# Validation phase
model.eval()
running_loss = 0.0
correct = 0
total = 0
all_labels = []
all_preds = []
all_scores = []
with torch.no_grad():
for inputs, labels in tqdm(val_loader, desc="Validation"):
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
running_loss += loss.item() * inputs.size(0)
probs = F.softmax(outputs, dim=1)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
all_labels.extend(labels.cpu().numpy().tolist())
all_preds.extend(predicted.cpu().numpy().tolist())
all_scores.append(probs.cpu().numpy())
# Mitigate DivideByZeroError
if total == 0:
print("Warning: No validation samples found. Skipping validation.")
epoch_val_loss = 0.0
epoch_val_acc = 0.0
else:
epoch_val_loss = running_loss / total
epoch_val_acc = correct / total
val_losses.append(epoch_val_loss)
val_accs.append(epoch_val_acc)
all_scores = np.vstack(all_scores) if all_scores else np.array([])
# Store validation results for the final epoch
all_val_labels = all_labels
all_val_preds = all_preds
all_val_scores = all_scores
# Update learning rate scheduler
scheduler.step(epoch_val_loss)
print(f"Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}")
print(f"Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")
print(f"Learning rate: {optimizer.param_groups[0]['lr']:.6f}")
# Check early stopping
early_stopping(epoch_val_loss)
if early_stopping.early_stop:
print("Early stopping triggered!")
break
# Free up memory
del all_labels, all_preds, all_scores
gc.collect()
torch.cuda.empty_cache()
# Clean up DDP if used
if num_gpus >= 2 and use_ddp:
dist.destroy_process_group()
return (
model,
train_losses,
val_losses,
train_accs,
val_accs,
all_val_labels,
all_val_preds,
all_val_scores,
)
def model_train(
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
dataset,
epochs: int = 20,
) -> dict:
model_name = type(model).__name__
if hasattr(model, "pretrained_cfg") and "name" in model.pretrained_cfg:
model_name = model.pretrained_cfg["name"]
print(f"\n{'='*20} Training {model_name} {'='*20}\n")
class_names = dataset.classes
num_classes = len(class_names)
learning_rate = 0.001
try:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.1, patience=3
)
early_stopping = EarlyStopping(patience=5)
(
model,
train_losses,
val_losses,
train_accs,
val_accs,
val_labels,
val_preds,
val_scores,
) = train_model(
model,
nn.CrossEntropyLoss(),
optimizer,
scheduler,
train_loader,
val_loader,
early_stopping,
epochs=epochs,
use_ddp=False,
)
print(f"\n{'='*20} Evaluation for {model_name} {'='*20}\n")
evaluator = ClassificationEvaluator(
class_names=class_names,
)
evaluator.plot_training_history(train_losses, val_losses, train_accs, val_accs)
# Process validation predictions and labels
try:
evaluator.plot_confusion_matrix(val_labels, val_preds)
evaluator.plot_per_class_accuracy(val_labels, val_preds)
# Get metrics from the updated function including kappa
accuracy, report_dict, roc_auc_dict, pr_auc_dict, kappa = (
evaluator.compute_metrics(
val_labels,
val_preds,
val_scores,
model_name,
)
)
# Build a results dictionary including kappa
results = {
"accuracy": accuracy,
"report": report_dict,
"roc_auc": roc_auc_dict,
"pr_auc": pr_auc_dict,
"kappa": kappa,
}
return results
except Exception as viz_error:
print(f"Error in visualization: {viz_error}")
import traceback
traceback.print_exc()
return {"accuracy": None}
except Exception as e:
print(f"Error occurred when training {model_name}: {e}")
import traceback
traceback.print_exc()
return {"accuracy": None}
finally:
# Clean up memory
if "optimizer" in locals():
del optimizer
if "scheduler" in locals():
del scheduler
if "early_stopping" in locals():
del early_stopping
if "train_losses" in locals():
del train_losses
del val_losses
del train_accs
del val_accs
del val_labels
del val_preds
del val_scores
gc.collect()
torch.cuda.empty_cache()