File size: 1,491 Bytes
e34b94f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
from larm.common.config import Config
from larm.common.registry import registry
class Task:
def __init__(self, config: Config, wandb=None):
self.config = config
self.wandb = wandb
def build_env_and_generator(self):
env_and_gms = dict()
datasets_config = self.config.datasets_cfg
assert len(datasets_config) > 0, "At least one dataset has to be specified."
for name in datasets_config:
dataset_config = datasets_config[name]
builder = registry.get_builder_class(name)(dataset_config)
env_cls = builder.get_env_cls()
generation_manager_cls = builder.get_generation_manager_cls()
env_and_gms[name] = (env_cls, generation_manager_cls)
return env_and_gms
def build_dataset(self):
datasets = dict()
datasets_config = self.config.datasets_cfg
assert len(datasets_config) > 0, "At least one dataset has to be specified."
for name in datasets_config:
dataset_config = datasets_config[name]
builder = registry.get_builder_class(name)(dataset_config)
dataset = builder.build_datasets()
datasets[name] = dataset
return datasets
def build_model(self):
model_config = self.config.model_cfg
model_cls = registry.get_model_class(self.config.method)
model = model_cls.from_config(model_config)
return model
|