| | |
| | |
| | |
| | |
| | |
| |
|
| | import logging |
| | import os |
| |
|
| | from fairseq.dataclass.initialize import add_defaults, hydra_init |
| | from fairseq_cli.train import main as pre_main |
| | from fairseq import distributed_utils, metrics |
| | from fairseq.dataclass.configs import FairseqConfig |
| | from fairseq.dataclass.utils import omegaconf_no_object_check |
| | from fairseq.utils import reset_logging |
| |
|
| | import hydra |
| | from hydra.core.hydra_config import HydraConfig |
| | import torch |
| | from omegaconf import OmegaConf, open_dict |
| |
|
| |
|
| | logger = logging.getLogger("fairseq_cli.hydra_train") |
| |
|
| |
|
| | @hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config") |
| | def hydra_main(cfg: FairseqConfig) -> float: |
| | _hydra_main(cfg) |
| |
|
| |
|
| | def _hydra_main(cfg: FairseqConfig, **kwargs) -> float: |
| | add_defaults(cfg) |
| |
|
| | if cfg.common.reset_logging: |
| | reset_logging() |
| | else: |
| | |
| | if HydraConfig.initialized(): |
| | with open_dict(cfg): |
| | |
| | cfg.job_logging_cfg = OmegaConf.to_container(HydraConfig.get().job_logging, resolve=True) |
| |
|
| | with omegaconf_no_object_check(): |
| | cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)) |
| | OmegaConf.set_struct(cfg, True) |
| |
|
| | try: |
| | if cfg.common.profile: |
| | with torch.cuda.profiler.profile(): |
| | with torch.autograd.profiler.emit_nvtx(): |
| | distributed_utils.call_main(cfg, pre_main, **kwargs) |
| | else: |
| | distributed_utils.call_main(cfg, pre_main, **kwargs) |
| | except BaseException as e: |
| | if not cfg.common.suppress_crashes: |
| | raise |
| | else: |
| | logger.error("Crashed! " + str(e)) |
| |
|
| | |
| | try: |
| | best_val = metrics.get_smoothed_value( |
| | "valid", cfg.checkpoint.best_checkpoint_metric |
| | ) |
| | except: |
| | best_val = None |
| |
|
| | if best_val is None: |
| | best_val = float("inf") |
| |
|
| | return best_val |
| |
|
| |
|
| | def cli_main(): |
| | try: |
| | from hydra._internal.utils import get_args |
| |
|
| | cfg_name = get_args().config_name or "config" |
| | except: |
| | logger.warning("Failed to get config name from hydra args") |
| | cfg_name = "config" |
| |
|
| | hydra_init(cfg_name) |
| | hydra_main() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | cli_main() |
| |
|