File size: 2,505 Bytes
871d19b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import argparse
import torch
import os
import logging
from omegaconf import OmegaConf
from train import train_model

os.environ['NCCL_P2P_DISABLE'] = '0'
os.environ['NCCL_IB_DISABLE'] = '0'


if __name__ == "__main__":
    """
    python train.py \
    --task regen/style_transfer/adjustment \
    --start 0 \ # 0 from scratch, n from checkpoint n
    --end 4000 \ # total epochs, default 4000
    --start_from_folder ../models/regen \ # path to checkpoint
    --save_folder ../models/regen \ # path to save model
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('--task', type=str, default='regen')
    parser.add_argument('--start', type=int, default=0)
    parser.add_argument('--end', type=int, default=4000)
    parser.add_argument('--start_from_folder', type=str, default=None)
    parser.add_argument('--save_folder', type=str, default=None)

    args = parser.parse_args()

    world_size = torch.cuda.device_count()
    
    logger_name = f'{args.task}_'
    checkpoint_path = None
    if args.start == 0:
        logger_name += ''
        start_epoch = 0
    else:
        checkpoint_path = os.path.join(args.start_from_folder, f'model_h3d_epoch{args.start}.pth')
        assert os.path.exists(checkpoint_path), f'Checkpoint file {checkpoint_path} not found!'
        logger_name += f'continue_from_epoch_{args.start}_'
        start_epoch = args.start
    
    import datetime
    now = datetime.datetime.now()
    logger_name += f'{now.strftime("%m-%d_%H-%M")}'
    logger_name += '.log'

    base_config = OmegaConf.load("src/configs/train/base_config.yaml")
    task_config = OmegaConf.load(f"src/configs/train/tasks/{args.task}.yaml")
    config = OmegaConf.merge(base_config, task_config)

    logger_name = os.path.join(config.train.logger_pth, logger_name)
    if not os.path.exists(config.train.logger_pth):
        os.makedirs(config.train.logger_pth)
    logging.basicConfig(filename=logger_name,
                    level=logging.INFO,
                    format='%(asctime)s:%(levelname)s:%(message)s')

    torch.multiprocessing.spawn(train_model, 
                                args=(world_size, 
                                      start_epoch,
                                      args.end, 
                                      checkpoint_path, 
                                      config,
                                      logging.getLogger(),), 
                                nprocs=world_size, 
                                join=True)