File size: 3,653 Bytes
0788e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import logging
import os
import random
import sys
import time

import numpy as np
import torch
from data import create_dataloader
from networks.trainer import Trainer
from options.train_options import TrainOptions
from tensorboardX import SummaryWriter
from validate import validate


def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False


def get_val_opt():
    val_opt = TrainOptions().parse(print_options=False)
    val_opt.isTrain = False
    val_opt.no_resize = False
    val_opt.no_crop = False
    val_opt.serial_batches = True
    return val_opt


if __name__ == '__main__':
    opt = TrainOptions().parse()
    seed_torch(100)
    opt.dataroot = f'{opt.dataroot}/{opt.train_split}/'
    print('  '.join(list(sys.argv)))
    val_opt = get_val_opt()
    val_opt.dataroot = f'{val_opt.dataroot}/{val_opt.val_split}/'

    data_loader = create_dataloader(opt)
    val_loader = create_dataloader(val_opt)

    train_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, 'train'))
    val_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, 'val'))

    model = Trainer(opt)
    model.train()

    # early_stopping = EarlyStopping(patience=opt.earlystop_epoch, delta=-0.001, verbose=True)

    # Configure logger
    logging.basicConfig(
        level=logging.INFO,  # Set logging level
        format='%(asctime)s - %(levelname)s - %(message)s',  # Log format
        datefmt='%Y-%m-%d %H:%M:%S',  # Date format without milliseconds
        handlers=[
            logging.FileHandler('log.log', mode='w'),  # Log file output
            logging.StreamHandler(),
        ],
    )  # Console output

    print(f'cwd: {os.getcwd()}')
    print(
        f'{time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())} Length of data loader: {len(data_loader)}'
    )

    for epoch in range(opt.niter):
        for i, data in enumerate(data_loader):
            model.total_steps += 1

            # Move data to cuda
            model.set_input(data)
            # Train
            model.optimize_parameters()

            if model.total_steps % opt.loss_freq == 0:
                print(
                    f'{time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())} Train loss: {model.loss} at step: {model.total_steps} lr {model.lr}'
                )
                train_writer.add_scalar('loss', model.loss, model.total_steps)

        if epoch % opt.save_epoch_freq == 0 and epoch != 0:
            # if epoch == 0:
            # Optionally save model at each epoch
            print(f'saving the model at the end of epoch {epoch}')
            model.save_networks(epoch)

        if epoch % 10 == 0 and epoch != 0:
            print(
                f'{time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())} changing lr at the end of epoch {epoch}, iters {model.total_steps}'
            )
            model.adjust_learning_rate()

        # Validation
        model.eval()
        acc, ap, r_acc, f_acc = validate(model.model, val_loader)
        val_writer.add_scalar('accuracy', acc, model.total_steps)
        val_writer.add_scalar('ap', ap, model.total_steps)
        print(
            f'{time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())} (Val @ epoch {epoch}) acc: {acc}; ap: {ap} r_acc: {r_acc}; f_acc: {f_acc}'
        )

        model.train()

    model.eval()
    model.save_networks('last')