"""Factory for creating datasets based on configuration.""" from taoTrain.config import TrainingConfig, TrainingModeEnum from taoTrain.data.pretrain_jsonl import PretrainJSONLDataset from taoTrain.data.sft_jsonl import SFTJSONLDataset from taoTrain.data.rl_jsonl import RLJSONLDataset try: from taoTrain.data.hf_pretrain import PretrainDataset from taoTrain.data.hf_sft import SFTDataset from taoTrain.data.hf_rl import RLDataset except ImportError: PretrainDataset = None SFTDataset = None RLDataset = None class DatasetFactory: """Factory for creating datasets based on configuration.""" # Registry of dataset classes by mode and backend DATASETS = { (TrainingModeEnum.PRETRAIN, "jsonl"): PretrainJSONLDataset, (TrainingModeEnum.SFT, "jsonl"): SFTJSONLDataset, (TrainingModeEnum.RL, "jsonl"): RLJSONLDataset, } if PretrainDataset is not None: DATASETS.update({ (TrainingModeEnum.PRETRAIN, "huggingface"): PretrainDataset, (TrainingModeEnum.SFT, "huggingface"): SFTDataset, (TrainingModeEnum.RL, "huggingface"): RLDataset, }) @staticmethod def create_dataset( config: TrainingConfig, split: str = "train", ): """ Create dataset instance based on configuration. Args: config: Training configuration split: Dataset split (train, validation, test) - primarily for HuggingFace datasets Returns: Dataset instance matching the configured mode and backend Raises: ValueError: If configuration is invalid or unsupported mode/backend combination """ # Determine backend: JSONL or HuggingFace if config.dataset.local: backend = "jsonl" else: backend = "huggingface" # Get mode mode = config.mode # Look up dataset class key = (mode, backend) if key not in DatasetFactory.DATASETS: if backend == "huggingface": raise ImportError( "HuggingFace dataset support requires the optional 'datasets' dependency. " "Install project dependencies before using dataset.local=false." ) raise ValueError( f"Unsupported dataset configuration: mode={mode.value}, backend={backend}. " f"Supported: {list(DatasetFactory.DATASETS.keys())}" ) dataset_class = DatasetFactory.DATASETS[key] # Instantiate dataset if backend == "jsonl": # JSONL datasets don't use split parameter return dataset_class(config) else: # HuggingFace datasets use split parameter return dataset_class(config, split=split) @staticmethod def register_dataset(mode: TrainingModeEnum, backend: str, dataset_class): """ Register a custom dataset class. Args: mode: Training mode (e.g., TrainingModeEnum.PRETRAIN) backend: Backend name (e.g., "jsonl", "huggingface") dataset_class: Dataset class to register """ DatasetFactory.DATASETS[(mode, backend)] = dataset_class @staticmethod def list_available_datasets(): """List all available dataset configurations.""" configs = {} for (mode, backend), dataset_class in DatasetFactory.DATASETS.items(): key = f"{mode.value}_{backend}" configs[key] = { "mode": mode.value, "backend": backend, "class": dataset_class.__name__, } return configs