| | import importlib |
| | import json |
| | import logging |
| | import os |
| | from datetime import datetime |
| | from functools import reduce, partial |
| | from operator import getitem |
| | from pathlib import Path |
| |
|
| | from hw_asr import text_encoder as text_encoder_module |
| | from hw_asr.base.base_text_encoder import BaseTextEncoder |
| | from hw_asr.logger import setup_logging |
| | from hw_asr.text_encoder import CTCCharTextEncoder |
| | from hw_asr.utils import read_json, write_json, ROOT_PATH |
| |
|
| |
|
| | class ConfigParser: |
| | def __init__(self, config, resume=None, modification=None, run_id=None): |
| | """ |
| | class to parse configuration json file. Handles hyperparameters for training, |
| | initializations of modules, checkpoint saving and logging module. |
| | :param config: Dict containing configurations, hyperparameters for training. |
| | contents of `config.json` file for example. |
| | :param resume: String, path to the checkpoint being loaded. |
| | :param modification: Dict {keychain: value}, specifying position values to be replaced |
| | from config dict. |
| | :param run_id: Unique Identifier for training processes. |
| | Used to save checkpoints and training log. Timestamp is being used as default |
| | """ |
| | |
| | self._config = _update_config(config, modification) |
| | self.resume = resume |
| | self._text_encoder = None |
| |
|
| | |
| | save_dir = Path(self.config["trainer"]["save_dir"]) |
| |
|
| | exper_name = self.config["name"] |
| | if run_id is None: |
| | run_id = datetime.now().strftime(r"%m%d_%H%M%S") |
| | self._save_dir = str(save_dir / "models" / exper_name / run_id) |
| | self._log_dir = str(save_dir / "log" / exper_name / run_id) |
| |
|
| | |
| | exist_ok = run_id == "" |
| | self.save_dir.mkdir(parents=True, exist_ok=exist_ok) |
| | self.log_dir.mkdir(parents=True, exist_ok=exist_ok) |
| |
|
| | |
| | write_json(self.config, self.save_dir / "config.json") |
| |
|
| | |
| | setup_logging(self.log_dir) |
| | self.log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} |
| |
|
| | @classmethod |
| | def from_args(cls, args, options=""): |
| | """ |
| | Initialize this class from some cli arguments. Used in train, test. |
| | """ |
| | for opt in options: |
| | args.add_argument(*opt.flags, default=None, type=opt.type) |
| | if not isinstance(args, tuple): |
| | args = args.parse_args() |
| |
|
| | if args.device is not None: |
| | os.environ["CUDA_VISIBLE_DEVICES"] = args.device |
| | if args.resume is not None: |
| | resume = Path(args.resume) |
| | cfg_fname = resume.parent / "config.json" |
| | else: |
| | msg_no_cfg = "Configuration file need to be specified. " \ |
| | "Add '-c config.json', for example." |
| | assert args.config is not None, msg_no_cfg |
| | resume = None |
| | cfg_fname = Path(args.config) |
| |
|
| | config = read_json(cfg_fname) |
| | if args.config and resume: |
| | |
| | config.update(read_json(args.config)) |
| |
|
| | |
| | modification = { |
| | opt.target: getattr(args, _get_opt_name(opt.flags)) for opt in options |
| | } |
| | return cls(config, resume, modification) |
| |
|
| | @staticmethod |
| | def init_obj(obj_dict, default_module, *args, **kwargs): |
| | """ |
| | Finds a function handle with the name given as 'type' in config, and returns the |
| | instance initialized with corresponding arguments given. |
| | |
| | `object = config.init_obj(config['param'], module, a, b=1)` |
| | is equivalent to |
| | `object = module.name(a, b=1)` |
| | """ |
| | if "module" in obj_dict: |
| | default_module = importlib.import_module(obj_dict["module"]) |
| |
|
| | module_name = obj_dict["type"] |
| | module_args = dict(obj_dict["args"]) |
| | assert all( |
| | [k not in module_args for k in kwargs] |
| | ), "Overwriting kwargs given in config file is not allowed" |
| | module_args.update(kwargs) |
| | return getattr(default_module, module_name)(*args, **module_args) |
| |
|
| | def init_ftn(self, name, module, *args, **kwargs): |
| | """ |
| | Finds a function handle with the name given as 'type' in config, and returns the |
| | function with given arguments fixed with functools.partial. |
| | |
| | `function = config.init_ftn('name', module, a, b=1)` |
| | is equivalent to |
| | `function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`. |
| | """ |
| | module_name = self[name]["type"] |
| | module_args = dict(self[name]["args"]) |
| | assert all( |
| | [k not in module_args for k in kwargs] |
| | ), "Overwriting kwargs given in config file is not allowed" |
| | module_args.update(kwargs) |
| | return partial(getattr(module, module_name), *args, **module_args) |
| |
|
| | def __getitem__(self, name): |
| | """Access items like ordinary dict.""" |
| | return self.config[name] |
| |
|
| | def get_logger(self, name, verbosity=2): |
| | msg_verbosity = "verbosity option {} is invalid. Valid options are {}.".format( |
| | verbosity, self.log_levels.keys() |
| | ) |
| | assert verbosity in self.log_levels, msg_verbosity |
| | logger = logging.getLogger(name) |
| | logger.setLevel(self.log_levels[verbosity]) |
| | return logger |
| |
|
| | def get_text_encoder(self) -> BaseTextEncoder: |
| | if self._text_encoder is None: |
| | if "text_encoder" not in self._config: |
| | self._text_encoder = CTCCharTextEncoder() |
| | elif self._config["text_encoder"] == "CTCCharTextEncoder": |
| | self._text_encoder = CTCCharTextEncoder(self._config["text_encoder"]["args"]) |
| | else: |
| | self._text_encoder = self.init_obj(self["text_encoder"], |
| | default_module=text_encoder_module) |
| | return self._text_encoder |
| |
|
| | |
| | @property |
| | def config(self): |
| | return self._config |
| |
|
| | @property |
| | def save_dir(self): |
| | return Path(self._save_dir) |
| |
|
| | @property |
| | def log_dir(self): |
| | return Path(self._log_dir) |
| |
|
| | @classmethod |
| | def get_default_configs(cls): |
| | config_path = ROOT_PATH / "hw_asr" / "config.json" |
| | with config_path.open() as f: |
| | return cls(json.load(f)) |
| |
|
| | @classmethod |
| | def get_test_configs(cls): |
| | config_path = ROOT_PATH / "hw_asr" / "tests" / "config.json" |
| | with config_path.open() as f: |
| | return cls(json.load(f)) |
| |
|
| |
|
| | |
| | def _update_config(config, modification): |
| | if modification is None: |
| | return config |
| |
|
| | for k, v in modification.items(): |
| | if v is not None: |
| | _set_by_path(config, k, v) |
| | return config |
| |
|
| |
|
| | def _get_opt_name(flags): |
| | for flg in flags: |
| | if flg.startswith("--"): |
| | return flg.replace("--", "") |
| | return flags[0].replace("--", "") |
| |
|
| |
|
| | def _set_by_path(tree, keys, value): |
| | """Set a value in a nested object in tree by sequence of keys.""" |
| | keys = keys.split(";") |
| | _get_by_path(tree, keys[:-1])[keys[-1]] = value |
| |
|
| |
|
| | def _get_by_path(tree, keys): |
| | """Access a nested object in tree by sequence of keys.""" |
| | return reduce(getitem, keys, tree) |
| |
|