| from __future__ import annotations | |
| import yaml | |
| def parse_yaml(config_yaml: str) -> dict: | |
| r"""Parse yaml file.""" | |
| with open(config_yaml, "r") as fr: | |
| return yaml.load(fr, Loader=yaml.FullLoader) | |
| class LinearWarmUp: | |
| r"""Linear learning rate warm up scheduler. | |
| """ | |
| def __init__(self, warm_up_steps: int) -> None: | |
| self.warm_up_steps = warm_up_steps | |
| def __call__(self, step: int) -> float: | |
| if step <= self.warm_up_steps: | |
| return step / self.warm_up_steps | |
| else: | |
| return 1. | |
| def pad_or_truncate(x: list, length: int, pad_value) -> list: | |
| if len(x) >= length: | |
| return x[: length] | |
| else: | |
| return x + [pad_value] * (length - len(x)) |