File size: 1,937 Bytes
e34b94f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
from datasets import DatasetDict
from typing import Dict, Literal
from omegaconf import OmegaConf
from abc import ABC, abstractmethod
from larm.common import utils
from larm.data.envs.base_env import BaseEnv
class BaseDatasetBuilder(ABC):
def __init__(self, cfg: Dict = None):
super().__init__()
if cfg is None:
# help to create datasets from default config.
config = load_dataset_config(self.default_config_path())
elif isinstance(cfg, str):
config = load_dataset_config(cfg)
else:
# when called from runner.build_dataset()
config = cfg
self.mode = config.get("mode", "sft")
self.config = config.get(self.mode)
def build_datasets(self) -> DatasetDict:
method_builder_map = {
"sft": self._build_sft_datasets,
"grpo": self._build_rl_datasets,
}
if self.mode not in method_builder_map:
raise ValueError("Unsupported datasets mode")
return method_builder_map[self.mode]()
@abstractmethod
def _build_sft_datasets(self) -> DatasetDict:
raise NotImplementedError("Should be implemented by subclasses")
@abstractmethod
def _build_rl_datasets(self) -> DatasetDict:
raise NotImplementedError("Should be implemented by subclasses")
@abstractmethod
def get_env_cls(self) -> BaseEnv:
raise NotImplementedError("Should be implemented by subclasses")
@abstractmethod
def get_generation_manager_cls(self) -> BaseEnv:
raise NotImplementedError("Should be implemented by subclasses")
@classmethod
def default_config_path(cls, type="default"):
return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
def load_dataset_config(cfg_path: str) -> Dict:
cfg = OmegaConf.load(cfg_path).datasets
cfg = cfg[list(cfg.keys())[0]]
return cfg
|