| import configargparse |
|
|
|
|
| def parse_train_args(): |
| parser = configargparse.ArgumentParser( |
| formatter_class=configargparse.ArgumentDefaultsHelpFormatter, |
| config_file_parser_class=configargparse.YAMLConfigFileParser, |
| allow_abbrev=False, |
| ) |
| parser.add_argument( |
| "-c", |
| "--config", |
| default="_utils/example_config.yaml", |
| is_config_file=True, |
| help="config file path", |
| ) |
| parser.add_argument("-d", "--d_model", type=int, default=256) |
| parser.add_argument("-w", "--window", type=int, default=10) |
| parser.add_argument("--spatial_pos_cutoff", type=int, default=256) |
| parser.add_argument("--num_encoder_layers", type=int, default=6) |
| parser.add_argument("--num_decoder_layers", type=int, default=6) |
| parser.add_argument("--pos_embed_per_dim", type=int, default=32) |
| parser.add_argument("--feat_embed_per_dim", type=int, default=8) |
| parser.add_argument("--dropout", type=float, default=0.00) |
| parser.add_argument( |
| "--attn_positional_bias", |
| type=str, |
| choices=["rope", "bias", "none"], |
| default="rope", |
| ) |
| parser.add_argument("--attn_positional_bias_n_spatial", type=int, default=16) |
| parser.add_argument("--attn_dist_mode", default="v0") |
| parser.add_argument( |
| "--causal_norm", |
| type=str, |
| choices=["none", "linear", "softmax", "quiet_softmax"], |
| default="quiet_softmax", |
| ) |
|
|
| args, unknown_args = parser.parse_known_args() |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| return args |
|
|