|
|
import argparse
|
|
|
import random
|
|
|
from datetime import datetime
|
|
|
|
|
|
|
|
|
def str2bool(v):
|
|
|
if isinstance(v, bool):
|
|
|
return v
|
|
|
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
|
|
return True
|
|
|
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
|
|
return False
|
|
|
else:
|
|
|
raise argparse.ArgumentTypeError('Boolean value expected.')
|
|
|
|
|
|
|
|
|
class BaseOptions:
|
|
|
def initialize(self, parser):
|
|
|
parser.add_argument("--name", type=str)
|
|
|
parser.add_argument("--model", type=str, default="dot", choices=["dot", "of", "pt"])
|
|
|
parser.add_argument("--datetime", type=str, default=None)
|
|
|
parser.add_argument("--data_root", type=str)
|
|
|
parser.add_argument("--height", type=int, default=512)
|
|
|
parser.add_argument("--width", type=int, default=512)
|
|
|
parser.add_argument("--aspect_ratio", type=float, default=1)
|
|
|
parser.add_argument("--batch_size", type=int)
|
|
|
parser.add_argument("--num_tracks", type=int, default=2048)
|
|
|
parser.add_argument("--sim_tracks", type=int, default=2048)
|
|
|
parser.add_argument("--alpha_thresh", type=float, default=0.8)
|
|
|
parser.add_argument("--is_train", type=str2bool, nargs='?', const=True, default=False)
|
|
|
|
|
|
|
|
|
parser.add_argument('--worker_idx', type=int, default=0)
|
|
|
parser.add_argument("--num_workers", type=int, default=2)
|
|
|
|
|
|
|
|
|
parser.add_argument("--estimator_config", type=str, default="configs/raft_patch_8.json")
|
|
|
parser.add_argument("--estimator_path", type=str, default="checkpoints/cvo_raft_patch_8.pth")
|
|
|
parser.add_argument("--flow_mode", type=str, default="direct", choices=["direct", "chain", "warm_start"])
|
|
|
|
|
|
|
|
|
parser.add_argument("--refiner_config", type=str, default="configs/raft_patch_4_alpha.json")
|
|
|
parser.add_argument("--refiner_path", type=str, default="checkpoints/movi_f_raft_patch_4_alpha.pth")
|
|
|
|
|
|
|
|
|
parser.add_argument("--tracker_config", type=str, default="configs/cotracker2_patch_4_wind_8.json")
|
|
|
parser.add_argument("--tracker_path", type=str, default="checkpoints/movi_f_cotracker2_patch_4_wind_8.pth")
|
|
|
parser.add_argument("--sample_mode", type=str, default="all", choices=["all", "first", "last"])
|
|
|
|
|
|
|
|
|
parser.add_argument("--cell_size", type=int, default=1)
|
|
|
parser.add_argument("--cell_time_steps", type=int, default=20)
|
|
|
|
|
|
|
|
|
parser.add_argument("--interpolation_version", type=str, default="torch3d", choices=["torch3d", "torch"])
|
|
|
return parser
|
|
|
|
|
|
def parse_args(self):
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser = self.initialize(parser)
|
|
|
args = parser.parse_args()
|
|
|
if args.datetime is None:
|
|
|
args.datetime = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
|
|
name = f"{args.datetime}_{args.name}_{args.model}"
|
|
|
if hasattr(args, 'split'):
|
|
|
name += f"_{args.split}"
|
|
|
args.checkpoint_path = f"checkpoints/{name}"
|
|
|
args.log_path = f"logs/{name}"
|
|
|
args.result_path = f"results/{name}"
|
|
|
if hasattr(args, 'world_size'):
|
|
|
args.batch_size = args.batch_size // args.world_size
|
|
|
args.master_port = f'{10000 + random.randrange(1, 10000)}'
|
|
|
return args
|
|
|
|