File size: 3,560 Bytes
e170a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import argparse
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint

from datasets import dataset_dict, RandomConcatSampler
from model import PL_RelPose
from utils import seed_torch
from configs.default import get_cfg_defaults


def main(args):
    config = get_cfg_defaults()
    config.merge_from_file(args.config)

    task = config.DATASET.TASK
    dataset = config.DATASET.DATA_SOURCE

    batch_size = config.TRAINER.BATCH_SIZE
    num_workers = config.TRAINER.NUM_WORKERS
    pin_memory = config.TRAINER.PIN_MEMORY
    n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET
    lr = config.TRAINER.LEARNING_RATE
    epochs = config.TRAINER.EPOCHS
    pct_start = config.TRAINER.PCT_START

    num_keypoints = config.MODEL.NUM_KEYPOINTS
    n_layers = config.MODEL.N_LAYERS
    num_heads = config.MODEL.NUM_HEADS
    features = config.MODEL.FEATURES

    seed = config.RANDOM_SEED
    seed_torch(seed)

    build_fn = dataset_dict[task][dataset]
    trainset = build_fn('train', config)
    validset = build_fn('val', config)

    if dataset == 'scannet' or dataset == 'megadepth' or dataset == 'linemod' or dataset == 'ho3d' or dataset == 'mapfree':
        sampler = RandomConcatSampler(
            trainset,
            n_samples_per_subset=n_samples_per_subset,
            subset_replacement=True,
            shuffle=True, 
            seed=seed
        )
        trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, sampler=sampler)
    else:
        trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True)

    validloader = DataLoader(validset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)

    if args.weights is None:
        pl_relpose = PL_RelPose(
            task=task,
            lr=lr,
            epochs=epochs,
            pct_start=pct_start,
            n_layers=n_layers,
            num_heads=num_heads,
            num_keypoints=num_keypoints,
            features=features,
        )
    else:
        pl_relpose = PL_RelPose.load_from_checkpoint(
            checkpoint_path=args.weights,
            task=task,
            lr=lr,
            epochs=epochs,
            pct_start=pct_start,
            n_layers=n_layers,
            num_heads=num_heads,
            num_keypoints=num_keypoints,
        )

    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    latest_checkpoint_callback = ModelCheckpoint()
    best_checkpoint_callback = ModelCheckpoint(monitor='valid/auc@20', mode='max')
    trainer = L.Trainer(
        devices=[0],
        # devices=[0, 1],
        # accelerator='gpu', strategy='ddp_find_unused_parameters_true', 
        max_epochs=epochs, 
        callbacks=[lr_monitor, latest_checkpoint_callback, best_checkpoint_callback],
        precision="bf16-mixed",
        # fast_dev_run=1,
    )
    
    trainer.fit(pl_relpose, trainloader, validloader, ckpt_path=args.resume)


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('config', type=str, help='.yaml configure file path')
    parser.add_argument('--resume', type=str, default=None)
    parser.add_argument('--weights', type=str, default=None)

    return parser


if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()
    main(args)