Spaces:
Sleeping
Sleeping
File size: 2,284 Bytes
9c4b1c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import json
import argparse
import ast
from trainer import train
def main():
args = setup_parser().parse_args()
param = load_json(args.config)
args = vars(args) # Converting argparse Namespace to a dict.
param.update(args) # Add parameters from json
train(param)
def load_json(settings_path) -> dict:
with open(settings_path) as data_file:
param = json.load(data_file)
return param
def setup_parser():
parser = argparse.ArgumentParser(description="Prompt2Guard - training part.")
parser.add_argument(
"--config",
type=str,
default="./configs/cddb_training.json",
help="Json file of settings.",
)
parser.add_argument(
"--K", type=int, default=argparse.SUPPRESS, help="Number of prompts."
)
parser.add_argument(
"--batch_size",
type=int,
default=argparse.SUPPRESS,
help="Batch size for training.",
)
parser.add_argument(
"--batch_size_eval",
type=int,
default=argparse.SUPPRESS,
help="Batch size for evaluation.",
)
parser.add_argument(
"--torch_seed",
type=int,
default=argparse.SUPPRESS,
help="Seed for PyTorch random number generator.",
)
parser.add_argument(
"--lrate", type=float, default=argparse.SUPPRESS, help="LR for task > 0."
)
parser.add_argument(
"--init_lr",
type=float,
default=argparse.SUPPRESS,
help="Initial LR for task 0.",
)
parser.add_argument(
"--epochs",
type=int,
default=argparse.SUPPRESS,
help="Epochs for the other tasks.",
)
parser.add_argument(
"--wandb", action="store_true", help="Enable Weights & Biases logging."
)
parser.add_argument(
"--warmup_epoch",
type=int,
default=argparse.SUPPRESS,
help="Number of warmup epochs.",
)
parser.add_argument(
"--topk_classes", type=int, default=argparse.SUPPRESS, help="TopK classes."
)
parser.add_argument(
"--ensembling",
type=ast.literal_eval,
default=argparse.SUPPRESS,
help="List of boolean values for ensembling.",
)
return parser
if __name__ == "__main__":
main()
|