# Copyright (c) Alibaba, Inc. and its affiliates. from dataclasses import dataclass, field from typing import List, Literal, Optional, Union from datasets import enable_caching from swift.llm import DATASET_MAPPING, register_dataset_info from swift.utils import get_logger logger = get_logger() @dataclass class DataArguments: """ DataArguments class is a dataclass that holds various arguments related to dataset handling and processing. Args: dataset (List[str]): List of dataset_id, dataset_path or dataset_dir val_dataset (List[str]): List of validation dataset_id, dataset_path or dataset_dir split_dataset_ratio (float): Ratio to split the dataset for validation if val_dataset is empty. Default is 0.01. data_seed (Optional[int]): Seed for dataset shuffling. Default is None. dataset_num_proc (int): Number of processes to use for data loading and preprocessing. Default is 1. streaming (bool): Flag to enable streaming of datasets. Default is False. download_mode (Literal): Mode for downloading datasets. Default is 'reuse_dataset_if_exists'. columns: Used for manual column mapping of datasets. model_name (List[str]): List containing Chinese and English names of the model. Default is [None, None]. model_author (List[str]): List containing Chinese and English names of the model author. Default is [None, None]. custom_dataset_info (Optional[str]): Path to custom dataset_info.json file. Default is None. """ # dataset_id or dataset_dir or dataset_path dataset: List[str] = field( default_factory=list, metadata={'help': f'dataset choices: {list(DATASET_MAPPING.keys())}'}) val_dataset: List[str] = field( default_factory=list, metadata={'help': f'dataset choices: {list(DATASET_MAPPING.keys())}'}) split_dataset_ratio: float = 0.01 data_seed: Optional[int] = None dataset_num_proc: int = 1 load_from_cache_file: bool = False dataset_shuffle: bool = True val_dataset_shuffle: bool = False streaming: bool = False interleave_prob: Optional[List[float]] = None stopping_strategy: Literal['first_exhausted', 'all_exhausted'] = 'first_exhausted' shuffle_buffer_size: int = 1000 download_mode: Literal['force_redownload', 'reuse_dataset_if_exists'] = 'reuse_dataset_if_exists' columns: Optional[Union[dict, str]] = None strict: bool = False remove_unused_columns: bool = True # Chinese name and English name model_name: List[str] = field(default_factory=lambda: [None, None], metadata={'help': "e.g. ['小黄', 'Xiao Huang']"}) model_author: List[str] = field( default_factory=lambda: [None, None], metadata={'help': "e.g. ['魔搭', 'ModelScope']"}) custom_dataset_info: List[str] = field(default_factory=list) # .json def _init_custom_dataset_info(self): """register custom dataset_info.json to datasets""" if isinstance(self.custom_dataset_info, str): self.custom_dataset_info = [self.custom_dataset_info] for path in self.custom_dataset_info: register_dataset_info(path) def __post_init__(self): if self.data_seed is None: self.data_seed = self.seed self.columns = self.parse_to_dict(self.columns) if len(self.val_dataset) > 0 or self.streaming: self.split_dataset_ratio = 0. if len(self.val_dataset) > 0: msg = 'len(args.val_dataset) > 0' else: msg = 'args.streaming is True' logger.info(f'Because {msg}, setting split_dataset_ratio: {self.split_dataset_ratio}') self._init_custom_dataset_info() def get_dataset_kwargs(self): return { 'seed': self.data_seed, 'num_proc': self.dataset_num_proc, 'load_from_cache_file': self.load_from_cache_file, 'streaming': self.streaming, 'interleave_prob': self.interleave_prob, 'stopping_strategy': self.stopping_strategy, 'shuffle_buffer_size': self.shuffle_buffer_size, 'use_hf': self.use_hf, 'hub_token': self.hub_token, 'download_mode': self.download_mode, 'columns': self.columns, 'strict': self.strict, 'model_name': self.model_name, 'model_author': self.model_author, 'remove_unused_columns': self.remove_unused_columns, }