plutosss commited on
Commit
195d616
·
verified ·
1 Parent(s): c4eace3

Update TEED/main.py

Browse files
Files changed (1) hide show
  1. TEED/main.py +6 -3
TEED/main.py CHANGED
@@ -206,7 +206,7 @@ def testPich(checkpoint_path, dataloader, model, device, output_dir, args, resiz
206
  print("Average time per image: %f.4" % total_duration.mean(), "seconds")
207
  print("Time spend in the Dataset: %f.4" % total_duration.sum(), "seconds")
208
 
209
- def parse_args(is_testing=True, pl_opt_dir = 'output/teed'):
210
  """Parse command line arguments."""
211
  parser = argparse.ArgumentParser(description='TEED model')
212
  parser.add_argument('--choose_test_data',
@@ -214,16 +214,19 @@ def parse_args(is_testing=True, pl_opt_dir = 'output/teed'):
214
  default=-1, # UDED=15
215
  help='Choose a dataset for testing: 0 - 15')
216
 
 
 
 
217
  # ----------- test -------0--
218
  TEST_DATA = DATASET_NAMES[parser.parse_args().choose_test_data] # max 8
219
  test_inf = dataset_info(TEST_DATA, is_linux=IS_LINUX)
220
 
221
  # Training settings
222
- # BIPED-B2=1, BIPDE-B3=2, just for evaluation, using LDC trained with 2 or 3 bloacks
223
  TRAIN_DATA = DATASET_NAMES[0] # BIPED=0, BRIND=6, MDBD=10, BIPBRI=13
224
  train_inf = dataset_info(TRAIN_DATA, is_linux=IS_LINUX)
225
  train_dir = train_inf['data_dir']
226
 
 
227
  # Data parameters
228
  parser.add_argument('--input_dir',
229
  type=str,
@@ -366,7 +369,7 @@ def main(args, train_inf):
366
  training_dir = os.path.join(args.output_dir,args.train_data)
367
  os.makedirs(training_dir,exist_ok=True)
368
  checkpoint_path = './TEED/checkpoints/BIPED/5/5_model.pth'
369
- checkpoint_path = os.path.join('./TEED/checkpoints', 'BIPED', str(args.epoch), '5_model.pth')
370
  if args.tensorboard and not args.is_testing:
371
  # from tensorboardX import SummaryWriter # previous torch version
372
  from torch.utils.tensorboard import SummaryWriter # for torch 1.4 or greather
 
206
  print("Average time per image: %f.4" % total_duration.mean(), "seconds")
207
  print("Time spend in the Dataset: %f.4" % total_duration.sum(), "seconds")
208
 
209
+ def parse_args(is_testing=True, pl_opt_dir='output/teed'):
210
  """Parse command line arguments."""
211
  parser = argparse.ArgumentParser(description='TEED model')
212
  parser.add_argument('--choose_test_data',
 
214
  default=-1, # UDED=15
215
  help='Choose a dataset for testing: 0 - 15')
216
 
217
+ # 新增的 epoch 参数
218
+ parser.add_argument('--epoch', type=int, required=True, help='Epoch number')
219
+
220
  # ----------- test -------0--
221
  TEST_DATA = DATASET_NAMES[parser.parse_args().choose_test_data] # max 8
222
  test_inf = dataset_info(TEST_DATA, is_linux=IS_LINUX)
223
 
224
  # Training settings
 
225
  TRAIN_DATA = DATASET_NAMES[0] # BIPED=0, BRIND=6, MDBD=10, BIPBRI=13
226
  train_inf = dataset_info(TRAIN_DATA, is_linux=IS_LINUX)
227
  train_dir = train_inf['data_dir']
228
 
229
+
230
  # Data parameters
231
  parser.add_argument('--input_dir',
232
  type=str,
 
369
  training_dir = os.path.join(args.output_dir,args.train_data)
370
  os.makedirs(training_dir,exist_ok=True)
371
  checkpoint_path = './TEED/checkpoints/BIPED/5/5_model.pth'
372
+ checkpoint_path = os.path.join('./TEED/checkpoints', 'BIPED', str(args.epochs), '5_model.pth')
373
  if args.tensorboard and not args.is_testing:
374
  # from tensorboardX import SummaryWriter # previous torch version
375
  from torch.utils.tensorboard import SummaryWriter # for torch 1.4 or greather