Yzy00518 commited on
Commit
871d19b
·
1 Parent(s): 7cc572d

Upload src/trainer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/trainer.py +67 -0
src/trainer.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import logging
5
+ from omegaconf import OmegaConf
6
+ from train import train_model
7
+
8
+ os.environ['NCCL_P2P_DISABLE'] = '0'
9
+ os.environ['NCCL_IB_DISABLE'] = '0'
10
+
11
+
12
+ if __name__ == "__main__":
13
+ """
14
+ python train.py \
15
+ --task regen/style_transfer/adjustment \
16
+ --start 0 \ # 0 from scratch, n from checkpoint n
17
+ --end 4000 \ # total epochs, default 4000
18
+ --start_from_folder ../models/regen \ # path to checkpoint
19
+ --save_folder ../models/regen \ # path to save model
20
+ """
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument('--task', type=str, default='regen')
23
+ parser.add_argument('--start', type=int, default=0)
24
+ parser.add_argument('--end', type=int, default=4000)
25
+ parser.add_argument('--start_from_folder', type=str, default=None)
26
+ parser.add_argument('--save_folder', type=str, default=None)
27
+
28
+ args = parser.parse_args()
29
+
30
+ world_size = torch.cuda.device_count()
31
+
32
+ logger_name = f'{args.task}_'
33
+ checkpoint_path = None
34
+ if args.start == 0:
35
+ logger_name += ''
36
+ start_epoch = 0
37
+ else:
38
+ checkpoint_path = os.path.join(args.start_from_folder, f'model_h3d_epoch{args.start}.pth')
39
+ assert os.path.exists(checkpoint_path), f'Checkpoint file {checkpoint_path} not found!'
40
+ logger_name += f'continue_from_epoch_{args.start}_'
41
+ start_epoch = args.start
42
+
43
+ import datetime
44
+ now = datetime.datetime.now()
45
+ logger_name += f'{now.strftime("%m-%d_%H-%M")}'
46
+ logger_name += '.log'
47
+
48
+ base_config = OmegaConf.load("src/configs/train/base_config.yaml")
49
+ task_config = OmegaConf.load(f"src/configs/train/tasks/{args.task}.yaml")
50
+ config = OmegaConf.merge(base_config, task_config)
51
+
52
+ logger_name = os.path.join(config.train.logger_pth, logger_name)
53
+ if not os.path.exists(config.train.logger_pth):
54
+ os.makedirs(config.train.logger_pth)
55
+ logging.basicConfig(filename=logger_name,
56
+ level=logging.INFO,
57
+ format='%(asctime)s:%(levelname)s:%(message)s')
58
+
59
+ torch.multiprocessing.spawn(train_model,
60
+ args=(world_size,
61
+ start_epoch,
62
+ args.end,
63
+ checkpoint_path,
64
+ config,
65
+ logging.getLogger(),),
66
+ nprocs=world_size,
67
+ join=True)