plutosss commited on
Commit
9608d23
·
verified ·
1 Parent(s): caa4e03

Delete teed

Browse files
Files changed (1) hide show
  1. teed/main.py +0 -530
teed/main.py DELETED
@@ -1,530 +0,0 @@
1
- """
2
- Hello, welcome on board,
3
- """
4
- from __future__ import print_function
5
-
6
- import argparse
7
- import os
8
- import time, platform
9
- import cv2
10
- import numpy as np
11
- os.environ['CUDA_LAUNCH_BLOCKING']="0"
12
- import torch
13
- import torch.optim as optim
14
- from torch.utils.data import DataLoader
15
- from thop import profile
16
-
17
- from TEED.dataset import DATASET_NAMES, BipedDataset, TestDataset, dataset_info
18
- from TEED.loss2 import *
19
-
20
- from TEED.ted import TED # TEED architecture
21
-
22
- from TEED.utils.img_processing import (image_normalization, save_image_batch_to_disk,
23
- visualize_result, count_parameters)
24
-
25
- is_testing =True # set False to train with TEED model
26
- IS_LINUX = True if platform.system()=="Linux" else False
27
-
28
- def train_one_epoch(epoch, dataloader, model, criterions, optimizer, device,
29
- log_interval_vis, tb_writer, args=None):
30
-
31
- imgs_res_folder = os.path.join(args.output_dir, 'current_res')
32
- os.makedirs(imgs_res_folder,exist_ok=True)
33
- show_log = args.show_log
34
- if isinstance(criterions, list):
35
- criterion1, criterion2 = criterions
36
- else:
37
- criterion1 = criterions
38
-
39
- # Put model in training mode
40
- model.train()
41
-
42
- l_weight0 = [1.1,0.7,1.1,1.3] # for bdcn loss2-B4
43
- l_weight = [[0.05, 2.], [0.05, 2.], [0.01, 1.],
44
- [0.01, 3.]] # for cats loss [0.01, 4.]
45
- loss_avg =[]
46
- for batch_id, sample_batched in enumerate(dataloader):
47
- images = sample_batched['images'].to(device) # BxCxHxW
48
- labels = sample_batched['labels'].to(device) # BxHxW
49
- preds_list = model(images)
50
- loss1 = sum([criterion2(preds, labels,l_w) for preds, l_w in zip(preds_list[:-1],l_weight0)]) # bdcn_loss2 [1,2,3] TEED
51
- loss2 = criterion1(preds_list[-1], labels, l_weight[-1], device) # cats_loss [dfuse] TEED
52
- tLoss = loss2+loss1 # TEED
53
-
54
- optimizer.zero_grad()
55
- tLoss.backward()
56
- optimizer.step()
57
- loss_avg.append(tLoss.item())
58
- if epoch==0 and (batch_id==100 and tb_writer is not None):
59
- tmp_loss = np.array(loss_avg).mean()
60
- tb_writer.add_scalar('loss', tmp_loss,epoch)
61
-
62
- if batch_id % (show_log) == 0:
63
- print(time.ctime(), 'Epoch: {0} Sample {1}/{2} Loss: {3}'
64
- .format(epoch, batch_id, len(dataloader), format(tLoss.item(),'.4f')))
65
- if batch_id % log_interval_vis == 0:
66
- res_data = []
67
-
68
- img = images.cpu().numpy()
69
- res_data.append(img[2])
70
-
71
- ed_gt = labels.cpu().numpy()
72
- res_data.append(ed_gt[2])
73
-
74
- # tmp_pred = tmp_preds[2,...]
75
- for i in range(len(preds_list)):
76
- tmp = preds_list[i]
77
- tmp = tmp[2]
78
- # print(tmp.shape)
79
- tmp = torch.sigmoid(tmp).unsqueeze(dim=0)
80
- tmp = tmp.cpu().detach().numpy()
81
- res_data.append(tmp)
82
-
83
- vis_imgs = visualize_result(res_data, arg=args)
84
- del tmp, res_data
85
-
86
- vis_imgs = cv2.resize(vis_imgs,
87
- (int(vis_imgs.shape[1]*0.8), int(vis_imgs.shape[0]*0.8)))
88
- img_test = 'Epoch: {0} Iter: {1}/{2} Loss: {3}' \
89
- .format(epoch, batch_id, len(dataloader), round(tLoss.item(),4))
90
-
91
- BLACK = (0, 0, 255)
92
- font = cv2.FONT_HERSHEY_SIMPLEX
93
- font_size = 0.9
94
- font_color = BLACK
95
- font_thickness = 2
96
- x, y = 30, 30
97
- vis_imgs = cv2.putText(vis_imgs,
98
- img_test,
99
- (x, y),
100
- font, font_size, font_color, font_thickness, cv2.LINE_AA)
101
- # tmp_vis_name = str(batch_id)+'-results.png'
102
- # cv2.imwrite(os.path.join(imgs_res_folder, tmp_vis_name), vis_imgs)
103
- cv2.imwrite(os.path.join(imgs_res_folder, 'results.png'), vis_imgs)
104
- loss_avg = np.array(loss_avg).mean()
105
- return loss_avg
106
-
107
- def validate_one_epoch(epoch, dataloader, model, device, output_dir, arg=None,test_resize=False):
108
- # XXX This is not really validation, but testing
109
-
110
- # Put model in eval mode
111
- model.eval()
112
-
113
- with torch.no_grad():
114
- for _, sample_batched in enumerate(dataloader):
115
- images = sample_batched['images'].to(device)
116
- # labels = sample_batched['labels'].to(device)
117
- file_names = sample_batched['file_names']
118
- image_shape = sample_batched['image_shape']
119
- preds = model(images,single_test=test_resize)
120
- # print('pred shape', preds[0].shape)
121
- save_image_batch_to_disk(preds[-1],
122
- output_dir,
123
- file_names,img_shape=image_shape,
124
- arg=arg)
125
-
126
-
127
- def test(checkpoint_path, dataloader, model, device, output_dir, args,resize_input=False):
128
- if not os.path.isfile(checkpoint_path):
129
- raise FileNotFoundError(
130
- f"Checkpoint filte note found: {checkpoint_path}")
131
- print(f"Restoring weights from: {checkpoint_path}")
132
- model.load_state_dict(torch.load(checkpoint_path,
133
- map_location=device))
134
-
135
- model.eval()
136
- # just for the new dataset
137
- # os.makedirs(os.path.join(output_dir,"healthy"), exist_ok=True)
138
- # os.makedirs(os.path.join(output_dir,"infected"), exist_ok=True)
139
-
140
- with torch.no_grad():
141
- total_duration = []
142
- for batch_id, sample_batched in enumerate(dataloader):
143
- images = sample_batched['images'].to(device)
144
- # if not args.test_data == "CLASSIC":
145
- labels = sample_batched['labels'].to(device)
146
- file_names = sample_batched['file_names']
147
- image_shape = sample_batched['image_shape']
148
-
149
-
150
- print(f"{file_names}: {images.shape}")
151
- end = time.perf_counter()
152
- if device.type == 'cuda':
153
- torch.cuda.synchronize()
154
- preds = model(images, single_test=resize_input)
155
- if device.type == 'cuda':
156
- torch.cuda.synchronize()
157
- tmp_duration = time.perf_counter() - end
158
- total_duration.append(tmp_duration)
159
- save_image_batch_to_disk(preds,
160
- output_dir, # output_dir
161
- file_names,
162
- image_shape,
163
- arg=args)
164
- torch.cuda.empty_cache()
165
- total_duration = np.sum(np.array(total_duration))
166
- print("******** Testing finished in", args.test_data, "dataset. *****")
167
- print("FPS: %f.4" % (len(dataloader)/total_duration))
168
- # print("Time spend in the Dataset: %f.4" % total_duration.sum(), "seconds")
169
-
170
- def testPich(checkpoint_path, dataloader, model, device, output_dir, args, resize_input=False):
171
- # a test model plus the interganged channels
172
- if not os.path.isfile(checkpoint_path):
173
- raise FileNotFoundError(
174
- f"Checkpoint filte note found: {checkpoint_path}")
175
- print(f"Restoring weights from: {checkpoint_path}")
176
- model.load_state_dict(torch.load(checkpoint_path,
177
- map_location=device))
178
-
179
- model.eval()
180
-
181
- with torch.no_grad():
182
- total_duration = []
183
- for batch_id, sample_batched in enumerate(dataloader):
184
- images = sample_batched['images'].to(device)
185
- if not args.test_data == "CLASSIC":
186
- labels = sample_batched['labels'].to(device)
187
- file_names = sample_batched['file_names']
188
- image_shape = sample_batched['image_shape']
189
- print(f"input tensor shape: {images.shape}")
190
- start_time = time.time()
191
- images2 = images[:, [1, 0, 2], :, :] #GBR
192
- # images2 = images[:, [2, 1, 0], :, :] # RGB
193
- preds = model(images,single_test=resize_input)
194
- preds2 = model(images2,single_test=resize_input)
195
- tmp_duration = time.time() - start_time
196
- total_duration.append(tmp_duration)
197
- save_image_batch_to_disk([preds,preds2],
198
- output_dir,
199
- file_names,
200
- image_shape,
201
- arg=args, is_inchannel=True)
202
- torch.cuda.empty_cache()
203
-
204
- total_duration = np.array(total_duration)
205
- print("******** Testing finished in", args.test_data, "dataset. *****")
206
- print("Average time per image: %f.4" % total_duration.mean(), "seconds")
207
- print("Time spend in the Dataset: %f.4" % total_duration.sum(), "seconds")
208
-
209
- def parse_args(is_testing=True, pl_opt_dir = 'output/teed'):
210
- """Parse command line arguments."""
211
- parser = argparse.ArgumentParser(description='TEED model')
212
- parser.add_argument('--choose_test_data',
213
- type=int,
214
- default=-1, # UDED=15
215
- help='Choose a dataset for testing: 0 - 15')
216
-
217
- # ----------- test -------0--
218
- TEST_DATA = DATASET_NAMES[parser.parse_args().choose_test_data] # max 8
219
- test_inf = dataset_info(TEST_DATA, is_linux=IS_LINUX)
220
-
221
- # Training settings
222
- # BIPED-B2=1, BIPDE-B3=2, just for evaluation, using LDC trained with 2 or 3 bloacks
223
- TRAIN_DATA = DATASET_NAMES[0] # BIPED=0, BRIND=6, MDBD=10, BIPBRI=13
224
- train_inf = dataset_info(TRAIN_DATA, is_linux=IS_LINUX)
225
- train_dir = train_inf['data_dir']
226
-
227
- # Data parameters
228
- parser.add_argument('--input_dir',
229
- type=str,
230
- default=train_dir,
231
- help='the path to the directory with the input data.')
232
- parser.add_argument('--input_val_dir',
233
- type=str,
234
- default=test_inf['data_dir'],
235
- help='the path to the directory with the input data for validation.')
236
- parser.add_argument('--output_dir',
237
- type=str,
238
- default='checkpoints',
239
- help='the path to output the results.')
240
- parser.add_argument('--train_data',
241
- type=str,
242
- choices=DATASET_NAMES,
243
- default=TRAIN_DATA,
244
- help='Name of the dataset.')# TRAIN_DATA,BIPED-B3
245
- parser.add_argument('--test_data',
246
- type=str,
247
- choices=DATASET_NAMES,
248
- default=TEST_DATA,
249
- help='Name of the dataset.')
250
- parser.add_argument('--test_list',
251
- type=str,
252
- default=test_inf['test_list'],
253
- help='Dataset sample indices list.')
254
- parser.add_argument('--train_list',
255
- type=str,
256
- default=train_inf['train_list'],
257
- help='Dataset sample indices list.')
258
- parser.add_argument('--is_testing',type=bool,
259
- default=is_testing,
260
- help='Script in testing mode.')
261
- parser.add_argument('--predict_all',
262
- type=bool,
263
- default=False,
264
- help='True: Generate all TEED outputs in all_edges ')
265
- parser.add_argument('--up_scale',
266
- type=bool,
267
- default=False, # for Upsale test set in 30%
268
- help='True: up scale x1.5 test image') # Just for test
269
-
270
- parser.add_argument('--resume',
271
- type=bool,
272
- default=False,
273
- help='use previous trained data') # Just for test
274
- parser.add_argument('--checkpoint_data',
275
- type=str,
276
- default='5/5_model.pth',# 37 for biped 60 MDBD
277
- help='Checkpoint path.')
278
- parser.add_argument('--test_img_width',
279
- type=int,
280
- default=test_inf['img_width'],
281
- help='Image width for testing.')
282
- parser.add_argument('--test_img_height',
283
- type=int,
284
- default=test_inf['img_height'],
285
- help='Image height for testing.')
286
- parser.add_argument('--res_dir',
287
- type=str,
288
- default='result',
289
- help='Result directory')
290
- parser.add_argument('--use_gpu',type=int,
291
- default=0, help='use GPU')
292
- parser.add_argument('--log_interval_vis',
293
- type=int,
294
- default=200,# 100
295
- help='Interval to visualize predictions. 200')
296
- parser.add_argument('--show_log', type=int, default=20, help='display logs')
297
- parser.add_argument('--epochs',
298
- type=int,
299
- default=8,
300
- metavar='N',
301
- help='Number of training epochs (default: 25).')
302
- parser.add_argument('--lr', default=8e-4, type=float,
303
- help='Initial learning rate. =1e-3') # 1e-3
304
- parser.add_argument('--lrs', default=[8e-5], type=float,
305
- help='LR for epochs') # [7e-5]
306
- parser.add_argument('--wd', type=float, default=2e-4, metavar='WD',
307
- help='weight decay (Good 5e-4/1e-4 )') # good 12e-5
308
- parser.add_argument('--adjust_lr', default=[4], type=int,
309
- help='Learning rate step size.') # [4] [6,9,19]
310
- parser.add_argument('--version_notes',
311
- default='TEED BIPED+BRIND-trainingdataLoader BRIND light AF -USNet--noBN xav init normal bdcnLoss2+cats2loss +DoubleFusion-3AF, AF sum',
312
- type=str,
313
- help='version notes')
314
- parser.add_argument('--batch_size',
315
- type=int,
316
- default=8,
317
- metavar='B',
318
- help='the mini-batch size (default: 8)')
319
- parser.add_argument('--workers',
320
- default=8,
321
- type=int,
322
- help='The number of workers for the dataloaders.')
323
- parser.add_argument('--tensorboard',type=bool,
324
- default=True,
325
- help='Use Tensorboard for logging.'),
326
- parser.add_argument('--img_width',
327
- type=int,
328
- default=300,
329
- help='Image width for training.') # BIPED 352/300 BRIND 256 MDBD 480
330
- parser.add_argument('--img_height',
331
- type=int,
332
- default=300,
333
- help='Image height for training.') # BIPED 352/300 BSDS 352/320
334
- parser.add_argument('--channel_swap',
335
- default=[2, 1, 0],
336
- type=int)
337
- parser.add_argument('--resume_chpt',
338
- default='result/resume/',
339
- type=str,
340
- help='resume training')
341
- parser.add_argument('--pl_opt_dir',
342
- default=pl_opt_dir,
343
- type=str,
344
- help='pl output directory')
345
- parser.add_argument('--crop_img',
346
- default=True,
347
- type=bool,
348
- help='If true crop training images, else resize images to match image width and height.')
349
- parser.add_argument('--mean_test',
350
- default=test_inf['mean'],
351
- type=float)
352
- parser.add_argument('--mean_train',
353
- default=train_inf['mean'],
354
- type=float) # [103.939,116.779,123.68,137.86] [104.00699, 116.66877, 122.67892]
355
-
356
- args = parser.parse_args()
357
- return args, train_inf
358
-
359
-
360
- def main(args, train_inf):
361
-
362
- # Tensorboard summary writer
363
-
364
- # torch.autograd.set_detect_anomaly(True)
365
- tb_writer = None
366
- training_dir = os.path.join(args.output_dir,args.train_data)
367
- os.makedirs(training_dir,exist_ok=True)
368
- checkpoint_path = os.path.join('./teed',args.output_dir)
369
- checkpoint_path = os.path.join(checkpoint_path, args.train_data,args.checkpoint_data)
370
- if args.tensorboard and not args.is_testing:
371
- # from tensorboardX import SummaryWriter # previous torch version
372
- from torch.utils.tensorboard import SummaryWriter # for torch 1.4 or greather
373
- tb_writer = SummaryWriter(log_dir=training_dir)
374
- # saving training settings
375
- training_notes =[args.version_notes+ ' RL= ' + str(args.lr) + ' WD= '
376
- + str(args.wd) + ' image size = ' + str(args.img_width)
377
- + ' adjust LR=' + str(args.adjust_lr) +' LRs= '
378
- + str(args.lrs)+' Loss Function= BDCNloss2 + CAST-loss2.py '
379
- + str(time.asctime())+' trained on '+args.train_data]
380
- info_txt = open(os.path.join(training_dir, 'training_settings.txt'), 'w')
381
- info_txt.write(str(training_notes))
382
- info_txt.close()
383
- print("Training details> ",training_notes)
384
-
385
- # Get computing device
386
- device = torch.device('cpu' if torch.cuda.device_count() == 0
387
- else 'cuda')
388
- # torch.cuda.set_device(args.use_gpu) # set a desired gpu
389
-
390
- print(f"Number of GPU's available: {torch.cuda.device_count()}")
391
- print(f"Pytorch version: {torch.__version__}")
392
- # print(f'GPU: {torch.cuda.get_device_name()}')
393
- print(f'Trainimage mean: {args.mean_train}')
394
- print(f'Test image mean: {args.mean_test}')
395
-
396
-
397
- # Instantiate model and move it to the computing device
398
- model = TED().to(device)
399
- # model = nn.DataParallel(model)
400
- ini_epoch =0
401
- if not args.is_testing:
402
- if args.resume:
403
- checkpoint_path2= os.path.join(args.output_dir, 'BIPED-54-B4',args.checkpoint_data)
404
- ini_epoch=8
405
- model.load_state_dict(torch.load(checkpoint_path2,
406
- map_location=device))
407
-
408
- # Training dataset loading...
409
- dataset_train = BipedDataset(args.input_dir,
410
- img_width=args.img_width,
411
- img_height=args.img_height,
412
- train_mode='train',
413
- arg=args
414
- )
415
- dataloader_train = DataLoader(dataset_train,
416
- batch_size=args.batch_size,
417
- shuffle=True,
418
- num_workers=args.workers)
419
- # Test dataset loading...
420
- dataset_val = TestDataset(args.input_val_dir,
421
- test_data=args.test_data,
422
- img_width=args.test_img_width,
423
- img_height=args.test_img_height,
424
- test_list=args.test_list, arg=args
425
- )
426
- dataloader_val = DataLoader(dataset_val,
427
- batch_size=1,
428
- shuffle=False,
429
- num_workers=args.workers)
430
- # Testing
431
- if_resize_img = False if args.test_data in ['BIPED', 'CID', 'MDBD'] else True
432
- if args.is_testing:
433
-
434
- # output_dir = os.path.join(args.res_dir, args.train_data+"2"+ args.test_data)
435
- output_dir = args.pl_opt_dir
436
- print(f"output_dir: {output_dir}")
437
-
438
- test(checkpoint_path, dataloader_val, model, device,
439
- output_dir, args,if_resize_img)
440
-
441
- # Count parameters:
442
- num_param = count_parameters(model)
443
- print('-------------------------------------------------------')
444
- print('TED parameters:')
445
- print(num_param)
446
- print('-------------------------------------------------------')
447
- return
448
-
449
- criterion1 = cats_loss #bdcn_loss2
450
- criterion2 = bdcn_loss2#cats_loss#f1_accuracy2
451
- criterion = [criterion1,criterion2]
452
- optimizer = optim.Adam(model.parameters(),
453
- lr=args.lr,
454
- weight_decay=args.wd)
455
-
456
- # Count parameters:
457
- num_param = count_parameters(model)
458
- print('-------------------------------------------------------')
459
- print('TEED parameters:')
460
- print(num_param)
461
- print('-------------------------------------------------------')
462
-
463
- # Main training loop
464
- seed=1021
465
- adjust_lr = args.adjust_lr
466
- k=0
467
- set_lr = args.lrs#[25e-4, 5e-6]
468
- for epoch in range(ini_epoch,args.epochs):
469
- if epoch%5==0: # before 7
470
-
471
- seed = seed+1000
472
- np.random.seed(seed)
473
- torch.manual_seed(seed)
474
- torch.cuda.manual_seed(seed)
475
- print("------ Random seed applied-------------")
476
- # adjust learning rate
477
- if adjust_lr is not None:
478
- if epoch in adjust_lr:
479
- lr2 = set_lr[k]
480
- for param_group in optimizer.param_groups:
481
- param_group['lr'] = lr2
482
- k+=1
483
- # Create output directories
484
-
485
- output_dir_epoch = os.path.join(args.output_dir,args.train_data, str(epoch))
486
- img_test_dir = os.path.join(output_dir_epoch, args.test_data + '_res')
487
- os.makedirs(output_dir_epoch,exist_ok=True)
488
- os.makedirs(img_test_dir,exist_ok=True)
489
- print("**************** Validating the training from the scratch **********")
490
- # validate_one_epoch(epoch,
491
- # dataloader_val,
492
- # model,
493
- # device,
494
- # img_test_dir,
495
- # arg=args,test_resize=if_resize_img)
496
-
497
- avg_loss =train_one_epoch(epoch,dataloader_train,
498
- model, criterion,
499
- optimizer,
500
- device,
501
- args.log_interval_vis,
502
- tb_writer=tb_writer,
503
- args=args)
504
- validate_one_epoch(epoch,
505
- dataloader_val,
506
- model,
507
- device,
508
- img_test_dir,
509
- arg=args, test_resize=if_resize_img)
510
-
511
- # Save model after end of every epoch
512
- torch.save(model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
513
- os.path.join(output_dir_epoch, '{0}_model.pth'.format(epoch)))
514
- if tb_writer is not None:
515
- tb_writer.add_scalar('loss',
516
- avg_loss,
517
- epoch+1)
518
- print('Last learning rate> ', optimizer.param_groups[0]['lr'])
519
-
520
- num_param = count_parameters(model)
521
- print('-------------------------------------------------------')
522
- print('TEED parameters:')
523
- print(num_param)
524
- print('-------------------------------------------------------')
525
-
526
- if __name__ == '__main__':
527
- # os.system(" ".join(command))
528
- is_testing =True # True to use TEED for testing
529
- args, train_info = parse_args(is_testing=is_testing)
530
- main(args, train_info)