import configargparse def parse_train_args(): parser = configargparse.ArgumentParser( formatter_class=configargparse.ArgumentDefaultsHelpFormatter, config_file_parser_class=configargparse.YAMLConfigFileParser, allow_abbrev=False, ) parser.add_argument( "-c", "--config", default="_utils/example_config.yaml", is_config_file=True, help="config file path", ) parser.add_argument("-d", "--d_model", type=int, default=256) parser.add_argument("-w", "--window", type=int, default=10) parser.add_argument("--spatial_pos_cutoff", type=int, default=256) parser.add_argument("--num_encoder_layers", type=int, default=6) parser.add_argument("--num_decoder_layers", type=int, default=6) parser.add_argument("--pos_embed_per_dim", type=int, default=32) parser.add_argument("--feat_embed_per_dim", type=int, default=8) parser.add_argument("--dropout", type=float, default=0.00) parser.add_argument( "--attn_positional_bias", type=str, choices=["rope", "bias", "none"], default="rope", ) parser.add_argument("--attn_positional_bias_n_spatial", type=int, default=16) parser.add_argument("--attn_dist_mode", default="v0") parser.add_argument( "--causal_norm", type=str, choices=["none", "linear", "softmax", "quiet_softmax"], default="quiet_softmax", ) args, unknown_args = parser.parse_known_args() # # Hack to allow for --input_test # allowed_unknown = ["input_test"] # if not set(a.split("=")[0].strip("-") for a in unknown_args).issubset( # set(allowed_unknown) # ): # raise ValueError(f"Unknown args: {unknown_args}") # pprint(vars(args)) # for backward compatibility # if args.attn_positional_bias == "True": # args.attn_positional_bias = "bias" # elif args.attn_positional_bias == "False": # args.attn_positional_bias = False # if args.train_samples == 0: # raise NotImplementedError( # "--train_samples must be > 0, full dataset pass not supported." # ) return args