|
|
from datasets import DatasetDict |
|
|
from typing import Dict, Literal |
|
|
from omegaconf import OmegaConf |
|
|
from abc import ABC, abstractmethod |
|
|
|
|
|
from larm.common import utils |
|
|
from larm.data.envs.base_env import BaseEnv |
|
|
|
|
|
class BaseDatasetBuilder(ABC): |
|
|
|
|
|
def __init__(self, cfg: Dict = None): |
|
|
super().__init__() |
|
|
|
|
|
if cfg is None: |
|
|
|
|
|
config = load_dataset_config(self.default_config_path()) |
|
|
elif isinstance(cfg, str): |
|
|
config = load_dataset_config(cfg) |
|
|
else: |
|
|
|
|
|
config = cfg |
|
|
|
|
|
self.mode = config.get("mode", "sft") |
|
|
self.config = config.get(self.mode) |
|
|
|
|
|
def build_datasets(self) -> DatasetDict: |
|
|
method_builder_map = { |
|
|
"sft": self._build_sft_datasets, |
|
|
"grpo": self._build_rl_datasets, |
|
|
} |
|
|
|
|
|
if self.mode not in method_builder_map: |
|
|
raise ValueError("Unsupported datasets mode") |
|
|
|
|
|
return method_builder_map[self.mode]() |
|
|
|
|
|
@abstractmethod |
|
|
def _build_sft_datasets(self) -> DatasetDict: |
|
|
raise NotImplementedError("Should be implemented by subclasses") |
|
|
|
|
|
@abstractmethod |
|
|
def _build_rl_datasets(self) -> DatasetDict: |
|
|
raise NotImplementedError("Should be implemented by subclasses") |
|
|
|
|
|
@abstractmethod |
|
|
def get_env_cls(self) -> BaseEnv: |
|
|
raise NotImplementedError("Should be implemented by subclasses") |
|
|
|
|
|
@abstractmethod |
|
|
def get_generation_manager_cls(self) -> BaseEnv: |
|
|
raise NotImplementedError("Should be implemented by subclasses") |
|
|
|
|
|
@classmethod |
|
|
def default_config_path(cls, type="default"): |
|
|
return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type]) |
|
|
|
|
|
|
|
|
def load_dataset_config(cfg_path: str) -> Dict: |
|
|
cfg = OmegaConf.load(cfg_path).datasets |
|
|
cfg = cfg[list(cfg.keys())[0]] |
|
|
|
|
|
return cfg |
|
|
|