import os import torch import numpy as np from .utils import SaveBestModel, EarlyStopping, ModelMetrics, ModelLoss, validate_mae import datetime def train(model, train_loader, val_loader, optimizer, epochs=50, save_folder="swin_mae_models", task='mae', loss='default', focal_alpha=0.25, focal_gamma=2.0, patience=10): start_time_t = torch.cuda.Event(enable_timing=True) end_time_t = torch.cuda.Event(enable_timing=True) torch.cuda.synchronize() start_time_t.record() scaler = torch.amp.GradScaler() early_stopping = EarlyStopping(patience=patience, min_delta=0.0) mode = 'min' if task == 'mae' else 'max' save_best_model = SaveBestModel(folder=save_folder, mode=mode) loss = ModelLoss(task=task, loss=loss, focal_alpha=focal_alpha, focal_gamma=focal_gamma) metrics = ModelMetrics(task=task, device='cuda') metrics_labels = metrics.getLabels() early_stopping_metrics = [] training_metrics = [] validation_metrics = [] B, C, H, W = 0, 0, 0, 0 for sources, targets in train_loader: B, C, H, W = sources.shape break for epoch in range(epochs): print(f"Epoch {epoch+1}/{epochs} - System start time {datetime.datetime.now().strftime('%H:%M:%S')}") model.train() total_loss = 0 training_accumulator = [] # go through each batch of each epoch for sources, targets in train_loader: sources = sources.cuda() # masked images targets = targets.cuda() # original images pred_imgs = model(sources) l = loss(pred_imgs, targets) training_accumulator.append(metrics(pred_imgs, targets)) # ---- Backward ---- optimizer.zero_grad() scaler.scale(l).backward() scaler.step(optimizer) scaler.update() total_loss += l.item() training_metrics.append(np.mean(training_accumulator, axis=0)) validation_metrics.append(validate_mae(model, val_loader, metrics)) val_loss = 0 for source, target in val_loader: source = source.cuda() target = target.cuda() B, C, H, W = source.shape val_loss += loss(model(source), target).item() validation_loss = val_loss / len(val_loader) # print(f"Epoch {epoch+1}/{epochs} loss: {total_loss/len(train_loader):.4f}") # for label, value in zip(metrics_labels, training_metrics[-1]): # print(f"Train {label}: {value:.4f} ", end=' ') # for label, value in zip(metrics_labels, validation_metrics[-1]): # print(f"Val {label}: {value:.4f} ", end=' ') # print() early_stopping_metrics.append(validation_metrics[-1][0]) # save to csv os.makedirs(save_folder, exist_ok=True) with open(f"{save_folder}/training_log.csv", "a") as f: if epoch == 0: f.write("epoch,") f.write("loss,") f.write("model_size,") # f.write("inference_time,") f.write(",".join([f"train_{label}" for label in metrics_labels])) f.write(",val_loss,") f.write(",".join([f"val_{label}" for label in metrics_labels])) f.write("\n") f.write(f"{epoch+1},") f.write(f"{total_loss / len(train_loader):.4f},") f.write(f"{sum(p.numel() for p in model.parameters())/1e6:.2f}M,") # f.write(f"{inference_time},") f.write(",".join([f"{value:.4f}" for value in training_metrics[-1]])) f.write(",") f.write(f"{validation_loss:.4f},") f.write(",".join([f"{value:.4f}" for value in validation_metrics[-1]])) f.write("\n") if early_stopping(early_stopping_metrics): torch.save(model.state_dict(), os.path.join(save_folder, f"epoch_{epoch+1}_earlystop.pth")) print("Early stopping triggered. Model saved.") break save_best_model(model, early_stopping_metrics[-1], f"epoch_best.pth") if (epoch + 1) % 5 == 0: torch.save(model.state_dict(), os.path.join(save_folder, f"epoch_{epoch+1}.pth")) torch.cuda.synchronize() end_time_t.record() torch.cuda.synchronize() inference_time = 0.0 model.eval() with torch.no_grad(): sample_input = torch.randn(B, C, H, W).cuda() for _ in range(10): torch.cuda.synchronize() start_time = torch.cuda.Event(enable_timing=True) end_time = torch.cuda.Event(enable_timing=True) start_time.record() _ = model(sample_input) end_time.record() torch.cuda.synchronize() inference_time += start_time.elapsed_time(end_time) inference_time /= 10.0 # average over 10 runs with open(f"{save_folder}/training_time_log.csv", "a") as f: f.write(f"Total training time (ms) for {task},{start_time_t.elapsed_time(end_time_t)}\n") f.write(f"Average inference time (ms) for {task},{inference_time}\n")