Spaces:
Sleeping
Sleeping
File size: 5,539 Bytes
8f72b1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import configargparse
def parse_train_args():
parser = configargparse.ArgumentParser(
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
config_file_parser_class=configargparse.YAMLConfigFileParser,
allow_abbrev=False,
)
parser.add_argument(
"-c",
"--config",
default="_utils/example_config.yaml",
is_config_file=True,
help="config file path",
)
parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda")
parser.add_argument("-o", "--outdir", type=str, default="runs")
parser.add_argument("--name", type=str, help="Name to append to timestamp")
parser.add_argument("--timestamp", type=bool, default=True)
parser.add_argument(
"-m",
"--model",
type=str,
default="",
help="load this model at start (e.g. to continue training)",
)
parser.add_argument(
"--ndim", type=int, default=2, help="number of spatial dimensions"
)
parser.add_argument("-d", "--d_model", type=int, default=256)
parser.add_argument("-w", "--window", type=int, default=10)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--warmup_epochs", type=int, default=10)
parser.add_argument(
"--detection_folders",
type=str,
nargs="+",
default=["TRA"],
help=(
"Subfolders to search for detections. Defaults to `TRA`, which corresponds"
" to using only the GT."
),
)
parser.add_argument("--downscale_temporal", type=int, default=1)
parser.add_argument("--downscale_spatial", type=int, default=1)
parser.add_argument("--spatial_pos_cutoff", type=int, default=256)
parser.add_argument("--from_subfolder", action="store_true")
# parser.add_argument("--train_samples", type=int, default=50000)
parser.add_argument("--num_encoder_layers", type=int, default=6)
parser.add_argument("--num_decoder_layers", type=int, default=6)
parser.add_argument("--pos_embed_per_dim", type=int, default=32)
parser.add_argument("--feat_embed_per_dim", type=int, default=8)
parser.add_argument("--dropout", type=float, default=0.00)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--max_tokens", type=int, default=None)
parser.add_argument("--delta_cutoff", type=int, default=2)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument(
"--attn_positional_bias",
type=str,
choices=["rope", "bias", "none"],
default="rope",
)
parser.add_argument("--attn_positional_bias_n_spatial", type=int, default=16)
parser.add_argument("--attn_dist_mode", default="v0")
parser.add_argument("--mixedp", type=bool, default=True)
parser.add_argument("--dry", action="store_true")
parser.add_argument("--profile", action="store_true")
parser.add_argument(
"--features",
type=str,
choices=[
"none",
"regionprops",
"regionprops2",
"patch",
"patch_regionprops",
"wrfeat",
],
default="wrfeat",
)
parser.add_argument(
"--causal_norm",
type=str,
choices=["none", "linear", "softmax", "quiet_softmax"],
default="quiet_softmax",
)
parser.add_argument("--div_upweight", type=float, default=2)
parser.add_argument("--augment", type=int, default=3)
parser.add_argument("--tracking_frequency", type=int, default=-1)
parser.add_argument("--sanity_dist", action="store_true")
parser.add_argument("--preallocate", type=bool, default=False)
parser.add_argument("--only_prechecks", action="store_true")
parser.add_argument(
"--compress", type=bool, default=True, help="compress dataset"
)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument(
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb", "none"],
)
parser.add_argument("--wandb_project", type=str, default="trackastra")
parser.add_argument(
"--crop_size",
type=int,
# required=True,
nargs="+",
default=None,
help="random crop size for augmentation",
)
parser.add_argument(
"--weight_by_ndivs",
type=bool,
default=True,
help="Oversample windows that contain divisions",
)
parser.add_argument(
"--weight_by_dataset",
type=bool,
default=False,
help=(
"Inversely weight datasets by number of samples (to counter dataset size"
" imbalance)"
),
)
args, unknown_args = parser.parse_known_args()
# # Hack to allow for --input_test
# allowed_unknown = ["input_test"]
# if not set(a.split("=")[0].strip("-") for a in unknown_args).issubset(
# set(allowed_unknown)
# ):
# raise ValueError(f"Unknown args: {unknown_args}")
# pprint(vars(args))
# for backward compatibility
# if args.attn_positional_bias == "True":
# args.attn_positional_bias = "bias"
# elif args.attn_positional_bias == "False":
# args.attn_positional_bias = False
# if args.train_samples == 0:
# raise NotImplementedError(
# "--train_samples must be > 0, full dataset pass not supported."
# )
return args |