StarMist0012's picture
Add files using upload-large-folder tool
3270dae verified
"""Factory for creating datasets based on configuration."""
from taoTrain.config import TrainingConfig, TrainingModeEnum
from taoTrain.data.pretrain_jsonl import PretrainJSONLDataset
from taoTrain.data.sft_jsonl import SFTJSONLDataset
from taoTrain.data.rl_jsonl import RLJSONLDataset
try:
from taoTrain.data.hf_pretrain import PretrainDataset
from taoTrain.data.hf_sft import SFTDataset
from taoTrain.data.hf_rl import RLDataset
except ImportError:
PretrainDataset = None
SFTDataset = None
RLDataset = None
class DatasetFactory:
"""Factory for creating datasets based on configuration."""
# Registry of dataset classes by mode and backend
DATASETS = {
(TrainingModeEnum.PRETRAIN, "jsonl"): PretrainJSONLDataset,
(TrainingModeEnum.SFT, "jsonl"): SFTJSONLDataset,
(TrainingModeEnum.RL, "jsonl"): RLJSONLDataset,
}
if PretrainDataset is not None:
DATASETS.update({
(TrainingModeEnum.PRETRAIN, "huggingface"): PretrainDataset,
(TrainingModeEnum.SFT, "huggingface"): SFTDataset,
(TrainingModeEnum.RL, "huggingface"): RLDataset,
})
@staticmethod
def create_dataset(
config: TrainingConfig,
split: str = "train",
):
"""
Create dataset instance based on configuration.
Args:
config: Training configuration
split: Dataset split (train, validation, test) - primarily for HuggingFace datasets
Returns:
Dataset instance matching the configured mode and backend
Raises:
ValueError: If configuration is invalid or unsupported mode/backend combination
"""
# Determine backend: JSONL or HuggingFace
if config.dataset.local:
backend = "jsonl"
else:
backend = "huggingface"
# Get mode
mode = config.mode
# Look up dataset class
key = (mode, backend)
if key not in DatasetFactory.DATASETS:
if backend == "huggingface":
raise ImportError(
"HuggingFace dataset support requires the optional 'datasets' dependency. "
"Install project dependencies before using dataset.local=false."
)
raise ValueError(
f"Unsupported dataset configuration: mode={mode.value}, backend={backend}. "
f"Supported: {list(DatasetFactory.DATASETS.keys())}"
)
dataset_class = DatasetFactory.DATASETS[key]
# Instantiate dataset
if backend == "jsonl":
# JSONL datasets don't use split parameter
return dataset_class(config)
else:
# HuggingFace datasets use split parameter
return dataset_class(config, split=split)
@staticmethod
def register_dataset(mode: TrainingModeEnum, backend: str, dataset_class):
"""
Register a custom dataset class.
Args:
mode: Training mode (e.g., TrainingModeEnum.PRETRAIN)
backend: Backend name (e.g., "jsonl", "huggingface")
dataset_class: Dataset class to register
"""
DatasetFactory.DATASETS[(mode, backend)] = dataset_class
@staticmethod
def list_available_datasets():
"""List all available dataset configurations."""
configs = {}
for (mode, backend), dataset_class in DatasetFactory.DATASETS.items():
key = f"{mode.value}_{backend}"
configs[key] = {
"mode": mode.value,
"backend": backend,
"class": dataset_class.__name__,
}
return configs