henryu commited on
Commit
91bb10d
·
1 Parent(s): dbb22f9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -0
app.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A main training script.
3
+ """
4
+
5
+
6
+ # Copyright (c) Facebook, Inc. and its affiliates.
7
+ import warnings
8
+ warnings.filterwarnings('ignore') # never print matching warnings
9
+ import logging
10
+ import os
11
+ from collections import OrderedDict
12
+ import torch
13
+ import uniperceiver.utils.comm as comm
14
+ from uniperceiver.config import get_cfg, CfgNode
15
+ from uniperceiver.engine import DefaultTrainer, default_argument_parser, default_setup, launch, build_engine, add_moe_arguments
16
+
17
+ #!TODO re-implement hooks
18
+ from uniperceiver.engine import hooks
19
+ from uniperceiver.modeling import add_config
20
+ from uniperceiver.utils.env import init_distributed_mode, check_dist_portfile
21
+ try:
22
+ import deepspeed
23
+ DEEPSPEED_INSTALLED = True
24
+ except:
25
+ DEEPSPEED_INSTALLED = False
26
+
27
+ import copy
28
+
29
+ def add_data_prefix(cfg):
30
+ # TODO: more flexible method
31
+ data_dir = os.getenv("DATA_PATH", None)
32
+ mapping_list = [
33
+ [cfg.DATALOADER, 'FEATS_FOLDER', ['DATALOADER',]],
34
+ [cfg.DATALOADER, 'ANNO_FOLDER', ['DATALOADER', ]],
35
+ [cfg.DATALOADER, 'CLASS_NAME_FILE', ['DATALOADER', ]],
36
+ [cfg.INFERENCE, 'VOCAB', ['INFERENCE', ]],
37
+ [cfg.INFERENCE, 'VAL_ANNFILE', ['INFERENCE', ]],
38
+ [cfg.INFERENCE, 'TEST_ANNFILE', ['INFERENCE',]],
39
+ [cfg.MODEL, 'WEIGHTS', ['MODEL',]],
40
+ ]
41
+ whitelist = ["BERT", "CLIP", "CLIP_CAPTION"]
42
+ if data_dir:
43
+ for node, attr ,_ in mapping_list:
44
+ if node[attr] != '' and not node[attr].startswith('.') and not node[attr].startswith('/') and not node[attr].startswith('work_dirs') and not node[attr].startswith('cluster') and not node[attr].startswith('s3://') and node[attr] not in whitelist:
45
+ setattr(node, attr, os.path.join(data_dir, node[attr]))
46
+ for task in cfg.TASKS:
47
+ for _, item, key_list in mapping_list:
48
+ config_tmp = task
49
+ for key in key_list:
50
+ if key in config_tmp:
51
+ config_tmp = config_tmp[key]
52
+ if item in config_tmp and config_tmp[item] != '' and not config_tmp[item].startswith('.') and not config_tmp[item].startswith('/') and not config_tmp[item].startswith('work_dirs') and not config_tmp[item].startswith('cluster') and not config_tmp[item].startswith('s3://') and config_tmp[item] not in whitelist:
53
+ config_tmp[item] = os.path.join(data_dir, config_tmp[item])
54
+
55
+ mapping_list = [
56
+ ['', 'FILE_PATH', ['SHARED_TARGETS_CFG',]],
57
+ ]
58
+ if cfg.SHARED_TARGETS is None:
59
+ cfg.SHARED_TARGETS = []
60
+ for share_targets in cfg.SHARED_TARGETS:
61
+ for _, item, key_list in mapping_list:
62
+ config_tmp = share_targets
63
+ for key in key_list:
64
+ config_tmp = config_tmp[key]
65
+ if item in config_tmp and config_tmp[item] != '' and not config_tmp[item].startswith('.') and not config_tmp[item].startswith(
66
+ '/') and not config_tmp[item].startswith('work_dirs') and not config_tmp[item].startswith(
67
+ 'cluster') and not config_tmp[item].startswith('s3://') and config_tmp[item] not in whitelist:
68
+ config_tmp[item] = os.path.join(data_dir, config_tmp[item])
69
+
70
+
71
+
72
+ def add_default_setting_for_multitask_config(cfg):
73
+ # merge some default config in (CfgNode) uniperceiver/config/defaults.py to each task config (dict)
74
+
75
+ tasks_config_temp = cfg.TASKS
76
+ num_tasks = len(tasks_config_temp)
77
+ cfg.pop('TASKS', None)
78
+
79
+ cfg.TASKS = [copy.deepcopy(cfg) for _ in range(num_tasks)]
80
+
81
+ for i, task_config in enumerate(tasks_config_temp):
82
+ cfg.TASKS[i].merge_from_other_cfg(CfgNode(task_config))
83
+ cfg.TASKS[i] = cfg.TASKS[i].to_dict_object()
84
+ pass
85
+
86
+
87
+ def setup(args):
88
+ """
89
+ Create configs and perform basic setups.
90
+ """
91
+ cfg = get_cfg()
92
+ tmp_cfg = cfg.load_from_file_tmp(args.config_file)
93
+ add_config(cfg, tmp_cfg)
94
+
95
+ cfg.merge_from_file(args.config_file)
96
+ add_data_prefix(cfg)
97
+
98
+ cfg.merge_from_list(args.opts)
99
+ #
100
+ add_default_setting_for_multitask_config(cfg)
101
+ cfg.freeze()
102
+ default_setup(cfg, args)
103
+ return cfg
104
+
105
+ def main(args):
106
+ cfg = setup(args)
107
+
108
+ """
109
+ If you'd like to do anything fancier than the standard training logic,
110
+ consider writing your own training loop (see plain_train_net.py) or
111
+ subclassing the trainer.
112
+ """
113
+ trainer = build_engine(cfg)
114
+ trainer.resume_or_load(resume=args.resume)
115
+ trainer.cast_layers()
116
+
117
+ if args.eval_only:
118
+ print('---------------------------')
119
+ print('eval model only')
120
+ print('---------------------------\n')
121
+ res = None
122
+ if trainer.val_data_loader is not None:
123
+
124
+ if trainer.model_ema is not None and args.eval_ema:
125
+ if comm.is_main_process():
126
+ print('using ema model for evaluation')
127
+ res = trainer.test(trainer.cfg, trainer.model_ema.ema, trainer.val_data_loader, trainer.val_evaluator, epoch=-1)
128
+ else:
129
+ if args.eval_ema and comm.is_main_process():
130
+ print('no ema model exists! using master model for evaluation')
131
+ res = trainer.test(trainer.cfg, trainer.model, trainer.val_data_loader, trainer.val_evaluator, epoch=-1)
132
+
133
+ if comm.is_main_process():
134
+ print(res)
135
+
136
+ if trainer.test_data_loader is not None:
137
+ if trainer.model_ema is not None and args.eval_ema:
138
+ if comm.is_main_process():
139
+ print('using ema model for evaluation')
140
+ res = trainer.test(trainer.cfg, trainer.model_ema.ema, trainer.test_data_loader, trainer.test_evaluator, epoch=-1)
141
+ else:
142
+ if args.eval_ema and comm.is_main_process():
143
+ print('no ema model exists! using master model for evaluation')
144
+ res = trainer.test(trainer.cfg, trainer.model, trainer.test_data_loader, trainer.test_evaluator, epoch=-1)
145
+ if comm.is_main_process():
146
+ print(res)
147
+ return res
148
+
149
+ return trainer.train()
150
+
151
+ def get_args_parser():
152
+ parser = default_argument_parser()
153
+ if DEEPSPEED_INSTALLED:
154
+ parser = deepspeed.add_config_arguments(parser)
155
+ parser = add_moe_arguments(parser)
156
+
157
+ parser.add_argument('--init_method', default='slurm', type=str)
158
+ parser.add_argument('--local_rank', default=0, type=int)
159
+ parser.add_argument("--eval-ema", action="store_true", help="perform evaluation using ema")
160
+ args = parser.parse_args()
161
+
162
+ return args
163
+
164
+ if __name__ == "__main__":
165
+ args = get_args_parser()
166
+ print("Command Line Args:", args)
167
+ if args.init_method == 'slurm':
168
+ # slurm init
169
+ check_dist_portfile()
170
+ init_distributed_mode(args)
171
+ main(args)
172
+ elif args.init_method == 'pytorch':
173
+ main(args)
174
+ else:
175
+ # follow 'd2' use default `mp.spawn` to init dist training
176
+ print('using \'mp.spawn\' for dist init! ')
177
+ launch(
178
+ main,
179
+ args.num_gpus,
180
+ num_machines=args.num_machines,
181
+ machine_rank=args.machine_rank,
182
+ dist_url=args.dist_url,
183
+ args=(args,),
184
+ )