FinalVision / _utils /track_args.py
Shengxiao0709's picture
Upload 78 files
8f72b1f verified
import configargparse
def parse_train_args():
parser = configargparse.ArgumentParser(
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
config_file_parser_class=configargparse.YAMLConfigFileParser,
allow_abbrev=False,
)
parser.add_argument(
"-c",
"--config",
default="_utils/example_config.yaml",
is_config_file=True,
help="config file path",
)
parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda")
parser.add_argument("-o", "--outdir", type=str, default="runs")
parser.add_argument("--name", type=str, help="Name to append to timestamp")
parser.add_argument("--timestamp", type=bool, default=True)
parser.add_argument(
"-m",
"--model",
type=str,
default="",
help="load this model at start (e.g. to continue training)",
)
parser.add_argument(
"--ndim", type=int, default=2, help="number of spatial dimensions"
)
parser.add_argument("-d", "--d_model", type=int, default=256)
parser.add_argument("-w", "--window", type=int, default=10)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--warmup_epochs", type=int, default=10)
parser.add_argument(
"--detection_folders",
type=str,
nargs="+",
default=["TRA"],
help=(
"Subfolders to search for detections. Defaults to `TRA`, which corresponds"
" to using only the GT."
),
)
parser.add_argument("--downscale_temporal", type=int, default=1)
parser.add_argument("--downscale_spatial", type=int, default=1)
parser.add_argument("--spatial_pos_cutoff", type=int, default=256)
parser.add_argument("--from_subfolder", action="store_true")
# parser.add_argument("--train_samples", type=int, default=50000)
parser.add_argument("--num_encoder_layers", type=int, default=6)
parser.add_argument("--num_decoder_layers", type=int, default=6)
parser.add_argument("--pos_embed_per_dim", type=int, default=32)
parser.add_argument("--feat_embed_per_dim", type=int, default=8)
parser.add_argument("--dropout", type=float, default=0.00)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--max_tokens", type=int, default=None)
parser.add_argument("--delta_cutoff", type=int, default=2)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument(
"--attn_positional_bias",
type=str,
choices=["rope", "bias", "none"],
default="rope",
)
parser.add_argument("--attn_positional_bias_n_spatial", type=int, default=16)
parser.add_argument("--attn_dist_mode", default="v0")
parser.add_argument("--mixedp", type=bool, default=True)
parser.add_argument("--dry", action="store_true")
parser.add_argument("--profile", action="store_true")
parser.add_argument(
"--features",
type=str,
choices=[
"none",
"regionprops",
"regionprops2",
"patch",
"patch_regionprops",
"wrfeat",
],
default="wrfeat",
)
parser.add_argument(
"--causal_norm",
type=str,
choices=["none", "linear", "softmax", "quiet_softmax"],
default="quiet_softmax",
)
parser.add_argument("--div_upweight", type=float, default=2)
parser.add_argument("--augment", type=int, default=3)
parser.add_argument("--tracking_frequency", type=int, default=-1)
parser.add_argument("--sanity_dist", action="store_true")
parser.add_argument("--preallocate", type=bool, default=False)
parser.add_argument("--only_prechecks", action="store_true")
parser.add_argument(
"--compress", type=bool, default=True, help="compress dataset"
)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument(
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb", "none"],
)
parser.add_argument("--wandb_project", type=str, default="trackastra")
parser.add_argument(
"--crop_size",
type=int,
# required=True,
nargs="+",
default=None,
help="random crop size for augmentation",
)
parser.add_argument(
"--weight_by_ndivs",
type=bool,
default=True,
help="Oversample windows that contain divisions",
)
parser.add_argument(
"--weight_by_dataset",
type=bool,
default=False,
help=(
"Inversely weight datasets by number of samples (to counter dataset size"
" imbalance)"
),
)
args, unknown_args = parser.parse_known_args()
# # Hack to allow for --input_test
# allowed_unknown = ["input_test"]
# if not set(a.split("=")[0].strip("-") for a in unknown_args).issubset(
# set(allowed_unknown)
# ):
# raise ValueError(f"Unknown args: {unknown_args}")
# pprint(vars(args))
# for backward compatibility
# if args.attn_positional_bias == "True":
# args.attn_positional_bias = "bias"
# elif args.attn_positional_bias == "False":
# args.attn_positional_bias = False
# if args.train_samples == 0:
# raise NotImplementedError(
# "--train_samples must be > 0, full dataset pass not supported."
# )
return args