File size: 5,294 Bytes
984cdba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import torch
import numpy as np
from .utils import SaveBestModel, EarlyStopping, ModelMetrics, ModelLoss, validate_mae
import datetime
def train(model, train_loader, val_loader, optimizer, epochs=50, save_folder="swin_mae_models", task='mae', loss='default', focal_alpha=0.25, focal_gamma=2.0, patience=10):
start_time_t = torch.cuda.Event(enable_timing=True)
end_time_t = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_time_t.record()
scaler = torch.amp.GradScaler()
early_stopping = EarlyStopping(patience=patience, min_delta=0.0)
mode = 'min' if task == 'mae' else 'max'
save_best_model = SaveBestModel(folder=save_folder, mode=mode)
loss = ModelLoss(task=task, loss=loss, focal_alpha=focal_alpha, focal_gamma=focal_gamma)
metrics = ModelMetrics(task=task, device='cuda')
metrics_labels = metrics.getLabels()
early_stopping_metrics = []
training_metrics = []
validation_metrics = []
B, C, H, W = 0, 0, 0, 0
for sources, targets in train_loader:
B, C, H, W = sources.shape
break
for epoch in range(epochs):
print(f"Epoch {epoch+1}/{epochs} - System start time {datetime.datetime.now().strftime('%H:%M:%S')}")
model.train()
total_loss = 0
training_accumulator = []
# go through each batch of each epoch
for sources, targets in train_loader:
sources = sources.cuda() # masked images
targets = targets.cuda() # original images
pred_imgs = model(sources)
l = loss(pred_imgs, targets)
training_accumulator.append(metrics(pred_imgs, targets))
# ---- Backward ----
optimizer.zero_grad()
scaler.scale(l).backward()
scaler.step(optimizer)
scaler.update()
total_loss += l.item()
training_metrics.append(np.mean(training_accumulator, axis=0))
validation_metrics.append(validate_mae(model, val_loader, metrics))
val_loss = 0
for source, target in val_loader:
source = source.cuda()
target = target.cuda()
B, C, H, W = source.shape
val_loss += loss(model(source), target).item()
validation_loss = val_loss / len(val_loader)
# print(f"Epoch {epoch+1}/{epochs} loss: {total_loss/len(train_loader):.4f}")
# for label, value in zip(metrics_labels, training_metrics[-1]):
# print(f"Train {label}: {value:.4f} ", end=' ')
# for label, value in zip(metrics_labels, validation_metrics[-1]):
# print(f"Val {label}: {value:.4f} ", end=' ')
# print()
early_stopping_metrics.append(validation_metrics[-1][0])
# save to csv
os.makedirs(save_folder, exist_ok=True)
with open(f"{save_folder}/training_log.csv", "a") as f:
if epoch == 0:
f.write("epoch,")
f.write("loss,")
f.write("model_size,")
# f.write("inference_time,")
f.write(",".join([f"train_{label}" for label in metrics_labels]))
f.write(",val_loss,")
f.write(",".join([f"val_{label}" for label in metrics_labels]))
f.write("\n")
f.write(f"{epoch+1},")
f.write(f"{total_loss / len(train_loader):.4f},")
f.write(f"{sum(p.numel() for p in model.parameters())/1e6:.2f}M,")
# f.write(f"{inference_time},")
f.write(",".join([f"{value:.4f}" for value in training_metrics[-1]]))
f.write(",")
f.write(f"{validation_loss:.4f},")
f.write(",".join([f"{value:.4f}" for value in validation_metrics[-1]]))
f.write("\n")
if early_stopping(early_stopping_metrics):
torch.save(model.state_dict(), os.path.join(save_folder, f"epoch_{epoch+1}_earlystop.pth"))
print("Early stopping triggered. Model saved.")
break
save_best_model(model, early_stopping_metrics[-1], f"epoch_best.pth")
if (epoch + 1) % 5 == 0:
torch.save(model.state_dict(), os.path.join(save_folder, f"epoch_{epoch+1}.pth"))
torch.cuda.synchronize()
end_time_t.record()
torch.cuda.synchronize()
inference_time = 0.0
model.eval()
with torch.no_grad():
sample_input = torch.randn(B, C, H, W).cuda()
for _ in range(10):
torch.cuda.synchronize()
start_time = torch.cuda.Event(enable_timing=True)
end_time = torch.cuda.Event(enable_timing=True)
start_time.record()
_ = model(sample_input)
end_time.record()
torch.cuda.synchronize()
inference_time += start_time.elapsed_time(end_time)
inference_time /= 10.0 # average over 10 runs
with open(f"{save_folder}/training_time_log.csv", "a") as f:
f.write(f"Total training time (ms) for {task},{start_time_t.elapsed_time(end_time_t)}\n")
f.write(f"Average inference time (ms) for {task},{inference_time}\n")
|