Spaces:
Runtime error
Runtime error
Update TEED/main.py
Browse files- TEED/main.py +6 -3
TEED/main.py
CHANGED
|
@@ -206,7 +206,7 @@ def testPich(checkpoint_path, dataloader, model, device, output_dir, args, resiz
|
|
| 206 |
print("Average time per image: %f.4" % total_duration.mean(), "seconds")
|
| 207 |
print("Time spend in the Dataset: %f.4" % total_duration.sum(), "seconds")
|
| 208 |
|
| 209 |
-
def parse_args(is_testing=True, pl_opt_dir
|
| 210 |
"""Parse command line arguments."""
|
| 211 |
parser = argparse.ArgumentParser(description='TEED model')
|
| 212 |
parser.add_argument('--choose_test_data',
|
|
@@ -214,16 +214,19 @@ def parse_args(is_testing=True, pl_opt_dir = 'output/teed'):
|
|
| 214 |
default=-1, # UDED=15
|
| 215 |
help='Choose a dataset for testing: 0 - 15')
|
| 216 |
|
|
|
|
|
|
|
|
|
|
| 217 |
# ----------- test -------0--
|
| 218 |
TEST_DATA = DATASET_NAMES[parser.parse_args().choose_test_data] # max 8
|
| 219 |
test_inf = dataset_info(TEST_DATA, is_linux=IS_LINUX)
|
| 220 |
|
| 221 |
# Training settings
|
| 222 |
-
# BIPED-B2=1, BIPDE-B3=2, just for evaluation, using LDC trained with 2 or 3 bloacks
|
| 223 |
TRAIN_DATA = DATASET_NAMES[0] # BIPED=0, BRIND=6, MDBD=10, BIPBRI=13
|
| 224 |
train_inf = dataset_info(TRAIN_DATA, is_linux=IS_LINUX)
|
| 225 |
train_dir = train_inf['data_dir']
|
| 226 |
|
|
|
|
| 227 |
# Data parameters
|
| 228 |
parser.add_argument('--input_dir',
|
| 229 |
type=str,
|
|
@@ -366,7 +369,7 @@ def main(args, train_inf):
|
|
| 366 |
training_dir = os.path.join(args.output_dir,args.train_data)
|
| 367 |
os.makedirs(training_dir,exist_ok=True)
|
| 368 |
checkpoint_path = './TEED/checkpoints/BIPED/5/5_model.pth'
|
| 369 |
-
checkpoint_path = os.path.join('./TEED/checkpoints', 'BIPED', str(args.
|
| 370 |
if args.tensorboard and not args.is_testing:
|
| 371 |
# from tensorboardX import SummaryWriter # previous torch version
|
| 372 |
from torch.utils.tensorboard import SummaryWriter # for torch 1.4 or greather
|
|
|
|
| 206 |
print("Average time per image: %f.4" % total_duration.mean(), "seconds")
|
| 207 |
print("Time spend in the Dataset: %f.4" % total_duration.sum(), "seconds")
|
| 208 |
|
| 209 |
+
def parse_args(is_testing=True, pl_opt_dir='output/teed'):
|
| 210 |
"""Parse command line arguments."""
|
| 211 |
parser = argparse.ArgumentParser(description='TEED model')
|
| 212 |
parser.add_argument('--choose_test_data',
|
|
|
|
| 214 |
default=-1, # UDED=15
|
| 215 |
help='Choose a dataset for testing: 0 - 15')
|
| 216 |
|
| 217 |
+
# 新增的 epoch 参数
|
| 218 |
+
parser.add_argument('--epoch', type=int, required=True, help='Epoch number')
|
| 219 |
+
|
| 220 |
# ----------- test -------0--
|
| 221 |
TEST_DATA = DATASET_NAMES[parser.parse_args().choose_test_data] # max 8
|
| 222 |
test_inf = dataset_info(TEST_DATA, is_linux=IS_LINUX)
|
| 223 |
|
| 224 |
# Training settings
|
|
|
|
| 225 |
TRAIN_DATA = DATASET_NAMES[0] # BIPED=0, BRIND=6, MDBD=10, BIPBRI=13
|
| 226 |
train_inf = dataset_info(TRAIN_DATA, is_linux=IS_LINUX)
|
| 227 |
train_dir = train_inf['data_dir']
|
| 228 |
|
| 229 |
+
|
| 230 |
# Data parameters
|
| 231 |
parser.add_argument('--input_dir',
|
| 232 |
type=str,
|
|
|
|
| 369 |
training_dir = os.path.join(args.output_dir,args.train_data)
|
| 370 |
os.makedirs(training_dir,exist_ok=True)
|
| 371 |
checkpoint_path = './TEED/checkpoints/BIPED/5/5_model.pth'
|
| 372 |
+
checkpoint_path = os.path.join('./TEED/checkpoints', 'BIPED', str(args.epochs), '5_model.pth')
|
| 373 |
if args.tensorboard and not args.is_testing:
|
| 374 |
# from tensorboardX import SummaryWriter # previous torch version
|
| 375 |
from torch.utils.tensorboard import SummaryWriter # for torch 1.4 or greather
|