File size: 2,170 Bytes
aff3c6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ce5a27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import configargparse


def parse_train_args():
    parser = configargparse.ArgumentParser(
        formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
        config_file_parser_class=configargparse.YAMLConfigFileParser,
        allow_abbrev=False,
    )
    parser.add_argument(
        "-c",
        "--config",
        default="_utils/example_config.yaml",
        is_config_file=True,
        help="config file path",
    )
    parser.add_argument("-d", "--d_model", type=int, default=256)
    parser.add_argument("-w", "--window", type=int, default=10)
    parser.add_argument("--spatial_pos_cutoff", type=int, default=256)
    parser.add_argument("--num_encoder_layers", type=int, default=6)
    parser.add_argument("--num_decoder_layers", type=int, default=6)
    parser.add_argument("--pos_embed_per_dim", type=int, default=32)
    parser.add_argument("--feat_embed_per_dim", type=int, default=8)
    parser.add_argument("--dropout", type=float, default=0.00)
    parser.add_argument(
        "--attn_positional_bias",
        type=str,
        choices=["rope", "bias", "none"],
        default="rope",
    )
    parser.add_argument("--attn_positional_bias_n_spatial", type=int, default=16)
    parser.add_argument("--attn_dist_mode", default="v0")
    parser.add_argument(
        "--causal_norm",
        type=str,
        choices=["none", "linear", "softmax", "quiet_softmax"],
        default="quiet_softmax",
    )

    args, unknown_args = parser.parse_known_args()

    # # Hack to allow for --input_test
    # allowed_unknown = ["input_test"]
    # if not set(a.split("=")[0].strip("-") for a in unknown_args).issubset(
    #     set(allowed_unknown)
    # ):
    #     raise ValueError(f"Unknown args: {unknown_args}")

    # pprint(vars(args))

    # for backward compatibility
    # if args.attn_positional_bias == "True":
    #     args.attn_positional_bias = "bias"
    # elif args.attn_positional_bias == "False":
    #     args.attn_positional_bias = False

    # if args.train_samples == 0:
    #     raise NotImplementedError(
    #         "--train_samples must be > 0, full dataset pass not supported."
    #     )

    return args