File size: 2,127 Bytes
b5a0bec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#  Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
#

from inspect import signature
from typing import Any, Dict, Protocol, Union, runtime_checkable

import hydra
from omegaconf import DictConfig, OmegaConf, read_write

from lcm.utils.common import promote_config

TRAINER_KEY = "_trainer_"


@runtime_checkable
class Trainer(Protocol):
    """Abstract trainer in LCM"""

    def run(self) -> Any: ...


def _parse_training_config(train_config: DictConfig):
    """Return the TrainingConfig object from the omegaconf inputs"""
    # The train_config should have 2 keys "_target_" and "_trainer_"
    # the config is set to read-only within stopes module __init__
    assert TRAINER_KEY in train_config, (
        f"The trainer configuration is missing a {TRAINER_KEY} configuration, "
        "you need to specify a Callable to initialize your config."
    )
    trainer_cls_or_func = train_config.get(TRAINER_KEY)
    try:
        trainer_obj = hydra.utils.get_object(trainer_cls_or_func)
        sign = signature(trainer_obj)
        assert len(sign.parameters) == 1 and "config" in sign.parameters, (
            f'{trainer_cls_or_func} should take a single argument called "config"'
        )
        param_type = sign.parameters["config"].annotation

        OmegaConf.resolve(train_config)
        with read_write(train_config):
            del train_config._trainer_

        typed_config = promote_config(train_config, param_type)
        return trainer_obj, typed_config
    except Exception as ex:
        raise ValueError(
            f"couldnt parse the train config: {train_config}.", str(ex)
        ) from ex


def get_trainer(train_config: DictConfig) -> Trainer:
    trainer_obj, typed_config = _parse_training_config(train_config)
    return trainer_obj(typed_config)


def _is_missing(config: Union[DictConfig, Dict], attr: str) -> bool:
    if isinstance(config, Dict):
        return attr in config and config[attr]
    if OmegaConf.is_missing(config, attr):
        return True
    if not hasattr(config, attr) or not getattr(config, attr):
        return True
    return False