| import json |
| from dataclasses import dataclass, make_dataclass, asdict, field |
| from typing import List |
|
|
|
|
| @dataclass |
| class Config: |
| |
| config: str = "config/default.json" |
| loader: str = "loaders/newsroom.py" |
| dataset: str = "" |
| indices: str = "" |
| model_dir: str = "default_model_dir" |
| validation_datasets: List = field(default_factory=lambda: []) |
|
|
| |
| batch_size: int = 4 |
| learning_rate: float = 0.00001 |
| k_samples: int = 1 |
| sample_aggregation: str = "max" |
| max_val_steps: int = None |
| max_train_steps: int = None |
| max_train_seconds: int = None |
| print_every: int = 10 |
| save_every: int = 100 |
| eval_every: int = 100 |
| verbose: bool = True |
|
|
| |
| encoder_model_id: str = "distilroberta-base" |
| |
| rewards: tuple = ( |
| "FluencyReward", |
| "BiEncoderSimilarity", |
| "GaussianLength", |
| ) |
|
|
|
|
| def validate_config(args): |
| assert (args.sample_aggregation in ("max", "mean")) |
|
|
|
|
| def load_config(args): |
| """ |
| Loads settings into a dataclass object, from the following sources: |
| - defaults defined above by DefaultConfig |
| - args.config (path to a JSON config file) |
| - args (from using argparse in a script) |
| |
| Overlapping fields are overwritten in that order. |
| |
| Example usage: |
| (...) |
| args = load_config(parser.parse_args()) |
| args.batch_size |
| """ |
| config = asdict(Config()) |
| if args.config: |
| with open(args.config) as f: |
| config.update(json.load(f)) |
| config.update(args.__dict__) |
| Config_ = make_dataclass("Config", fields=config.items()) |
| config_object = Config_(**config) |
| validate_config(config_object) |
| return config_object |
|
|