Timerns's picture
Upload folder using huggingface_hub
984cdba verified
import os
import torch
import numpy as np
from .utils import SaveBestModel, EarlyStopping, ModelMetrics, ModelLoss, validate_mae
import datetime
def train(model, train_loader, val_loader, optimizer, epochs=50, save_folder="swin_mae_models", task='mae', loss='default', focal_alpha=0.25, focal_gamma=2.0, patience=10):
start_time_t = torch.cuda.Event(enable_timing=True)
end_time_t = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_time_t.record()
scaler = torch.amp.GradScaler()
early_stopping = EarlyStopping(patience=patience, min_delta=0.0)
mode = 'min' if task == 'mae' else 'max'
save_best_model = SaveBestModel(folder=save_folder, mode=mode)
loss = ModelLoss(task=task, loss=loss, focal_alpha=focal_alpha, focal_gamma=focal_gamma)
metrics = ModelMetrics(task=task, device='cuda')
metrics_labels = metrics.getLabels()
early_stopping_metrics = []
training_metrics = []
validation_metrics = []
B, C, H, W = 0, 0, 0, 0
for sources, targets in train_loader:
B, C, H, W = sources.shape
break
for epoch in range(epochs):
print(f"Epoch {epoch+1}/{epochs} - System start time {datetime.datetime.now().strftime('%H:%M:%S')}")
model.train()
total_loss = 0
training_accumulator = []
# go through each batch of each epoch
for sources, targets in train_loader:
sources = sources.cuda() # masked images
targets = targets.cuda() # original images
pred_imgs = model(sources)
l = loss(pred_imgs, targets)
training_accumulator.append(metrics(pred_imgs, targets))
# ---- Backward ----
optimizer.zero_grad()
scaler.scale(l).backward()
scaler.step(optimizer)
scaler.update()
total_loss += l.item()
training_metrics.append(np.mean(training_accumulator, axis=0))
validation_metrics.append(validate_mae(model, val_loader, metrics))
val_loss = 0
for source, target in val_loader:
source = source.cuda()
target = target.cuda()
B, C, H, W = source.shape
val_loss += loss(model(source), target).item()
validation_loss = val_loss / len(val_loader)
# print(f"Epoch {epoch+1}/{epochs} loss: {total_loss/len(train_loader):.4f}")
# for label, value in zip(metrics_labels, training_metrics[-1]):
# print(f"Train {label}: {value:.4f} ", end=' ')
# for label, value in zip(metrics_labels, validation_metrics[-1]):
# print(f"Val {label}: {value:.4f} ", end=' ')
# print()
early_stopping_metrics.append(validation_metrics[-1][0])
# save to csv
os.makedirs(save_folder, exist_ok=True)
with open(f"{save_folder}/training_log.csv", "a") as f:
if epoch == 0:
f.write("epoch,")
f.write("loss,")
f.write("model_size,")
# f.write("inference_time,")
f.write(",".join([f"train_{label}" for label in metrics_labels]))
f.write(",val_loss,")
f.write(",".join([f"val_{label}" for label in metrics_labels]))
f.write("\n")
f.write(f"{epoch+1},")
f.write(f"{total_loss / len(train_loader):.4f},")
f.write(f"{sum(p.numel() for p in model.parameters())/1e6:.2f}M,")
# f.write(f"{inference_time},")
f.write(",".join([f"{value:.4f}" for value in training_metrics[-1]]))
f.write(",")
f.write(f"{validation_loss:.4f},")
f.write(",".join([f"{value:.4f}" for value in validation_metrics[-1]]))
f.write("\n")
if early_stopping(early_stopping_metrics):
torch.save(model.state_dict(), os.path.join(save_folder, f"epoch_{epoch+1}_earlystop.pth"))
print("Early stopping triggered. Model saved.")
break
save_best_model(model, early_stopping_metrics[-1], f"epoch_best.pth")
if (epoch + 1) % 5 == 0:
torch.save(model.state_dict(), os.path.join(save_folder, f"epoch_{epoch+1}.pth"))
torch.cuda.synchronize()
end_time_t.record()
torch.cuda.synchronize()
inference_time = 0.0
model.eval()
with torch.no_grad():
sample_input = torch.randn(B, C, H, W).cuda()
for _ in range(10):
torch.cuda.synchronize()
start_time = torch.cuda.Event(enable_timing=True)
end_time = torch.cuda.Event(enable_timing=True)
start_time.record()
_ = model(sample_input)
end_time.record()
torch.cuda.synchronize()
inference_time += start_time.elapsed_time(end_time)
inference_time /= 10.0 # average over 10 runs
with open(f"{save_folder}/training_time_log.csv", "a") as f:
f.write(f"Total training time (ms) for {task},{start_time_t.elapsed_time(end_time_t)}\n")
f.write(f"Average inference time (ms) for {task},{inference_time}\n")