import configargparse def parse_train_args(): parser = configargparse.ArgumentParser( formatter_class=configargparse.ArgumentDefaultsHelpFormatter, config_file_parser_class=configargparse.YAMLConfigFileParser, allow_abbrev=False, ) parser.add_argument( "-c", "--config", default="_utils/example_config.yaml", is_config_file=True, help="config file path", ) parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda") parser.add_argument("-o", "--outdir", type=str, default="runs") parser.add_argument("--name", type=str, help="Name to append to timestamp") parser.add_argument("--timestamp", type=bool, default=True) parser.add_argument( "-m", "--model", type=str, default="", help="load this model at start (e.g. to continue training)", ) parser.add_argument( "--ndim", type=int, default=2, help="number of spatial dimensions" ) parser.add_argument("-d", "--d_model", type=int, default=256) parser.add_argument("-w", "--window", type=int, default=10) parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--warmup_epochs", type=int, default=10) parser.add_argument( "--detection_folders", type=str, nargs="+", default=["TRA"], help=( "Subfolders to search for detections. Defaults to `TRA`, which corresponds" " to using only the GT." ), ) parser.add_argument("--downscale_temporal", type=int, default=1) parser.add_argument("--downscale_spatial", type=int, default=1) parser.add_argument("--spatial_pos_cutoff", type=int, default=256) parser.add_argument("--from_subfolder", action="store_true") # parser.add_argument("--train_samples", type=int, default=50000) parser.add_argument("--num_encoder_layers", type=int, default=6) parser.add_argument("--num_decoder_layers", type=int, default=6) parser.add_argument("--pos_embed_per_dim", type=int, default=32) parser.add_argument("--feat_embed_per_dim", type=int, default=8) parser.add_argument("--dropout", type=float, default=0.00) parser.add_argument("--num_workers", type=int, default=4) parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--max_tokens", type=int, default=None) parser.add_argument("--delta_cutoff", type=int, default=2) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument( "--attn_positional_bias", type=str, choices=["rope", "bias", "none"], default="rope", ) parser.add_argument("--attn_positional_bias_n_spatial", type=int, default=16) parser.add_argument("--attn_dist_mode", default="v0") parser.add_argument("--mixedp", type=bool, default=True) parser.add_argument("--dry", action="store_true") parser.add_argument("--profile", action="store_true") parser.add_argument( "--features", type=str, choices=[ "none", "regionprops", "regionprops2", "patch", "patch_regionprops", "wrfeat", ], default="wrfeat", ) parser.add_argument( "--causal_norm", type=str, choices=["none", "linear", "softmax", "quiet_softmax"], default="quiet_softmax", ) parser.add_argument("--div_upweight", type=float, default=2) parser.add_argument("--augment", type=int, default=3) parser.add_argument("--tracking_frequency", type=int, default=-1) parser.add_argument("--sanity_dist", action="store_true") parser.add_argument("--preallocate", type=bool, default=False) parser.add_argument("--only_prechecks", action="store_true") parser.add_argument( "--compress", type=bool, default=True, help="compress dataset" ) parser.add_argument("--seed", type=int, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb", "none"], ) parser.add_argument("--wandb_project", type=str, default="trackastra") parser.add_argument( "--crop_size", type=int, # required=True, nargs="+", default=None, help="random crop size for augmentation", ) parser.add_argument( "--weight_by_ndivs", type=bool, default=True, help="Oversample windows that contain divisions", ) parser.add_argument( "--weight_by_dataset", type=bool, default=False, help=( "Inversely weight datasets by number of samples (to counter dataset size" " imbalance)" ), ) args, unknown_args = parser.parse_known_args() # # Hack to allow for --input_test # allowed_unknown = ["input_test"] # if not set(a.split("=")[0].strip("-") for a in unknown_args).issubset( # set(allowed_unknown) # ): # raise ValueError(f"Unknown args: {unknown_args}") # pprint(vars(args)) # for backward compatibility # if args.attn_positional_bias == "True": # args.attn_positional_bias = "bias" # elif args.attn_positional_bias == "False": # args.attn_positional_bias = False # if args.train_samples == 0: # raise NotImplementedError( # "--train_samples must be > 0, full dataset pass not supported." # ) return args