Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| import gc | |
| from utils.Evaluator import ClassificationEvaluator | |
| from utils.Callback import EarlyStopping | |
| def train_model( | |
| model: nn.Module, | |
| criterion: nn.Module, | |
| optimizer: optim.Optimizer, | |
| scheduler, | |
| train_loader: DataLoader, | |
| val_loader: DataLoader, | |
| early_stopping: EarlyStopping, | |
| epochs: int = 15, | |
| use_ddp: bool = False, | |
| ) -> tuple: | |
| """ | |
| Train the model and perform validation using multiple GPUs. | |
| Supports both DataParallel (DP) and DistributedDataParallel (DDP) modes. | |
| Args: | |
| model: Model to train | |
| criterion: Loss function | |
| optimizer: Optimizer for training | |
| scheduler: Learning rate scheduler | |
| train_loader: DataLoader for training data | |
| val_loader: DataLoader for validation data | |
| early_stopping: Early stopping handler | |
| epochs: Maximum number of epochs to train | |
| use_ddp: Whether to use DistributedDataParallel (True) or DataParallel (False) | |
| """ | |
| # Check available GPUs | |
| num_gpus = torch.cuda.device_count() | |
| if num_gpus < 2: | |
| print( | |
| f"Warning: Requested multi-GPU training but only {num_gpus} GPU(s) available. Continuing with available resources." | |
| ) | |
| else: | |
| print(f"Using {num_gpus} GPUs for training") | |
| # Setup device and model | |
| if num_gpus >= 2: | |
| if use_ddp: | |
| # For DistributedDataParallel | |
| import torch.distributed as dist | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| # Initialize process group | |
| dist.init_process_group(backend="nccl") | |
| local_rank = dist.get_rank() | |
| torch.cuda.set_device(local_rank) | |
| device = torch.device(f"cuda:{local_rank}") | |
| model = model.to(device) | |
| model = DDP(model, device_ids=[local_rank]) | |
| else: | |
| # For DataParallel (simpler to use) | |
| device = torch.device("cuda:0") | |
| model = model.to(device) | |
| model = torch.nn.DataParallel(model) | |
| else: | |
| # Single GPU | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| model = model.to(device) | |
| train_losses = [] | |
| val_losses = [] | |
| train_accs = [] | |
| val_accs = [] | |
| # Store validation predictions and labels for final evaluation | |
| all_val_labels = [] | |
| all_val_preds = [] | |
| all_val_scores = [] | |
| for epoch in range(epochs): | |
| print(f"Epoch {epoch+1}/{epochs}") | |
| # Training phase | |
| model.train() | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| for inputs, labels in tqdm(train_loader, desc="Training"): | |
| inputs, labels = inputs.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() * inputs.size(0) | |
| _, predicted = torch.max(outputs, 1) | |
| total += labels.size(0) | |
| correct += (predicted == labels).sum().item() | |
| if total == 0: | |
| print("Warning: No training samples found. Skipping training.") | |
| epoch_train_loss = 0.0 | |
| epoch_train_acc = 0.0 | |
| else: | |
| epoch_train_loss = running_loss / total | |
| epoch_train_acc = correct / total | |
| train_losses.append(epoch_train_loss) | |
| train_accs.append(epoch_train_acc) | |
| # Validation phase | |
| model.eval() | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| all_labels = [] | |
| all_preds = [] | |
| all_scores = [] | |
| with torch.no_grad(): | |
| for inputs, labels in tqdm(val_loader, desc="Validation"): | |
| inputs, labels = inputs.to(device), labels.to(device) | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| running_loss += loss.item() * inputs.size(0) | |
| probs = F.softmax(outputs, dim=1) | |
| _, predicted = torch.max(outputs, 1) | |
| total += labels.size(0) | |
| correct += (predicted == labels).sum().item() | |
| all_labels.extend(labels.cpu().numpy().tolist()) | |
| all_preds.extend(predicted.cpu().numpy().tolist()) | |
| all_scores.append(probs.cpu().numpy()) | |
| # Mitigate DivideByZeroError | |
| if total == 0: | |
| print("Warning: No validation samples found. Skipping validation.") | |
| epoch_val_loss = 0.0 | |
| epoch_val_acc = 0.0 | |
| else: | |
| epoch_val_loss = running_loss / total | |
| epoch_val_acc = correct / total | |
| val_losses.append(epoch_val_loss) | |
| val_accs.append(epoch_val_acc) | |
| all_scores = np.vstack(all_scores) if all_scores else np.array([]) | |
| # Store validation results for the final epoch | |
| all_val_labels = all_labels | |
| all_val_preds = all_preds | |
| all_val_scores = all_scores | |
| # Update learning rate scheduler | |
| scheduler.step(epoch_val_loss) | |
| print(f"Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}") | |
| print(f"Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}") | |
| print(f"Learning rate: {optimizer.param_groups[0]['lr']:.6f}") | |
| # Check early stopping | |
| early_stopping(epoch_val_loss) | |
| if early_stopping.early_stop: | |
| print("Early stopping triggered!") | |
| break | |
| # Free up memory | |
| del all_labels, all_preds, all_scores | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # Clean up DDP if used | |
| if num_gpus >= 2 and use_ddp: | |
| dist.destroy_process_group() | |
| return ( | |
| model, | |
| train_losses, | |
| val_losses, | |
| train_accs, | |
| val_accs, | |
| all_val_labels, | |
| all_val_preds, | |
| all_val_scores, | |
| ) | |
| def model_train( | |
| model: nn.Module, | |
| train_loader: DataLoader, | |
| val_loader: DataLoader, | |
| dataset, | |
| epochs: int = 20, | |
| ) -> dict: | |
| model_name = type(model).__name__ | |
| if hasattr(model, "pretrained_cfg") and "name" in model.pretrained_cfg: | |
| model_name = model.pretrained_cfg["name"] | |
| print(f"\n{'='*20} Training {model_name} {'='*20}\n") | |
| class_names = dataset.classes | |
| num_classes = len(class_names) | |
| learning_rate = 0.001 | |
| try: | |
| optimizer = optim.Adam(model.parameters(), lr=learning_rate) | |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, mode="min", factor=0.1, patience=3 | |
| ) | |
| early_stopping = EarlyStopping(patience=5) | |
| ( | |
| model, | |
| train_losses, | |
| val_losses, | |
| train_accs, | |
| val_accs, | |
| val_labels, | |
| val_preds, | |
| val_scores, | |
| ) = train_model( | |
| model, | |
| nn.CrossEntropyLoss(), | |
| optimizer, | |
| scheduler, | |
| train_loader, | |
| val_loader, | |
| early_stopping, | |
| epochs=epochs, | |
| use_ddp=False, | |
| ) | |
| print(f"\n{'='*20} Evaluation for {model_name} {'='*20}\n") | |
| evaluator = ClassificationEvaluator( | |
| class_names=class_names, | |
| ) | |
| evaluator.plot_training_history(train_losses, val_losses, train_accs, val_accs) | |
| # Process validation predictions and labels | |
| try: | |
| evaluator.plot_confusion_matrix(val_labels, val_preds) | |
| evaluator.plot_per_class_accuracy(val_labels, val_preds) | |
| # Get metrics from the updated function including kappa | |
| accuracy, report_dict, roc_auc_dict, pr_auc_dict, kappa = ( | |
| evaluator.compute_metrics( | |
| val_labels, | |
| val_preds, | |
| val_scores, | |
| model_name, | |
| ) | |
| ) | |
| # Build a results dictionary including kappa | |
| results = { | |
| "accuracy": accuracy, | |
| "report": report_dict, | |
| "roc_auc": roc_auc_dict, | |
| "pr_auc": pr_auc_dict, | |
| "kappa": kappa, | |
| } | |
| return results | |
| except Exception as viz_error: | |
| print(f"Error in visualization: {viz_error}") | |
| import traceback | |
| traceback.print_exc() | |
| return {"accuracy": None} | |
| except Exception as e: | |
| print(f"Error occurred when training {model_name}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return {"accuracy": None} | |
| finally: | |
| # Clean up memory | |
| if "optimizer" in locals(): | |
| del optimizer | |
| if "scheduler" in locals(): | |
| del scheduler | |
| if "early_stopping" in locals(): | |
| del early_stopping | |
| if "train_losses" in locals(): | |
| del train_losses | |
| del val_losses | |
| del train_accs | |
| del val_accs | |
| del val_labels | |
| del val_preds | |
| del val_scores | |
| gc.collect() | |
| torch.cuda.empty_cache() | |