File size: 4,114 Bytes
8d5039c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import glob
import tqdm
import os
import torch
import numpy as np
from test_util import test_model
import wandb

def train_one_epoch(model, optim, data_loader, accumulated_iter, tbar, leave_pbar=False):
    total_it_each_epoch = len(data_loader)
    dataloader_iter = iter(data_loader)
    pbar = tqdm.tqdm(total=total_it_each_epoch, leave=leave_pbar, desc='train', dynamic_ncols=True)

    for cur_it in range(total_it_each_epoch):
        try:
            batch = next(dataloader_iter)
        except StopIteration:
            dataloader_iter = iter(data_loader)
            batch = next(dataloader_iter)
            print('new iters')

        try:
            cur_lr = float(optim.lr)
        except:
            cur_lr = optim.param_groups[0]['lr']

        model.train()
        optim.zero_grad()
        load_data_to_gpu(batch)
        loss, loss_dict, disp_dict = model(batch)
        loss.backward()
        optim.step()

        accumulated_iter += 1
        disp_dict.update(loss_dict)
        disp_dict.update({'loss': loss.item(), 'lr': cur_lr})

        # Log to wandb
        wandb.log({"loss": loss.item(), "lr": cur_lr, "iter": accumulated_iter})

        # log to console and tensorboard
        pbar.update()
        pbar.set_postfix(dict(total_it=accumulated_iter))
        tbar.set_postfix(disp_dict)
        tbar.refresh()

    pbar.close()
    return accumulated_iter

def train_model(model, optim, data_loader, lr_sch, start_it, start_epoch, total_epochs, ckpt_save_dir, sampler=None,
                max_ckpt_save_num=5):

    with tqdm.trange(start_epoch, total_epochs, desc='epochs', dynamic_ncols=True) as tbar:
        accumulated_iter = start_it
        for e in tbar:
            if sampler is not None:
                sampler.set_epoch(e)
            if e > 5:
                model.use_edge = True
            accumulated_iter = train_one_epoch(model, optim, data_loader, accumulated_iter, tbar,
                                               leave_pbar=(e + 1 == total_epochs))
            lr_sch.step()
            lr = max(optim.param_groups[0]['lr'], 1e-6)
            for param_group in optim.param_groups:
                param_group['lr'] = lr

            ckpt_list = glob.glob(str(ckpt_save_dir / 'checkpoint_epoch_*.pth'))
            ckpt_list.sort(key=os.path.getmtime)
            if ckpt_list.__len__() >= max_ckpt_save_num:
                for cur_file_idx in range(0, len(ckpt_list) - max_ckpt_save_num + 1):
                    os.remove(ckpt_list[cur_file_idx])

            ckpt_name = ckpt_save_dir / ('checkpoint_epoch_%d' % (e + 1))
            save_checkpoint(
                checkpoint_state(model, optim, e + 1, accumulated_iter), filename=ckpt_name,
            )

def model_state_to_cpu(model_state):
    model_state_cpu = type(model_state)()
    for key, val in model_state.items():
        model_state_cpu[key] = val.cpu()
    return model_state_cpu

def checkpoint_state(model=None, optimizer=None, epoch=None, it=None):
    optim_state = optimizer.state_dict() if optimizer is not None else None
    if model is not None:
        if isinstance(model, torch.nn.parallel.DistributedDataParallel):
            model_state = model_state_to_cpu(model.module.state_dict())
        else:
            model_state = model.state_dict()
    else:
        model_state = None

    return {'epoch': epoch, 'it': it, 'model_state': model_state, 'optimizer_state': optim_state}

def save_checkpoint(state, filename='checkpoint'):
    if False and 'optimizer_state' in state:
        optimizer_state = state['optimizer_state']
        state.pop('optimizer_state', None)
        optimizer_filename = '{}_optim.pth'.format(filename)
        torch.save({'optimizer_state': optimizer_state}, optimizer_filename)

    filename = '{}.pth'.format(filename)
    torch.save(state, filename)

def load_data_to_gpu(batch_dict):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    for key, val in batch_dict.items():
        if not isinstance(val, np.ndarray):
            continue
        batch_dict[key] = torch.from_numpy(val).float().to(device)