Spaces:
Running
Running
| import json | |
| import argparse | |
| import ast | |
| from trainer import train | |
| def main(): | |
| args = setup_parser().parse_args() | |
| param = load_json(args.config) | |
| args = vars(args) # Converting argparse Namespace to a dict. | |
| param.update(args) # Add parameters from json | |
| train(param) | |
| def load_json(settings_path) -> dict: | |
| with open(settings_path) as data_file: | |
| param = json.load(data_file) | |
| return param | |
| def setup_parser(): | |
| parser = argparse.ArgumentParser(description="Prompt2Guard - training part.") | |
| parser.add_argument( | |
| "--config", | |
| type=str, | |
| default="./configs/cddb_training.json", | |
| help="Json file of settings.", | |
| ) | |
| parser.add_argument( | |
| "--K", type=int, default=argparse.SUPPRESS, help="Number of prompts." | |
| ) | |
| parser.add_argument( | |
| "--batch_size", | |
| type=int, | |
| default=argparse.SUPPRESS, | |
| help="Batch size for training.", | |
| ) | |
| parser.add_argument( | |
| "--batch_size_eval", | |
| type=int, | |
| default=argparse.SUPPRESS, | |
| help="Batch size for evaluation.", | |
| ) | |
| parser.add_argument( | |
| "--torch_seed", | |
| type=int, | |
| default=argparse.SUPPRESS, | |
| help="Seed for PyTorch random number generator.", | |
| ) | |
| parser.add_argument( | |
| "--lrate", type=float, default=argparse.SUPPRESS, help="LR for task > 0." | |
| ) | |
| parser.add_argument( | |
| "--init_lr", | |
| type=float, | |
| default=argparse.SUPPRESS, | |
| help="Initial LR for task 0.", | |
| ) | |
| parser.add_argument( | |
| "--epochs", | |
| type=int, | |
| default=argparse.SUPPRESS, | |
| help="Epochs for the other tasks.", | |
| ) | |
| parser.add_argument( | |
| "--wandb", action="store_true", help="Enable Weights & Biases logging." | |
| ) | |
| parser.add_argument( | |
| "--warmup_epoch", | |
| type=int, | |
| default=argparse.SUPPRESS, | |
| help="Number of warmup epochs.", | |
| ) | |
| parser.add_argument( | |
| "--topk_classes", type=int, default=argparse.SUPPRESS, help="TopK classes." | |
| ) | |
| parser.add_argument( | |
| "--ensembling", | |
| type=ast.literal_eval, | |
| default=argparse.SUPPRESS, | |
| help="List of boolean values for ensembling.", | |
| ) | |
| return parser | |
| if __name__ == "__main__": | |
| main() | |