File size: 1,937 Bytes
e34b94f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from datasets import DatasetDict
from typing import Dict, Literal
from omegaconf import OmegaConf
from abc import ABC, abstractmethod

from larm.common import utils
from larm.data.envs.base_env import BaseEnv

class BaseDatasetBuilder(ABC):

    def __init__(self, cfg: Dict = None):
        super().__init__()

        if cfg is None:
            # help to create datasets from default config.
            config = load_dataset_config(self.default_config_path())
        elif isinstance(cfg, str):
            config = load_dataset_config(cfg)
        else:
            # when called from runner.build_dataset()
            config = cfg
        
        self.mode = config.get("mode", "sft")
        self.config = config.get(self.mode)

    def build_datasets(self) -> DatasetDict:
        method_builder_map = {
            "sft": self._build_sft_datasets,
            "grpo": self._build_rl_datasets,
        }

        if self.mode not in method_builder_map:
            raise ValueError("Unsupported datasets mode")
        
        return method_builder_map[self.mode]()

    @abstractmethod
    def _build_sft_datasets(self) -> DatasetDict:
        raise NotImplementedError("Should be implemented by subclasses")
    
    @abstractmethod
    def _build_rl_datasets(self) -> DatasetDict:
        raise NotImplementedError("Should be implemented by subclasses")
    
    @abstractmethod
    def get_env_cls(self) -> BaseEnv:
        raise NotImplementedError("Should be implemented by subclasses")

    @abstractmethod
    def get_generation_manager_cls(self) -> BaseEnv:
        raise NotImplementedError("Should be implemented by subclasses")
    
    @classmethod
    def default_config_path(cls, type="default"):
        return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
    

def load_dataset_config(cfg_path: str) -> Dict:
    cfg = OmegaConf.load(cfg_path).datasets
    cfg = cfg[list(cfg.keys())[0]]

    return cfg