|
|
import os
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
from .utils import SaveBestModel, EarlyStopping, ModelMetrics, ModelLoss, validate_mae
|
|
|
import datetime
|
|
|
|
|
|
def train(model, train_loader, val_loader, optimizer, epochs=50, save_folder="swin_mae_models", task='mae', loss='default', focal_alpha=0.25, focal_gamma=2.0, patience=10):
|
|
|
start_time_t = torch.cuda.Event(enable_timing=True)
|
|
|
end_time_t = torch.cuda.Event(enable_timing=True)
|
|
|
torch.cuda.synchronize()
|
|
|
start_time_t.record()
|
|
|
scaler = torch.amp.GradScaler()
|
|
|
early_stopping = EarlyStopping(patience=patience, min_delta=0.0)
|
|
|
mode = 'min' if task == 'mae' else 'max'
|
|
|
save_best_model = SaveBestModel(folder=save_folder, mode=mode)
|
|
|
loss = ModelLoss(task=task, loss=loss, focal_alpha=focal_alpha, focal_gamma=focal_gamma)
|
|
|
metrics = ModelMetrics(task=task, device='cuda')
|
|
|
metrics_labels = metrics.getLabels()
|
|
|
|
|
|
early_stopping_metrics = []
|
|
|
training_metrics = []
|
|
|
validation_metrics = []
|
|
|
B, C, H, W = 0, 0, 0, 0
|
|
|
for sources, targets in train_loader:
|
|
|
B, C, H, W = sources.shape
|
|
|
break
|
|
|
|
|
|
|
|
|
for epoch in range(epochs):
|
|
|
print(f"Epoch {epoch+1}/{epochs} - System start time {datetime.datetime.now().strftime('%H:%M:%S')}")
|
|
|
model.train()
|
|
|
total_loss = 0
|
|
|
training_accumulator = []
|
|
|
|
|
|
for sources, targets in train_loader:
|
|
|
sources = sources.cuda()
|
|
|
targets = targets.cuda()
|
|
|
|
|
|
pred_imgs = model(sources)
|
|
|
l = loss(pred_imgs, targets)
|
|
|
|
|
|
training_accumulator.append(metrics(pred_imgs, targets))
|
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
scaler.scale(l).backward()
|
|
|
scaler.step(optimizer)
|
|
|
scaler.update()
|
|
|
|
|
|
total_loss += l.item()
|
|
|
|
|
|
training_metrics.append(np.mean(training_accumulator, axis=0))
|
|
|
|
|
|
validation_metrics.append(validate_mae(model, val_loader, metrics))
|
|
|
val_loss = 0
|
|
|
for source, target in val_loader:
|
|
|
source = source.cuda()
|
|
|
target = target.cuda()
|
|
|
B, C, H, W = source.shape
|
|
|
val_loss += loss(model(source), target).item()
|
|
|
validation_loss = val_loss / len(val_loader)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
early_stopping_metrics.append(validation_metrics[-1][0])
|
|
|
|
|
|
|
|
|
os.makedirs(save_folder, exist_ok=True)
|
|
|
with open(f"{save_folder}/training_log.csv", "a") as f:
|
|
|
if epoch == 0:
|
|
|
f.write("epoch,")
|
|
|
f.write("loss,")
|
|
|
f.write("model_size,")
|
|
|
|
|
|
f.write(",".join([f"train_{label}" for label in metrics_labels]))
|
|
|
f.write(",val_loss,")
|
|
|
f.write(",".join([f"val_{label}" for label in metrics_labels]))
|
|
|
f.write("\n")
|
|
|
f.write(f"{epoch+1},")
|
|
|
f.write(f"{total_loss / len(train_loader):.4f},")
|
|
|
f.write(f"{sum(p.numel() for p in model.parameters())/1e6:.2f}M,")
|
|
|
|
|
|
f.write(",".join([f"{value:.4f}" for value in training_metrics[-1]]))
|
|
|
f.write(",")
|
|
|
f.write(f"{validation_loss:.4f},")
|
|
|
f.write(",".join([f"{value:.4f}" for value in validation_metrics[-1]]))
|
|
|
f.write("\n")
|
|
|
|
|
|
if early_stopping(early_stopping_metrics):
|
|
|
torch.save(model.state_dict(), os.path.join(save_folder, f"epoch_{epoch+1}_earlystop.pth"))
|
|
|
print("Early stopping triggered. Model saved.")
|
|
|
break
|
|
|
|
|
|
save_best_model(model, early_stopping_metrics[-1], f"epoch_best.pth")
|
|
|
|
|
|
if (epoch + 1) % 5 == 0:
|
|
|
torch.save(model.state_dict(), os.path.join(save_folder, f"epoch_{epoch+1}.pth"))
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
end_time_t.record()
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
inference_time = 0.0
|
|
|
model.eval()
|
|
|
with torch.no_grad():
|
|
|
sample_input = torch.randn(B, C, H, W).cuda()
|
|
|
for _ in range(10):
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
start_time = torch.cuda.Event(enable_timing=True)
|
|
|
end_time = torch.cuda.Event(enable_timing=True)
|
|
|
|
|
|
start_time.record()
|
|
|
_ = model(sample_input)
|
|
|
end_time.record()
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
inference_time += start_time.elapsed_time(end_time)
|
|
|
inference_time /= 10.0
|
|
|
|
|
|
with open(f"{save_folder}/training_time_log.csv", "a") as f:
|
|
|
f.write(f"Total training time (ms) for {task},{start_time_t.elapsed_time(end_time_t)}\n")
|
|
|
f.write(f"Average inference time (ms) for {task},{inference_time}\n")
|
|
|
|