| import argparse
|
| from torch.utils.data import DataLoader
|
| import lightning as L
|
| from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
|
|
|
| from datasets import dataset_dict, RandomConcatSampler
|
| from model import PL_RelPose
|
| from utils import seed_torch
|
| from configs.default import get_cfg_defaults
|
|
|
|
|
| def main(args):
|
| config = get_cfg_defaults()
|
| config.merge_from_file(args.config)
|
|
|
| task = config.DATASET.TASK
|
| dataset = config.DATASET.DATA_SOURCE
|
|
|
| batch_size = config.TRAINER.BATCH_SIZE
|
| num_workers = config.TRAINER.NUM_WORKERS
|
| pin_memory = config.TRAINER.PIN_MEMORY
|
| n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET
|
| lr = config.TRAINER.LEARNING_RATE
|
| epochs = config.TRAINER.EPOCHS
|
| pct_start = config.TRAINER.PCT_START
|
|
|
| num_keypoints = config.MODEL.NUM_KEYPOINTS
|
| n_layers = config.MODEL.N_LAYERS
|
| num_heads = config.MODEL.NUM_HEADS
|
| features = config.MODEL.FEATURES
|
|
|
| seed = config.RANDOM_SEED
|
| seed_torch(seed)
|
|
|
| build_fn = dataset_dict[task][dataset]
|
| trainset = build_fn('train', config)
|
| validset = build_fn('val', config)
|
|
|
| if dataset == 'scannet' or dataset == 'megadepth' or dataset == 'linemod' or dataset == 'ho3d' or dataset == 'mapfree':
|
| sampler = RandomConcatSampler(
|
| trainset,
|
| n_samples_per_subset=n_samples_per_subset,
|
| subset_replacement=True,
|
| shuffle=True,
|
| seed=seed
|
| )
|
| trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, sampler=sampler)
|
| else:
|
| trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True)
|
|
|
| validloader = DataLoader(validset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
|
|
|
| if args.weights is None:
|
| pl_relpose = PL_RelPose(
|
| task=task,
|
| lr=lr,
|
| epochs=epochs,
|
| pct_start=pct_start,
|
| n_layers=n_layers,
|
| num_heads=num_heads,
|
| num_keypoints=num_keypoints,
|
| features=features,
|
| )
|
| else:
|
| pl_relpose = PL_RelPose.load_from_checkpoint(
|
| checkpoint_path=args.weights,
|
| task=task,
|
| lr=lr,
|
| epochs=epochs,
|
| pct_start=pct_start,
|
| n_layers=n_layers,
|
| num_heads=num_heads,
|
| num_keypoints=num_keypoints,
|
| )
|
|
|
| lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
| latest_checkpoint_callback = ModelCheckpoint()
|
| best_checkpoint_callback = ModelCheckpoint(monitor='valid/auc@20', mode='max')
|
| trainer = L.Trainer(
|
| devices=[0],
|
|
|
|
|
| max_epochs=epochs,
|
| callbacks=[lr_monitor, latest_checkpoint_callback, best_checkpoint_callback],
|
| precision="bf16-mixed",
|
|
|
| )
|
|
|
| trainer.fit(pl_relpose, trainloader, validloader, ckpt_path=args.resume)
|
|
|
|
|
| def get_parser():
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument('config', type=str, help='.yaml configure file path')
|
| parser.add_argument('--resume', type=str, default=None)
|
| parser.add_argument('--weights', type=str, default=None)
|
|
|
| return parser
|
|
|
|
|
| if __name__ == "__main__":
|
| parser = get_parser()
|
| args = parser.parse_args()
|
| main(args)
|
|
|