plutosss commited on
Commit
687bed1
·
verified ·
1 Parent(s): 6e63578

Upload main.py

Browse files
Files changed (1) hide show
  1. TEED/main.py +530 -0
TEED/main.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hello, welcome on board,
3
+ """
4
+ from __future__ import print_function
5
+
6
+ import argparse
7
+ import os
8
+ import time, platform
9
+ import cv2
10
+ import numpy as np
11
+ os.environ['CUDA_LAUNCH_BLOCKING']="0"
12
+ import torch
13
+ import torch.optim as optim
14
+ from torch.utils.data import DataLoader
15
+ from thop import profile
16
+
17
+ from TEED.dataset import DATASET_NAMES, BipedDataset, TestDataset, dataset_info
18
+ from TEED.loss2 import *
19
+
20
+ from TEED.ted import TED # TEED architecture
21
+
22
+ from TEED.utils.img_processing import (image_normalization, save_image_batch_to_disk,
23
+ visualize_result, count_parameters)
24
+
25
+ is_testing =True # set False to train with TEED model
26
+ IS_LINUX = True if platform.system()=="Linux" else False
27
+
28
+ def train_one_epoch(epoch, dataloader, model, criterions, optimizer, device,
29
+ log_interval_vis, tb_writer, args=None):
30
+
31
+ imgs_res_folder = os.path.join(args.output_dir, 'current_res')
32
+ os.makedirs(imgs_res_folder,exist_ok=True)
33
+ show_log = args.show_log
34
+ if isinstance(criterions, list):
35
+ criterion1, criterion2 = criterions
36
+ else:
37
+ criterion1 = criterions
38
+
39
+ # Put model in training mode
40
+ model.train()
41
+
42
+ l_weight0 = [1.1,0.7,1.1,1.3] # for bdcn loss2-B4
43
+ l_weight = [[0.05, 2.], [0.05, 2.], [0.01, 1.],
44
+ [0.01, 3.]] # for cats loss [0.01, 4.]
45
+ loss_avg =[]
46
+ for batch_id, sample_batched in enumerate(dataloader):
47
+ images = sample_batched['images'].to(device) # BxCxHxW
48
+ labels = sample_batched['labels'].to(device) # BxHxW
49
+ preds_list = model(images)
50
+ loss1 = sum([criterion2(preds, labels,l_w) for preds, l_w in zip(preds_list[:-1],l_weight0)]) # bdcn_loss2 [1,2,3] TEED
51
+ loss2 = criterion1(preds_list[-1], labels, l_weight[-1], device) # cats_loss [dfuse] TEED
52
+ tLoss = loss2+loss1 # TEED
53
+
54
+ optimizer.zero_grad()
55
+ tLoss.backward()
56
+ optimizer.step()
57
+ loss_avg.append(tLoss.item())
58
+ if epoch==0 and (batch_id==100 and tb_writer is not None):
59
+ tmp_loss = np.array(loss_avg).mean()
60
+ tb_writer.add_scalar('loss', tmp_loss,epoch)
61
+
62
+ if batch_id % (show_log) == 0:
63
+ print(time.ctime(), 'Epoch: {0} Sample {1}/{2} Loss: {3}'
64
+ .format(epoch, batch_id, len(dataloader), format(tLoss.item(),'.4f')))
65
+ if batch_id % log_interval_vis == 0:
66
+ res_data = []
67
+
68
+ img = images.cpu().numpy()
69
+ res_data.append(img[2])
70
+
71
+ ed_gt = labels.cpu().numpy()
72
+ res_data.append(ed_gt[2])
73
+
74
+ # tmp_pred = tmp_preds[2,...]
75
+ for i in range(len(preds_list)):
76
+ tmp = preds_list[i]
77
+ tmp = tmp[2]
78
+ # print(tmp.shape)
79
+ tmp = torch.sigmoid(tmp).unsqueeze(dim=0)
80
+ tmp = tmp.cpu().detach().numpy()
81
+ res_data.append(tmp)
82
+
83
+ vis_imgs = visualize_result(res_data, arg=args)
84
+ del tmp, res_data
85
+
86
+ vis_imgs = cv2.resize(vis_imgs,
87
+ (int(vis_imgs.shape[1]*0.8), int(vis_imgs.shape[0]*0.8)))
88
+ img_test = 'Epoch: {0} Iter: {1}/{2} Loss: {3}' \
89
+ .format(epoch, batch_id, len(dataloader), round(tLoss.item(),4))
90
+
91
+ BLACK = (0, 0, 255)
92
+ font = cv2.FONT_HERSHEY_SIMPLEX
93
+ font_size = 0.9
94
+ font_color = BLACK
95
+ font_thickness = 2
96
+ x, y = 30, 30
97
+ vis_imgs = cv2.putText(vis_imgs,
98
+ img_test,
99
+ (x, y),
100
+ font, font_size, font_color, font_thickness, cv2.LINE_AA)
101
+ # tmp_vis_name = str(batch_id)+'-results.png'
102
+ # cv2.imwrite(os.path.join(imgs_res_folder, tmp_vis_name), vis_imgs)
103
+ cv2.imwrite(os.path.join(imgs_res_folder, 'results.png'), vis_imgs)
104
+ loss_avg = np.array(loss_avg).mean()
105
+ return loss_avg
106
+
107
+ def validate_one_epoch(epoch, dataloader, model, device, output_dir, arg=None,test_resize=False):
108
+ # XXX This is not really validation, but testing
109
+
110
+ # Put model in eval mode
111
+ model.eval()
112
+
113
+ with torch.no_grad():
114
+ for _, sample_batched in enumerate(dataloader):
115
+ images = sample_batched['images'].to(device)
116
+ # labels = sample_batched['labels'].to(device)
117
+ file_names = sample_batched['file_names']
118
+ image_shape = sample_batched['image_shape']
119
+ preds = model(images,single_test=test_resize)
120
+ # print('pred shape', preds[0].shape)
121
+ save_image_batch_to_disk(preds[-1],
122
+ output_dir,
123
+ file_names,img_shape=image_shape,
124
+ arg=arg)
125
+
126
+
127
+ def test(checkpoint_path, dataloader, model, device, output_dir, args,resize_input=False):
128
+ if not os.path.isfile(checkpoint_path):
129
+ raise FileNotFoundError(
130
+ f"Checkpoint filte note found: {checkpoint_path}")
131
+ print(f"Restoring weights from: {checkpoint_path}")
132
+ model.load_state_dict(torch.load(checkpoint_path,
133
+ map_location=device))
134
+
135
+ model.eval()
136
+ # just for the new dataset
137
+ # os.makedirs(os.path.join(output_dir,"healthy"), exist_ok=True)
138
+ # os.makedirs(os.path.join(output_dir,"infected"), exist_ok=True)
139
+
140
+ with torch.no_grad():
141
+ total_duration = []
142
+ for batch_id, sample_batched in enumerate(dataloader):
143
+ images = sample_batched['images'].to(device)
144
+ # if not args.test_data == "CLASSIC":
145
+ labels = sample_batched['labels'].to(device)
146
+ file_names = sample_batched['file_names']
147
+ image_shape = sample_batched['image_shape']
148
+
149
+
150
+ print(f"{file_names}: {images.shape}")
151
+ end = time.perf_counter()
152
+ if device.type == 'cuda':
153
+ torch.cuda.synchronize()
154
+ preds = model(images, single_test=resize_input)
155
+ if device.type == 'cuda':
156
+ torch.cuda.synchronize()
157
+ tmp_duration = time.perf_counter() - end
158
+ total_duration.append(tmp_duration)
159
+ save_image_batch_to_disk(preds,
160
+ output_dir, # output_dir
161
+ file_names,
162
+ image_shape,
163
+ arg=args)
164
+ torch.cuda.empty_cache()
165
+ total_duration = np.sum(np.array(total_duration))
166
+ print("******** Testing finished in", args.test_data, "dataset. *****")
167
+ print("FPS: %f.4" % (len(dataloader)/total_duration))
168
+ # print("Time spend in the Dataset: %f.4" % total_duration.sum(), "seconds")
169
+
170
+ def testPich(checkpoint_path, dataloader, model, device, output_dir, args, resize_input=False):
171
+ # a test model plus the interganged channels
172
+ if not os.path.isfile(checkpoint_path):
173
+ raise FileNotFoundError(
174
+ f"Checkpoint filte note found: {checkpoint_path}")
175
+ print(f"Restoring weights from: {checkpoint_path}")
176
+ model.load_state_dict(torch.load(checkpoint_path,
177
+ map_location=device))
178
+
179
+ model.eval()
180
+
181
+ with torch.no_grad():
182
+ total_duration = []
183
+ for batch_id, sample_batched in enumerate(dataloader):
184
+ images = sample_batched['images'].to(device)
185
+ if not args.test_data == "CLASSIC":
186
+ labels = sample_batched['labels'].to(device)
187
+ file_names = sample_batched['file_names']
188
+ image_shape = sample_batched['image_shape']
189
+ print(f"input tensor shape: {images.shape}")
190
+ start_time = time.time()
191
+ images2 = images[:, [1, 0, 2], :, :] #GBR
192
+ # images2 = images[:, [2, 1, 0], :, :] # RGB
193
+ preds = model(images,single_test=resize_input)
194
+ preds2 = model(images2,single_test=resize_input)
195
+ tmp_duration = time.time() - start_time
196
+ total_duration.append(tmp_duration)
197
+ save_image_batch_to_disk([preds,preds2],
198
+ output_dir,
199
+ file_names,
200
+ image_shape,
201
+ arg=args, is_inchannel=True)
202
+ torch.cuda.empty_cache()
203
+
204
+ total_duration = np.array(total_duration)
205
+ print("******** Testing finished in", args.test_data, "dataset. *****")
206
+ print("Average time per image: %f.4" % total_duration.mean(), "seconds")
207
+ print("Time spend in the Dataset: %f.4" % total_duration.sum(), "seconds")
208
+
209
+ def parse_args(is_testing=True, pl_opt_dir = 'output/teed'):
210
+ """Parse command line arguments."""
211
+ parser = argparse.ArgumentParser(description='TEED model')
212
+ parser.add_argument('--choose_test_data',
213
+ type=int,
214
+ default=-1, # UDED=15
215
+ help='Choose a dataset for testing: 0 - 15')
216
+
217
+ # ----------- test -------0--
218
+ TEST_DATA = DATASET_NAMES[parser.parse_args().choose_test_data] # max 8
219
+ test_inf = dataset_info(TEST_DATA, is_linux=IS_LINUX)
220
+
221
+ # Training settings
222
+ # BIPED-B2=1, BIPDE-B3=2, just for evaluation, using LDC trained with 2 or 3 bloacks
223
+ TRAIN_DATA = DATASET_NAMES[0] # BIPED=0, BRIND=6, MDBD=10, BIPBRI=13
224
+ train_inf = dataset_info(TRAIN_DATA, is_linux=IS_LINUX)
225
+ train_dir = train_inf['data_dir']
226
+
227
+ # Data parameters
228
+ parser.add_argument('--input_dir',
229
+ type=str,
230
+ default=train_dir,
231
+ help='the path to the directory with the input data.')
232
+ parser.add_argument('--input_val_dir',
233
+ type=str,
234
+ default=test_inf['data_dir'],
235
+ help='the path to the directory with the input data for validation.')
236
+ parser.add_argument('--output_dir',
237
+ type=str,
238
+ default='checkpoints',
239
+ help='the path to output the results.')
240
+ parser.add_argument('--train_data',
241
+ type=str,
242
+ choices=DATASET_NAMES,
243
+ default=TRAIN_DATA,
244
+ help='Name of the dataset.')# TRAIN_DATA,BIPED-B3
245
+ parser.add_argument('--test_data',
246
+ type=str,
247
+ choices=DATASET_NAMES,
248
+ default=TEST_DATA,
249
+ help='Name of the dataset.')
250
+ parser.add_argument('--test_list',
251
+ type=str,
252
+ default=test_inf['test_list'],
253
+ help='Dataset sample indices list.')
254
+ parser.add_argument('--train_list',
255
+ type=str,
256
+ default=train_inf['train_list'],
257
+ help='Dataset sample indices list.')
258
+ parser.add_argument('--is_testing',type=bool,
259
+ default=is_testing,
260
+ help='Script in testing mode.')
261
+ parser.add_argument('--predict_all',
262
+ type=bool,
263
+ default=False,
264
+ help='True: Generate all TEED outputs in all_edges ')
265
+ parser.add_argument('--up_scale',
266
+ type=bool,
267
+ default=False, # for Upsale test set in 30%
268
+ help='True: up scale x1.5 test image') # Just for test
269
+
270
+ parser.add_argument('--resume',
271
+ type=bool,
272
+ default=False,
273
+ help='use previous trained data') # Just for test
274
+ parser.add_argument('--checkpoint_data',
275
+ type=str,
276
+ default='5/5_model.pth',# 37 for biped 60 MDBD
277
+ help='Checkpoint path.')
278
+ parser.add_argument('--test_img_width',
279
+ type=int,
280
+ default=test_inf['img_width'],
281
+ help='Image width for testing.')
282
+ parser.add_argument('--test_img_height',
283
+ type=int,
284
+ default=test_inf['img_height'],
285
+ help='Image height for testing.')
286
+ parser.add_argument('--res_dir',
287
+ type=str,
288
+ default='result',
289
+ help='Result directory')
290
+ parser.add_argument('--use_gpu',type=int,
291
+ default=0, help='use GPU')
292
+ parser.add_argument('--log_interval_vis',
293
+ type=int,
294
+ default=200,# 100
295
+ help='Interval to visualize predictions. 200')
296
+ parser.add_argument('--show_log', type=int, default=20, help='display logs')
297
+ parser.add_argument('--epochs',
298
+ type=int,
299
+ default=8,
300
+ metavar='N',
301
+ help='Number of training epochs (default: 25).')
302
+ parser.add_argument('--lr', default=8e-4, type=float,
303
+ help='Initial learning rate. =1e-3') # 1e-3
304
+ parser.add_argument('--lrs', default=[8e-5], type=float,
305
+ help='LR for epochs') # [7e-5]
306
+ parser.add_argument('--wd', type=float, default=2e-4, metavar='WD',
307
+ help='weight decay (Good 5e-4/1e-4 )') # good 12e-5
308
+ parser.add_argument('--adjust_lr', default=[4], type=int,
309
+ help='Learning rate step size.') # [4] [6,9,19]
310
+ parser.add_argument('--version_notes',
311
+ default='TEED BIPED+BRIND-trainingdataLoader BRIND light AF -USNet--noBN xav init normal bdcnLoss2+cats2loss +DoubleFusion-3AF, AF sum',
312
+ type=str,
313
+ help='version notes')
314
+ parser.add_argument('--batch_size',
315
+ type=int,
316
+ default=8,
317
+ metavar='B',
318
+ help='the mini-batch size (default: 8)')
319
+ parser.add_argument('--workers',
320
+ default=8,
321
+ type=int,
322
+ help='The number of workers for the dataloaders.')
323
+ parser.add_argument('--tensorboard',type=bool,
324
+ default=True,
325
+ help='Use Tensorboard for logging.'),
326
+ parser.add_argument('--img_width',
327
+ type=int,
328
+ default=300,
329
+ help='Image width for training.') # BIPED 352/300 BRIND 256 MDBD 480
330
+ parser.add_argument('--img_height',
331
+ type=int,
332
+ default=300,
333
+ help='Image height for training.') # BIPED 352/300 BSDS 352/320
334
+ parser.add_argument('--channel_swap',
335
+ default=[2, 1, 0],
336
+ type=int)
337
+ parser.add_argument('--resume_chpt',
338
+ default='result/resume/',
339
+ type=str,
340
+ help='resume training')
341
+ parser.add_argument('--pl_opt_dir',
342
+ default=pl_opt_dir,
343
+ type=str,
344
+ help='pl output directory')
345
+ parser.add_argument('--crop_img',
346
+ default=True,
347
+ type=bool,
348
+ help='If true crop training images, else resize images to match image width and height.')
349
+ parser.add_argument('--mean_test',
350
+ default=test_inf['mean'],
351
+ type=float)
352
+ parser.add_argument('--mean_train',
353
+ default=train_inf['mean'],
354
+ type=float) # [103.939,116.779,123.68,137.86] [104.00699, 116.66877, 122.67892]
355
+
356
+ args = parser.parse_args()
357
+ return args, train_inf
358
+
359
+
360
+ def main(args, train_inf):
361
+
362
+ # Tensorboard summary writer
363
+
364
+ # torch.autograd.set_detect_anomaly(True)
365
+ tb_writer = None
366
+ training_dir = os.path.join(args.output_dir,args.train_data)
367
+ os.makedirs(training_dir,exist_ok=True)
368
+ checkpoint_path = os.path.join('./teed',args.output_dir)
369
+ checkpoint_path = os.path.join(checkpoint_path, args.train_data,args.checkpoint_data)
370
+ if args.tensorboard and not args.is_testing:
371
+ # from tensorboardX import SummaryWriter # previous torch version
372
+ from torch.utils.tensorboard import SummaryWriter # for torch 1.4 or greather
373
+ tb_writer = SummaryWriter(log_dir=training_dir)
374
+ # saving training settings
375
+ training_notes =[args.version_notes+ ' RL= ' + str(args.lr) + ' WD= '
376
+ + str(args.wd) + ' image size = ' + str(args.img_width)
377
+ + ' adjust LR=' + str(args.adjust_lr) +' LRs= '
378
+ + str(args.lrs)+' Loss Function= BDCNloss2 + CAST-loss2.py '
379
+ + str(time.asctime())+' trained on '+args.train_data]
380
+ info_txt = open(os.path.join(training_dir, 'training_settings.txt'), 'w')
381
+ info_txt.write(str(training_notes))
382
+ info_txt.close()
383
+ print("Training details> ",training_notes)
384
+
385
+ # Get computing device
386
+ device = torch.device('cpu' if torch.cuda.device_count() == 0
387
+ else 'cuda')
388
+ # torch.cuda.set_device(args.use_gpu) # set a desired gpu
389
+
390
+ print(f"Number of GPU's available: {torch.cuda.device_count()}")
391
+ print(f"Pytorch version: {torch.__version__}")
392
+ # print(f'GPU: {torch.cuda.get_device_name()}')
393
+ print(f'Trainimage mean: {args.mean_train}')
394
+ print(f'Test image mean: {args.mean_test}')
395
+
396
+
397
+ # Instantiate model and move it to the computing device
398
+ model = TED().to(device)
399
+ # model = nn.DataParallel(model)
400
+ ini_epoch =0
401
+ if not args.is_testing:
402
+ if args.resume:
403
+ checkpoint_path2= os.path.join(args.output_dir, 'BIPED-54-B4',args.checkpoint_data)
404
+ ini_epoch=8
405
+ model.load_state_dict(torch.load(checkpoint_path2,
406
+ map_location=device))
407
+
408
+ # Training dataset loading...
409
+ dataset_train = BipedDataset(args.input_dir,
410
+ img_width=args.img_width,
411
+ img_height=args.img_height,
412
+ train_mode='train',
413
+ arg=args
414
+ )
415
+ dataloader_train = DataLoader(dataset_train,
416
+ batch_size=args.batch_size,
417
+ shuffle=True,
418
+ num_workers=args.workers)
419
+ # Test dataset loading...
420
+ dataset_val = TestDataset(args.input_val_dir,
421
+ test_data=args.test_data,
422
+ img_width=args.test_img_width,
423
+ img_height=args.test_img_height,
424
+ test_list=args.test_list, arg=args
425
+ )
426
+ dataloader_val = DataLoader(dataset_val,
427
+ batch_size=1,
428
+ shuffle=False,
429
+ num_workers=args.workers)
430
+ # Testing
431
+ if_resize_img = False if args.test_data in ['BIPED', 'CID', 'MDBD'] else True
432
+ if args.is_testing:
433
+
434
+ # output_dir = os.path.join(args.res_dir, args.train_data+"2"+ args.test_data)
435
+ output_dir = args.pl_opt_dir
436
+ print(f"output_dir: {output_dir}")
437
+
438
+ test(checkpoint_path, dataloader_val, model, device,
439
+ output_dir, args,if_resize_img)
440
+
441
+ # Count parameters:
442
+ num_param = count_parameters(model)
443
+ print('-------------------------------------------------------')
444
+ print('TED parameters:')
445
+ print(num_param)
446
+ print('-------------------------------------------------------')
447
+ return
448
+
449
+ criterion1 = cats_loss #bdcn_loss2
450
+ criterion2 = bdcn_loss2#cats_loss#f1_accuracy2
451
+ criterion = [criterion1,criterion2]
452
+ optimizer = optim.Adam(model.parameters(),
453
+ lr=args.lr,
454
+ weight_decay=args.wd)
455
+
456
+ # Count parameters:
457
+ num_param = count_parameters(model)
458
+ print('-------------------------------------------------------')
459
+ print('TEED parameters:')
460
+ print(num_param)
461
+ print('-------------------------------------------------------')
462
+
463
+ # Main training loop
464
+ seed=1021
465
+ adjust_lr = args.adjust_lr
466
+ k=0
467
+ set_lr = args.lrs#[25e-4, 5e-6]
468
+ for epoch in range(ini_epoch,args.epochs):
469
+ if epoch%5==0: # before 7
470
+
471
+ seed = seed+1000
472
+ np.random.seed(seed)
473
+ torch.manual_seed(seed)
474
+ torch.cuda.manual_seed(seed)
475
+ print("------ Random seed applied-------------")
476
+ # adjust learning rate
477
+ if adjust_lr is not None:
478
+ if epoch in adjust_lr:
479
+ lr2 = set_lr[k]
480
+ for param_group in optimizer.param_groups:
481
+ param_group['lr'] = lr2
482
+ k+=1
483
+ # Create output directories
484
+
485
+ output_dir_epoch = os.path.join(args.output_dir,args.train_data, str(epoch))
486
+ img_test_dir = os.path.join(output_dir_epoch, args.test_data + '_res')
487
+ os.makedirs(output_dir_epoch,exist_ok=True)
488
+ os.makedirs(img_test_dir,exist_ok=True)
489
+ print("**************** Validating the training from the scratch **********")
490
+ # validate_one_epoch(epoch,
491
+ # dataloader_val,
492
+ # model,
493
+ # device,
494
+ # img_test_dir,
495
+ # arg=args,test_resize=if_resize_img)
496
+
497
+ avg_loss =train_one_epoch(epoch,dataloader_train,
498
+ model, criterion,
499
+ optimizer,
500
+ device,
501
+ args.log_interval_vis,
502
+ tb_writer=tb_writer,
503
+ args=args)
504
+ validate_one_epoch(epoch,
505
+ dataloader_val,
506
+ model,
507
+ device,
508
+ img_test_dir,
509
+ arg=args, test_resize=if_resize_img)
510
+
511
+ # Save model after end of every epoch
512
+ torch.save(model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
513
+ os.path.join(output_dir_epoch, '{0}_model.pth'.format(epoch)))
514
+ if tb_writer is not None:
515
+ tb_writer.add_scalar('loss',
516
+ avg_loss,
517
+ epoch+1)
518
+ print('Last learning rate> ', optimizer.param_groups[0]['lr'])
519
+
520
+ num_param = count_parameters(model)
521
+ print('-------------------------------------------------------')
522
+ print('TEED parameters:')
523
+ print(num_param)
524
+ print('-------------------------------------------------------')
525
+
526
+ if __name__ == '__main__':
527
+ # os.system(" ".join(command))
528
+ is_testing =True # True to use TEED for testing
529
+ args, train_info = parse_args(is_testing=is_testing)
530
+ main(args, train_info)