File size: 2,284 Bytes
9c4b1c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import json
import argparse
import ast
from trainer import train


def main():
    args = setup_parser().parse_args()
    param = load_json(args.config)
    args = vars(args)  # Converting argparse Namespace to a dict.
    param.update(args)  # Add parameters from json
    train(param)


def load_json(settings_path) -> dict:
    with open(settings_path) as data_file:
        param = json.load(data_file)

    return param


def setup_parser():
    parser = argparse.ArgumentParser(description="Prompt2Guard - training part.")
    parser.add_argument(
        "--config",
        type=str,
        default="./configs/cddb_training.json",
        help="Json file of settings.",
    )
    parser.add_argument(
        "--K", type=int, default=argparse.SUPPRESS, help="Number of prompts."
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=argparse.SUPPRESS,
        help="Batch size for training.",
    )
    parser.add_argument(
        "--batch_size_eval",
        type=int,
        default=argparse.SUPPRESS,
        help="Batch size for evaluation.",
    )
    parser.add_argument(
        "--torch_seed",
        type=int,
        default=argparse.SUPPRESS,
        help="Seed for PyTorch random number generator.",
    )
    parser.add_argument(
        "--lrate", type=float, default=argparse.SUPPRESS, help="LR for task > 0."
    )
    parser.add_argument(
        "--init_lr",
        type=float,
        default=argparse.SUPPRESS,
        help="Initial LR for task 0.",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=argparse.SUPPRESS,
        help="Epochs for the other tasks.",
    )
    parser.add_argument(
        "--wandb", action="store_true", help="Enable Weights & Biases logging."
    )

    parser.add_argument(
        "--warmup_epoch",
        type=int,
        default=argparse.SUPPRESS,
        help="Number of warmup epochs.",
    )
    parser.add_argument(
        "--topk_classes", type=int, default=argparse.SUPPRESS, help="TopK classes."
    )
    parser.add_argument(
        "--ensembling",
        type=ast.literal_eval,
        default=argparse.SUPPRESS,
        help="List of boolean values for ensembling.",
    )
    return parser


if __name__ == "__main__":
    main()