motionReFit / src /trainer.py
Yzy00518's picture
Upload src/trainer.py with huggingface_hub
871d19b
import argparse
import torch
import os
import logging
from omegaconf import OmegaConf
from train import train_model
os.environ['NCCL_P2P_DISABLE'] = '0'
os.environ['NCCL_IB_DISABLE'] = '0'
if __name__ == "__main__":
"""
python train.py \
--task regen/style_transfer/adjustment \
--start 0 \ # 0 from scratch, n from checkpoint n
--end 4000 \ # total epochs, default 4000
--start_from_folder ../models/regen \ # path to checkpoint
--save_folder ../models/regen \ # path to save model
"""
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='regen')
parser.add_argument('--start', type=int, default=0)
parser.add_argument('--end', type=int, default=4000)
parser.add_argument('--start_from_folder', type=str, default=None)
parser.add_argument('--save_folder', type=str, default=None)
args = parser.parse_args()
world_size = torch.cuda.device_count()
logger_name = f'{args.task}_'
checkpoint_path = None
if args.start == 0:
logger_name += ''
start_epoch = 0
else:
checkpoint_path = os.path.join(args.start_from_folder, f'model_h3d_epoch{args.start}.pth')
assert os.path.exists(checkpoint_path), f'Checkpoint file {checkpoint_path} not found!'
logger_name += f'continue_from_epoch_{args.start}_'
start_epoch = args.start
import datetime
now = datetime.datetime.now()
logger_name += f'{now.strftime("%m-%d_%H-%M")}'
logger_name += '.log'
base_config = OmegaConf.load("src/configs/train/base_config.yaml")
task_config = OmegaConf.load(f"src/configs/train/tasks/{args.task}.yaml")
config = OmegaConf.merge(base_config, task_config)
logger_name = os.path.join(config.train.logger_pth, logger_name)
if not os.path.exists(config.train.logger_pth):
os.makedirs(config.train.logger_pth)
logging.basicConfig(filename=logger_name,
level=logging.INFO,
format='%(asctime)s:%(levelname)s:%(message)s')
torch.multiprocessing.spawn(train_model,
args=(world_size,
start_epoch,
args.end,
checkpoint_path,
config,
logging.getLogger(),),
nprocs=world_size,
join=True)