VanNguyen1214 commited on
Commit
7a1ee60
·
verified ·
1 Parent(s): 75e661f

Delete main.py

Browse files
Files changed (1) hide show
  1. main.py +0 -401
main.py DELETED
@@ -1,401 +0,0 @@
1
- """
2
- @Date: 2021/07/17
3
- @description:
4
- """
5
- import sys
6
- import os
7
- import shutil
8
- import argparse
9
- import numpy as np
10
- import json
11
- import torch
12
- import torch.nn.parallel
13
- import torch.optim
14
- import torch.multiprocessing as mp
15
- import torch.utils.data
16
- import torch.utils.data.distributed
17
- import torch.cuda
18
-
19
- from PIL import Image
20
- from tqdm import tqdm
21
- from torch.utils.tensorboard import SummaryWriter
22
- from config.defaults import get_config, get_rank_config
23
- from models.other.criterion import calc_criterion
24
- from models.build import build_model
25
- from models.other.init_env import init_env
26
- from utils.logger import build_logger
27
- from utils.misc import tensor2np_d, tensor2np
28
- from dataset.build import build_loader
29
- from evaluation.accuracy import calc_accuracy, show_heat_map, calc_ce, calc_pe, calc_rmse_delta_1, \
30
- show_depth_normal_grad, calc_f1_score
31
- from postprocessing.post_process import post_process
32
-
33
- try:
34
- from apex import amp
35
- except ImportError:
36
- amp = None
37
-
38
-
39
- def parse_option():
40
- debug = True if sys.gettrace() else False
41
- parser = argparse.ArgumentParser(description='Panorama Layout Transformer training and evaluation script')
42
- parser.add_argument('--cfg',
43
- type=str,
44
- metavar='FILE',
45
- help='path to config file')
46
-
47
- parser.add_argument('--mode',
48
- type=str,
49
- default='train',
50
- choices=['train', 'val', 'test'],
51
- help='train/val/test mode')
52
-
53
- parser.add_argument('--val_name',
54
- type=str,
55
- choices=['val', 'test'],
56
- help='val name')
57
-
58
- parser.add_argument('--bs', type=int,
59
- help='batch size')
60
-
61
- parser.add_argument('--save_eval', action='store_true',
62
- help='save eval result')
63
-
64
- parser.add_argument('--post_processing', type=str,
65
- choices=['manhattan', 'atalanta', 'manhattan_old'],
66
- help='type of postprocessing ')
67
-
68
- parser.add_argument('--need_cpe', action='store_true',
69
- help='need to evaluate corner error and pixel error')
70
-
71
- parser.add_argument('--need_f1', action='store_true',
72
- help='need to evaluate f1-score of corners')
73
-
74
- parser.add_argument('--need_rmse', action='store_true',
75
- help='need to evaluate root mean squared error and delta error')
76
-
77
- parser.add_argument('--force_cube', action='store_true',
78
- help='force cube shape when eval')
79
-
80
- parser.add_argument('--wall_num', type=int,
81
- help='wall number')
82
-
83
- args = parser.parse_args()
84
- args.debug = debug
85
- print("arguments:")
86
- for arg in vars(args):
87
- print(arg, ":", getattr(args, arg))
88
- print("-" * 50)
89
- return args
90
-
91
-
92
- def main():
93
- args = parse_option()
94
- config = get_config(args)
95
-
96
- if config.TRAIN.SCRATCH and os.path.exists(config.CKPT.DIR) and config.MODE == 'train':
97
- print(f"Train from scratch, delete checkpoint dir: {config.CKPT.DIR}")
98
- f = [int(f.split('_')[-1].split('.')[0]) for f in os.listdir(config.CKPT.DIR) if 'pkl' in f]
99
- if len(f) > 0:
100
- last_epoch = np.array(f).max()
101
- if last_epoch > 10:
102
- c = input(f"delete it (last_epoch: {last_epoch})?(Y/N)\n")
103
- if c != 'y' and c != 'Y':
104
- exit(0)
105
-
106
- shutil.rmtree(config.CKPT.DIR, ignore_errors=True)
107
-
108
- os.makedirs(config.CKPT.DIR, exist_ok=True)
109
- os.makedirs(config.CKPT.RESULT_DIR, exist_ok=True)
110
- os.makedirs(config.LOGGER.DIR, exist_ok=True)
111
-
112
- if ':' in config.TRAIN.DEVICE:
113
- nprocs = len(config.TRAIN.DEVICE.split(':')[-1].split(','))
114
- if 'cuda' in config.TRAIN.DEVICE:
115
- if not torch.cuda.is_available():
116
- print(f"Cuda is not available(config is: {config.TRAIN.DEVICE}), will use cpu ...")
117
- config.defrost()
118
- config.TRAIN.DEVICE = "cpu"
119
- config.freeze()
120
- nprocs = 1
121
-
122
- if config.MODE == 'train':
123
- with open(os.path.join(config.CKPT.DIR, "config.yaml"), "w") as f:
124
- f.write(config.dump(allow_unicode=True))
125
-
126
- if config.TRAIN.DEVICE == 'cpu' or nprocs < 2:
127
- print(f"Use single process, device:{config.TRAIN.DEVICE}")
128
- main_worker(0, config, 1)
129
- else:
130
- print(f"Use {nprocs} processes ...")
131
- mp.spawn(main_worker, nprocs=nprocs, args=(config, nprocs), join=True)
132
-
133
-
134
- def main_worker(local_rank, cfg, world_size):
135
- config = get_rank_config(cfg, local_rank, world_size)
136
- logger = build_logger(config)
137
- writer = SummaryWriter(config.CKPT.DIR)
138
- logger.info(f"Comment: {config.COMMENT}")
139
- cur_pid = os.getpid()
140
- logger.info(f"Current process id: {cur_pid}")
141
- torch.hub._hub_dir = config.CKPT.PYTORCH
142
- logger.info(f"Pytorch hub dir: {torch.hub._hub_dir}")
143
- init_env(config.SEED, config.TRAIN.DETERMINISTIC, config.DATA.NUM_WORKERS)
144
-
145
- model, optimizer, criterion, scheduler = build_model(config, logger)
146
- train_data_loader, val_data_loader = build_loader(config, logger)
147
-
148
- if 'cuda' in config.TRAIN.DEVICE:
149
- torch.cuda.set_device(config.TRAIN.DEVICE)
150
-
151
- if config.MODE == 'train':
152
- train(model, train_data_loader, val_data_loader, optimizer, criterion, config, logger, writer, scheduler)
153
- else:
154
- iou_results, other_results = val_an_epoch(model, val_data_loader,
155
- criterion, config, logger, writer=None,
156
- epoch=config.TRAIN.START_EPOCH)
157
- results = dict(iou_results, **other_results)
158
- if config.SAVE_EVAL:
159
- save_path = os.path.join(config.CKPT.RESULT_DIR, f"result.json")
160
- with open(save_path, 'w+') as f:
161
- json.dump(results, f, indent=4)
162
-
163
-
164
- def save(model, optimizer, epoch, iou_d, logger, writer, config):
165
- model.save(optimizer, epoch, accuracy=iou_d['full_3d'], logger=logger, acc_d=iou_d, config=config)
166
- for k in model.acc_d:
167
- writer.add_scalar(f"BestACC/{k}", model.acc_d[k]['acc'], epoch)
168
-
169
-
170
- def train(model, train_data_loader, val_data_loader, optimizer, criterion, config, logger, writer, scheduler):
171
- for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
172
- logger.info("=" * 200)
173
- train_an_epoch(model, train_data_loader, optimizer, criterion, config, logger, writer, epoch)
174
- epoch_iou_d, _ = val_an_epoch(model, val_data_loader, criterion, config, logger, writer, epoch)
175
-
176
- if config.LOCAL_RANK == 0:
177
- ddp = config.WORLD_SIZE > 1
178
- save(model.module if ddp else model, optimizer, epoch, epoch_iou_d, logger, writer, config)
179
-
180
- if scheduler is not None:
181
- if scheduler.min_lr is not None and optimizer.param_groups[0]['lr'] <= scheduler.min_lr:
182
- continue
183
- scheduler.step()
184
- writer.close()
185
-
186
-
187
- def train_an_epoch(model, train_data_loader, optimizer, criterion, config, logger, writer, epoch=0):
188
- logger.info(f'Start Train Epoch {epoch}/{config.TRAIN.EPOCHS - 1}')
189
- model.train()
190
-
191
- if len(config.MODEL.FINE_TUNE) > 0:
192
- model.feature_extractor.eval()
193
-
194
- optimizer.zero_grad()
195
-
196
- data_len = len(train_data_loader)
197
- start_i = data_len * epoch * config.WORLD_SIZE
198
- bar = enumerate(train_data_loader)
199
- if config.LOCAL_RANK == 0 and config.SHOW_BAR:
200
- bar = tqdm(bar, total=data_len, ncols=200)
201
-
202
- device = config.TRAIN.DEVICE
203
- epoch_loss_d = {}
204
- for i, gt in bar:
205
- imgs = gt['image'].to(device, non_blocking=True)
206
- gt['depth'] = gt['depth'].to(device, non_blocking=True)
207
- gt['ratio'] = gt['ratio'].to(device, non_blocking=True)
208
- if 'corner_heat_map' in gt:
209
- gt['corner_heat_map'] = gt['corner_heat_map'].to(device, non_blocking=True)
210
- if config.AMP_OPT_LEVEL != "O0" and 'cuda' in device:
211
- imgs = imgs.type(torch.float16)
212
- gt['depth'] = gt['depth'].type(torch.float16)
213
- gt['ratio'] = gt['ratio'].type(torch.float16)
214
- dt = model(imgs)
215
- loss, batch_loss_d, epoch_loss_d = calc_criterion(criterion, gt, dt, epoch_loss_d)
216
- if config.LOCAL_RANK == 0 and config.SHOW_BAR:
217
- bar.set_postfix(batch_loss_d)
218
-
219
- optimizer.zero_grad()
220
- if config.AMP_OPT_LEVEL != "O0" and 'cuda' in device:
221
- with amp.scale_loss(loss, optimizer) as scaled_loss:
222
- scaled_loss.backward()
223
- else:
224
- loss.backward()
225
- optimizer.step()
226
-
227
- global_step = start_i + i * config.WORLD_SIZE + config.LOCAL_RANK
228
- for key, val in batch_loss_d.items():
229
- writer.add_scalar(f'TrainBatchLoss/{key}', val, global_step)
230
-
231
- if config.LOCAL_RANK != 0:
232
- return
233
-
234
- epoch_loss_d = dict(zip(epoch_loss_d.keys(), [np.array(epoch_loss_d[k]).mean() for k in epoch_loss_d.keys()]))
235
- s = 'TrainEpochLoss: '
236
- for key, val in epoch_loss_d.items():
237
- writer.add_scalar(f'TrainEpochLoss/{key}', val, epoch)
238
- s += f" {key}={val}"
239
- logger.info(s)
240
- writer.add_scalar('LearningRate', optimizer.param_groups[0]['lr'], epoch)
241
- logger.info(f"LearningRate: {optimizer.param_groups[0]['lr']}")
242
-
243
-
244
- @torch.no_grad()
245
- def val_an_epoch(model, val_data_loader, criterion, config, logger, writer, epoch=0):
246
- model.eval()
247
- logger.info(f'Start Validate Epoch {epoch}/{config.TRAIN.EPOCHS - 1}')
248
- data_len = len(val_data_loader)
249
- start_i = data_len * epoch * config.WORLD_SIZE
250
- bar = enumerate(val_data_loader)
251
- if config.LOCAL_RANK == 0 and config.SHOW_BAR:
252
- bar = tqdm(bar, total=data_len, ncols=200)
253
- device = config.TRAIN.DEVICE
254
- epoch_loss_d = {}
255
- epoch_iou_d = {
256
- 'visible_2d': [],
257
- 'visible_3d': [],
258
- 'full_2d': [],
259
- 'full_3d': [],
260
- 'height': []
261
- }
262
-
263
- epoch_other_d = {
264
- 'ce': [],
265
- 'pe': [],
266
- 'f1': [],
267
- 'precision': [],
268
- 'recall': [],
269
- 'rmse': [],
270
- 'delta_1': []
271
- }
272
-
273
- show_index = np.random.randint(0, data_len)
274
- for i, gt in bar:
275
- imgs = gt['image'].to(device, non_blocking=True)
276
- gt['depth'] = gt['depth'].to(device, non_blocking=True)
277
- gt['ratio'] = gt['ratio'].to(device, non_blocking=True)
278
- if 'corner_heat_map' in gt:
279
- gt['corner_heat_map'] = gt['corner_heat_map'].to(device, non_blocking=True)
280
- dt = model(imgs)
281
-
282
- vis_w = config.TRAIN.VIS_WEIGHT
283
- visualization = False # (config.LOCAL_RANK == 0 and i == show_index) or config.SAVE_EVAL
284
-
285
- loss, batch_loss_d, epoch_loss_d = calc_criterion(criterion, gt, dt, epoch_loss_d)
286
-
287
- if config.EVAL.POST_PROCESSING is not None:
288
- depth = tensor2np(dt['depth'])
289
- dt['processed_xyz'] = post_process(depth, type_name=config.EVAL.POST_PROCESSING,
290
- need_cube=config.EVAL.FORCE_CUBE)
291
-
292
- if config.EVAL.FORCE_CUBE and config.EVAL.NEED_CPE:
293
- ce = calc_ce(tensor2np_d(dt), tensor2np_d(gt))
294
- pe = calc_pe(tensor2np_d(dt), tensor2np_d(gt))
295
-
296
- epoch_other_d['ce'].append(ce)
297
- epoch_other_d['pe'].append(pe)
298
-
299
- if config.EVAL.NEED_F1:
300
- f1, precision, recall = calc_f1_score(tensor2np_d(dt), tensor2np_d(gt))
301
- epoch_other_d['f1'].append(f1)
302
- epoch_other_d['precision'].append(precision)
303
- epoch_other_d['recall'].append(recall)
304
-
305
- if config.EVAL.NEED_RMSE:
306
- rmse, delta_1 = calc_rmse_delta_1(tensor2np_d(dt), tensor2np_d(gt))
307
- epoch_other_d['rmse'].append(rmse)
308
- epoch_other_d['delta_1'].append(delta_1)
309
-
310
- visb_iou, full_iou, iou_height, pano_bds, full_iou_2ds = calc_accuracy(tensor2np_d(dt), tensor2np_d(gt),
311
- visualization, h=vis_w // 2)
312
- epoch_iou_d['visible_2d'].append(visb_iou[0])
313
- epoch_iou_d['visible_3d'].append(visb_iou[1])
314
- epoch_iou_d['full_2d'].append(full_iou[0])
315
- epoch_iou_d['full_3d'].append(full_iou[1])
316
- epoch_iou_d['height'].append(iou_height)
317
-
318
- if config.LOCAL_RANK == 0 and config.SHOW_BAR:
319
- bar.set_postfix(batch_loss_d)
320
-
321
- global_step = start_i + i * config.WORLD_SIZE + config.LOCAL_RANK
322
-
323
- if writer:
324
- for key, val in batch_loss_d.items():
325
- writer.add_scalar(f'ValBatchLoss/{key}', val, global_step)
326
-
327
- if not visualization:
328
- continue
329
-
330
- gt_grad_imgs, dt_grad_imgs = show_depth_normal_grad(dt, gt, device, vis_w)
331
-
332
- dt_heat_map_imgs = None
333
- gt_heat_map_imgs = None
334
- if 'corner_heat_map' in gt:
335
- dt_heat_map_imgs, gt_heat_map_imgs = show_heat_map(dt, gt, vis_w)
336
-
337
- if config.TRAIN.VIS_MERGE or config.SAVE_EVAL:
338
- imgs = []
339
- for j in range(len(pano_bds)):
340
- # floorplan = np.concatenate([visb_iou[2][j], full_iou[2][j]], axis=-1)
341
- floorplan = full_iou[2][j]
342
- margin_w = int(floorplan.shape[-1] * (60/512))
343
- floorplan = floorplan[:, :, margin_w:-margin_w]
344
-
345
- grad_h = dt_grad_imgs[0].shape[1]
346
- vis_merge = [
347
- gt_grad_imgs[j],
348
- pano_bds[j][:, grad_h:-grad_h],
349
- dt_grad_imgs[j]
350
- ]
351
- if 'corner_heat_map' in gt:
352
- vis_merge = [dt_heat_map_imgs[j], gt_heat_map_imgs[j]] + vis_merge
353
- img = np.concatenate(vis_merge, axis=-2)
354
-
355
- img = np.concatenate([img, ], axis=-1)
356
- # img = gt_grad_imgs[j]
357
- imgs.append(img)
358
- if writer:
359
- writer.add_images('VIS/Merge', np.array(imgs), global_step)
360
-
361
- if config.SAVE_EVAL:
362
- for k in range(len(imgs)):
363
- img = imgs[k] * 255.0
364
- save_path = os.path.join(config.CKPT.RESULT_DIR, f"{gt['id'][k]}_{full_iou_2ds[k]:.5f}.png")
365
- Image.fromarray(img.transpose(1, 2, 0).astype(np.uint8)).save(save_path)
366
-
367
- elif writer:
368
- writer.add_images('IoU/Visible_Floorplan', visb_iou[2], global_step)
369
- writer.add_images('IoU/Full_Floorplan', full_iou[2], global_step)
370
- writer.add_images('IoU/Boundary', pano_bds, global_step)
371
- writer.add_images('Grad/gt', gt_grad_imgs, global_step)
372
- writer.add_images('Grad/dt', dt_grad_imgs, global_step)
373
-
374
- if config.LOCAL_RANK != 0:
375
- return
376
-
377
- epoch_loss_d = dict(zip(epoch_loss_d.keys(), [np.array(epoch_loss_d[k]).mean() for k in epoch_loss_d.keys()]))
378
- s = 'ValEpochLoss: '
379
- for key, val in epoch_loss_d.items():
380
- if writer:
381
- writer.add_scalar(f'ValEpochLoss/{key}', val, epoch)
382
- s += f" {key}={val}"
383
- logger.info(s)
384
-
385
- epoch_iou_d = dict(zip(epoch_iou_d.keys(), [np.array(epoch_iou_d[k]).mean() for k in epoch_iou_d.keys()]))
386
- s = 'ValEpochIoU: '
387
- for key, val in epoch_iou_d.items():
388
- if writer:
389
- writer.add_scalar(f'ValEpochIoU/{key}', val, epoch)
390
- s += f" {key}={val}"
391
- logger.info(s)
392
- epoch_other_d = dict(zip(epoch_other_d.keys(),
393
- [np.array(epoch_other_d[k]).mean() if len(epoch_other_d[k]) > 0 else 0 for k in
394
- epoch_other_d.keys()]))
395
-
396
- logger.info(f'other acc: {epoch_other_d}')
397
- return epoch_iou_d, epoch_other_d
398
-
399
-
400
- if __name__ == '__main__':
401
- main()