File size: 5,294 Bytes
984cdba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import torch
import numpy as np
from .utils import SaveBestModel, EarlyStopping, ModelMetrics, ModelLoss, validate_mae
import datetime

def train(model, train_loader, val_loader, optimizer, epochs=50, save_folder="swin_mae_models", task='mae', loss='default', focal_alpha=0.25, focal_gamma=2.0, patience=10):
    start_time_t = torch.cuda.Event(enable_timing=True)
    end_time_t = torch.cuda.Event(enable_timing=True)
    torch.cuda.synchronize()
    start_time_t.record()
    scaler = torch.amp.GradScaler()
    early_stopping = EarlyStopping(patience=patience, min_delta=0.0)
    mode = 'min' if task == 'mae' else 'max'
    save_best_model = SaveBestModel(folder=save_folder, mode=mode)
    loss = ModelLoss(task=task, loss=loss, focal_alpha=focal_alpha, focal_gamma=focal_gamma)
    metrics = ModelMetrics(task=task, device='cuda')
    metrics_labels = metrics.getLabels()

    early_stopping_metrics = []
    training_metrics = []
    validation_metrics = []
    B, C, H, W = 0, 0, 0, 0
    for sources, targets in train_loader:
        B, C, H, W = sources.shape
        break

    
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs} - System start time {datetime.datetime.now().strftime('%H:%M:%S')}")
        model.train()
        total_loss = 0
        training_accumulator = []
        # go through each batch of each epoch
        for sources, targets in train_loader:
            sources = sources.cuda()   # masked images
            targets = targets.cuda()   # original images

            pred_imgs = model(sources) 
            l = loss(pred_imgs, targets)

            training_accumulator.append(metrics(pred_imgs, targets))
            
            # ---- Backward ----
            optimizer.zero_grad()
            scaler.scale(l).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += l.item()

        training_metrics.append(np.mean(training_accumulator, axis=0))

        validation_metrics.append(validate_mae(model, val_loader, metrics))
        val_loss = 0
        for source, target in val_loader:
            source = source.cuda()
            target = target.cuda()
            B, C, H, W = source.shape
            val_loss += loss(model(source), target).item()
        validation_loss = val_loss / len(val_loader)

        # print(f"Epoch {epoch+1}/{epochs} loss: {total_loss/len(train_loader):.4f}")
        # for label, value in zip(metrics_labels, training_metrics[-1]):
        #     print(f"Train {label}: {value:.4f} ", end=' ')
        # for label, value in zip(metrics_labels, validation_metrics[-1]):
        #     print(f"Val {label}: {value:.4f} ", end=' ')
        # print()

    

        early_stopping_metrics.append(validation_metrics[-1][0])

        # save to csv
        os.makedirs(save_folder, exist_ok=True)
        with open(f"{save_folder}/training_log.csv", "a") as f:
            if epoch == 0:
                f.write("epoch,")
                f.write("loss,")
                f.write("model_size,")
                # f.write("inference_time,")
                f.write(",".join([f"train_{label}" for label in metrics_labels]))
                f.write(",val_loss,")
                f.write(",".join([f"val_{label}" for label in metrics_labels]))
                f.write("\n")
            f.write(f"{epoch+1},")
            f.write(f"{total_loss / len(train_loader):.4f},")
            f.write(f"{sum(p.numel() for p in model.parameters())/1e6:.2f}M,")
            # f.write(f"{inference_time},")
            f.write(",".join([f"{value:.4f}" for value in training_metrics[-1]]))
            f.write(",")
            f.write(f"{validation_loss:.4f},")
            f.write(",".join([f"{value:.4f}" for value in validation_metrics[-1]]))
            f.write("\n")
       
        if early_stopping(early_stopping_metrics):
            torch.save(model.state_dict(), os.path.join(save_folder, f"epoch_{epoch+1}_earlystop.pth"))
            print("Early stopping triggered. Model saved.")
            break
        
        save_best_model(model, early_stopping_metrics[-1], f"epoch_best.pth")
        
        if (epoch + 1) % 5 == 0:
            torch.save(model.state_dict(), os.path.join(save_folder, f"epoch_{epoch+1}.pth"))

    torch.cuda.synchronize()
    end_time_t.record()
    torch.cuda.synchronize()

    inference_time = 0.0 
    model.eval()
    with torch.no_grad():
        sample_input = torch.randn(B, C, H, W).cuda()
        for _ in range(10):
            
            torch.cuda.synchronize()
            start_time = torch.cuda.Event(enable_timing=True)
            end_time = torch.cuda.Event(enable_timing=True)

            start_time.record()
            _ = model(sample_input)
            end_time.record()

            torch.cuda.synchronize()
            inference_time += start_time.elapsed_time(end_time)
    inference_time /= 10.0  # average over 10 runs

    with open(f"{save_folder}/training_time_log.csv", "a") as f:
        f.write(f"Total training time (ms) for {task},{start_time_t.elapsed_time(end_time_t)}\n")
        f.write(f"Average inference time (ms) for {task},{inference_time}\n")