File size: 1,491 Bytes
e34b94f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from larm.common.config import Config
from larm.common.registry import registry


class Task:
    def __init__(self, config: Config, wandb=None):
        self.config = config
        self.wandb = wandb
    
    def build_env_and_generator(self):

        env_and_gms = dict()

        datasets_config = self.config.datasets_cfg

        assert len(datasets_config) > 0, "At least one dataset has to be specified."

        for name in datasets_config:
            dataset_config = datasets_config[name]

            builder = registry.get_builder_class(name)(dataset_config) 
            env_cls = builder.get_env_cls()
            generation_manager_cls = builder.get_generation_manager_cls()

            env_and_gms[name] = (env_cls, generation_manager_cls)

        return env_and_gms        

    def build_dataset(self):
        datasets = dict()
        
        datasets_config = self.config.datasets_cfg

        assert len(datasets_config) > 0, "At least one dataset has to be specified."

        for name in datasets_config:
            dataset_config = datasets_config[name]

            builder = registry.get_builder_class(name)(dataset_config) 
            dataset = builder.build_datasets()

            datasets[name] = dataset

        return datasets
    

    def build_model(self):
        model_config = self.config.model_cfg

        model_cls = registry.get_model_class(self.config.method)

        model = model_cls.from_config(model_config)

        return model