| from larm.common.config import Config | |
| from larm.common.registry import registry | |
| class Task: | |
| def __init__(self, config: Config, wandb=None): | |
| self.config = config | |
| self.wandb = wandb | |
| def build_env_and_generator(self): | |
| env_and_gms = dict() | |
| datasets_config = self.config.datasets_cfg | |
| assert len(datasets_config) > 0, "At least one dataset has to be specified." | |
| for name in datasets_config: | |
| dataset_config = datasets_config[name] | |
| builder = registry.get_builder_class(name)(dataset_config) | |
| env_cls = builder.get_env_cls() | |
| generation_manager_cls = builder.get_generation_manager_cls() | |
| env_and_gms[name] = (env_cls, generation_manager_cls) | |
| return env_and_gms | |
| def build_dataset(self): | |
| datasets = dict() | |
| datasets_config = self.config.datasets_cfg | |
| assert len(datasets_config) > 0, "At least one dataset has to be specified." | |
| for name in datasets_config: | |
| dataset_config = datasets_config[name] | |
| builder = registry.get_builder_class(name)(dataset_config) | |
| dataset = builder.build_datasets() | |
| datasets[name] = dataset | |
| return datasets | |
| def build_model(self): | |
| model_config = self.config.model_cfg | |
| model_cls = registry.get_model_class(self.config.method) | |
| model = model_cls.from_config(model_config) | |
| return model | |