Spaces:
Sleeping
Sleeping
| import configargparse | |
| def parse_train_args(): | |
| parser = configargparse.ArgumentParser( | |
| formatter_class=configargparse.ArgumentDefaultsHelpFormatter, | |
| config_file_parser_class=configargparse.YAMLConfigFileParser, | |
| allow_abbrev=False, | |
| ) | |
| parser.add_argument( | |
| "-c", | |
| "--config", | |
| default="_utils/example_config.yaml", | |
| is_config_file=True, | |
| help="config file path", | |
| ) | |
| parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda") | |
| parser.add_argument("-o", "--outdir", type=str, default="runs") | |
| parser.add_argument("--name", type=str, help="Name to append to timestamp") | |
| parser.add_argument("--timestamp", type=bool, default=True) | |
| parser.add_argument( | |
| "-m", | |
| "--model", | |
| type=str, | |
| default="", | |
| help="load this model at start (e.g. to continue training)", | |
| ) | |
| parser.add_argument( | |
| "--ndim", type=int, default=2, help="number of spatial dimensions" | |
| ) | |
| parser.add_argument("-d", "--d_model", type=int, default=256) | |
| parser.add_argument("-w", "--window", type=int, default=10) | |
| parser.add_argument("--epochs", type=int, default=100) | |
| parser.add_argument("--warmup_epochs", type=int, default=10) | |
| parser.add_argument( | |
| "--detection_folders", | |
| type=str, | |
| nargs="+", | |
| default=["TRA"], | |
| help=( | |
| "Subfolders to search for detections. Defaults to `TRA`, which corresponds" | |
| " to using only the GT." | |
| ), | |
| ) | |
| parser.add_argument("--downscale_temporal", type=int, default=1) | |
| parser.add_argument("--downscale_spatial", type=int, default=1) | |
| parser.add_argument("--spatial_pos_cutoff", type=int, default=256) | |
| parser.add_argument("--from_subfolder", action="store_true") | |
| # parser.add_argument("--train_samples", type=int, default=50000) | |
| parser.add_argument("--num_encoder_layers", type=int, default=6) | |
| parser.add_argument("--num_decoder_layers", type=int, default=6) | |
| parser.add_argument("--pos_embed_per_dim", type=int, default=32) | |
| parser.add_argument("--feat_embed_per_dim", type=int, default=8) | |
| parser.add_argument("--dropout", type=float, default=0.00) | |
| parser.add_argument("--num_workers", type=int, default=4) | |
| parser.add_argument("--batch_size", type=int, default=1) | |
| parser.add_argument("--max_tokens", type=int, default=None) | |
| parser.add_argument("--delta_cutoff", type=int, default=2) | |
| parser.add_argument("--lr", type=float, default=1e-4) | |
| parser.add_argument( | |
| "--attn_positional_bias", | |
| type=str, | |
| choices=["rope", "bias", "none"], | |
| default="rope", | |
| ) | |
| parser.add_argument("--attn_positional_bias_n_spatial", type=int, default=16) | |
| parser.add_argument("--attn_dist_mode", default="v0") | |
| parser.add_argument("--mixedp", type=bool, default=True) | |
| parser.add_argument("--dry", action="store_true") | |
| parser.add_argument("--profile", action="store_true") | |
| parser.add_argument( | |
| "--features", | |
| type=str, | |
| choices=[ | |
| "none", | |
| "regionprops", | |
| "regionprops2", | |
| "patch", | |
| "patch_regionprops", | |
| "wrfeat", | |
| ], | |
| default="wrfeat", | |
| ) | |
| parser.add_argument( | |
| "--causal_norm", | |
| type=str, | |
| choices=["none", "linear", "softmax", "quiet_softmax"], | |
| default="quiet_softmax", | |
| ) | |
| parser.add_argument("--div_upweight", type=float, default=2) | |
| parser.add_argument("--augment", type=int, default=3) | |
| parser.add_argument("--tracking_frequency", type=int, default=-1) | |
| parser.add_argument("--sanity_dist", action="store_true") | |
| parser.add_argument("--preallocate", type=bool, default=False) | |
| parser.add_argument("--only_prechecks", action="store_true") | |
| parser.add_argument( | |
| "--compress", type=bool, default=True, help="compress dataset" | |
| ) | |
| parser.add_argument("--seed", type=int, default=None) | |
| parser.add_argument( | |
| "--logger", | |
| type=str, | |
| default="tensorboard", | |
| choices=["tensorboard", "wandb", "none"], | |
| ) | |
| parser.add_argument("--wandb_project", type=str, default="trackastra") | |
| parser.add_argument( | |
| "--crop_size", | |
| type=int, | |
| # required=True, | |
| nargs="+", | |
| default=None, | |
| help="random crop size for augmentation", | |
| ) | |
| parser.add_argument( | |
| "--weight_by_ndivs", | |
| type=bool, | |
| default=True, | |
| help="Oversample windows that contain divisions", | |
| ) | |
| parser.add_argument( | |
| "--weight_by_dataset", | |
| type=bool, | |
| default=False, | |
| help=( | |
| "Inversely weight datasets by number of samples (to counter dataset size" | |
| " imbalance)" | |
| ), | |
| ) | |
| args, unknown_args = parser.parse_known_args() | |
| # # Hack to allow for --input_test | |
| # allowed_unknown = ["input_test"] | |
| # if not set(a.split("=")[0].strip("-") for a in unknown_args).issubset( | |
| # set(allowed_unknown) | |
| # ): | |
| # raise ValueError(f"Unknown args: {unknown_args}") | |
| # pprint(vars(args)) | |
| # for backward compatibility | |
| # if args.attn_positional_bias == "True": | |
| # args.attn_positional_bias = "bias" | |
| # elif args.attn_positional_bias == "False": | |
| # args.attn_positional_bias = False | |
| # if args.train_samples == 0: | |
| # raise NotImplementedError( | |
| # "--train_samples must be > 0, full dataset pass not supported." | |
| # ) | |
| return args |