Frank Ruis
wrap prepared_ds_path in str() to avoid TypeError in fsspec package (#1548)
7477a53
unverified
| """data handling specific to SFT""" | |
| import functools | |
| import logging | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple, Union | |
| from datasets import ( | |
| Dataset, | |
| DatasetDict, | |
| concatenate_datasets, | |
| load_dataset, | |
| load_from_disk, | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| from huggingface_hub.utils import HFValidationError | |
| from transformers import PreTrainedTokenizerBase | |
| from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH | |
| from axolotl.datasets import TokenizedPromptDataset | |
| from axolotl.prompt_strategies import load | |
| from axolotl.prompt_tokenizers import ( | |
| AlpacaMultipleChoicePromptTokenizingStrategy, | |
| AlpacaPromptTokenizingStrategy, | |
| AlpacaReflectionPTStrategy, | |
| GPTeacherPromptTokenizingStrategy, | |
| JeopardyPromptTokenizingStrategy, | |
| OpenAssistantPromptTokenizingStrategy, | |
| SummarizeTLDRPromptTokenizingStrategy, | |
| ) | |
| from axolotl.prompters import ( | |
| AlpacaPrompter, | |
| GPTeacherPrompter, | |
| JeopardyPrompter, | |
| MultipleChoiceConcisePrompter, | |
| MultipleChoiceExplainPrompter, | |
| Prompter, | |
| ReflectAlpacaPrompter, | |
| SummarizeTLDRPrompter, | |
| UnsupportedPrompter, | |
| ) | |
| from axolotl.utils.data.pretraining import wrap_pretraining_dataset | |
| from axolotl.utils.data.utils import md5 | |
| from axolotl.utils.dict import DictDefault | |
| from axolotl.utils.distributed import is_main_process, zero_first | |
| from axolotl.utils.trainer import ( | |
| calculate_total_num_steps, | |
| process_datasets_for_packing, | |
| ) | |
| LOG = logging.getLogger("axolotl") | |
| def prepare_dataset(cfg, tokenizer): | |
| prompters = [] | |
| if not cfg.pretraining_dataset: | |
| with zero_first(is_main_process()): | |
| if cfg.test_datasets: | |
| train_dataset, _, prompters = load_prepare_datasets( | |
| tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train" | |
| ) | |
| _, eval_dataset, _ = load_prepare_datasets( | |
| tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test" | |
| ) | |
| else: | |
| train_dataset, eval_dataset, prompters = load_prepare_datasets( | |
| tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH | |
| ) | |
| else: | |
| path = cfg.pretraining_dataset | |
| split = "train" | |
| name = None | |
| if isinstance(cfg.pretraining_dataset, list) and isinstance( | |
| cfg.pretraining_dataset[0], dict | |
| ): | |
| path = cfg.pretraining_dataset[0]["path"] | |
| name = cfg.pretraining_dataset[0]["name"] | |
| if "split" in cfg.pretraining_dataset[0]: | |
| split = cfg.pretraining_dataset[0]["split"] | |
| ds_wrapper_partial = functools.partial( | |
| get_dataset_wrapper, | |
| cfg.pretraining_dataset[0], | |
| tokenizer, | |
| cfg, | |
| cfg.pretraining_dataset[0]["type"] or "pretrain", | |
| ) | |
| train_dataset = wrap_pretraining_dataset( | |
| load_dataset(path, streaming=True, split=split, name=name), | |
| tokenizer, | |
| cfg, | |
| ds_wrapper_partial, | |
| max_tokens=cfg.sequence_len, | |
| batch_size=cfg.micro_batch_size, | |
| seed=cfg.seed or 42, | |
| buffer_size=cfg.pretrain_multipack_buffer_size or 10_000, | |
| ) | |
| # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 | |
| train_dataset = train_dataset.with_format("torch") | |
| eval_dataset = None | |
| return train_dataset, eval_dataset, cfg.max_steps, prompters | |
| if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False: | |
| total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False) | |
| if total_eval_steps == 0: | |
| raise ValueError( | |
| "eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. " | |
| ) | |
| if cfg.max_steps: | |
| total_num_steps = min( | |
| calculate_total_num_steps(cfg, train_dataset), cfg.max_steps | |
| ) | |
| LOG.info(f"Maximum number of steps set at {total_num_steps}") | |
| else: | |
| total_num_steps = calculate_total_num_steps(cfg, train_dataset) | |
| return train_dataset, eval_dataset, total_num_steps, prompters | |
| def load_tokenized_prepared_datasets( | |
| tokenizer, | |
| cfg, | |
| default_dataset_prepared_path, | |
| split="train", | |
| ) -> Tuple[DatasetDict, List[Prompter]]: | |
| cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets | |
| tokenizer_name = cfg.tokenizer_config | |
| ds_hash = str( | |
| md5( | |
| ( | |
| str(cfg.sequence_len) | |
| + "@" | |
| + str(cfg.sample_packing) | |
| + "@" | |
| + str(cfg.eval_sample_packing) | |
| + "@" | |
| + str(cfg.group_by_length) | |
| + "@" | |
| + "|".join( | |
| sorted( | |
| [ | |
| f"{d.path}:{d.type}:{d.shards}:{d.conversation}{d.split}" | |
| for d in cfg_datasets | |
| ] | |
| ) | |
| ) | |
| + "|" | |
| + tokenizer_name | |
| ) | |
| ) | |
| ) | |
| prepared_ds_path = ( | |
| Path(cfg.dataset_prepared_path) / ds_hash | |
| if cfg.dataset_prepared_path | |
| else Path(default_dataset_prepared_path) / ds_hash | |
| ) | |
| dataset = None | |
| prompters = [] | |
| use_auth_token = cfg.hf_use_auth_token | |
| try: | |
| if cfg.push_dataset_to_hub: | |
| dataset = load_dataset( | |
| f"{cfg.push_dataset_to_hub}/{ds_hash}", | |
| token=use_auth_token, | |
| ) | |
| dataset = dataset[split] | |
| except Exception: # pylint: disable=broad-except # nosec | |
| pass | |
| # pylint: disable=duplicate-code | |
| if dataset: | |
| ... | |
| elif ( | |
| cfg.dataset_prepared_path | |
| and any(prepared_ds_path.glob("*")) | |
| and not cfg.is_preprocess | |
| ): | |
| LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") | |
| dataset = load_from_disk(str(prepared_ds_path)) | |
| LOG.info("Prepared dataset loaded from disk...") | |
| else: | |
| LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}") | |
| LOG.info("Loading raw datasets...") | |
| if not cfg.is_preprocess: | |
| LOG.warning( | |
| "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset." | |
| ) | |
| if cfg.seed: | |
| seed = cfg.seed | |
| else: | |
| LOG.info("No seed provided, using default seed of 42") | |
| seed = 42 | |
| datasets = [] | |
| def for_d_in_datasets(dataset_configs): | |
| for dataset in dataset_configs: | |
| if dataset.name and isinstance(dataset.name, list): | |
| for name in dataset.name: | |
| yield DictDefault({**dataset, "name": name}) | |
| else: | |
| yield dataset | |
| # pylint: disable=invalid-name | |
| for config_dataset in for_d_in_datasets(cfg_datasets): | |
| ds: Optional[Union[Dataset, DatasetDict]] = None | |
| ds_from_hub = False | |
| try: | |
| load_dataset( | |
| config_dataset.path, | |
| name=config_dataset.name, | |
| streaming=True, | |
| token=use_auth_token, | |
| ) | |
| ds_from_hub = True | |
| except (FileNotFoundError, ConnectionError, HFValidationError, ValueError): | |
| pass | |
| ds_from_cloud = False | |
| storage_options = {} | |
| remote_file_system = None | |
| if config_dataset.path.startswith("s3://"): | |
| try: | |
| import aiobotocore.session # type: ignore | |
| import s3fs # type: ignore | |
| except ImportError as exc: | |
| raise ImportError( | |
| "s3:// paths require aiobotocore and s3fs to be installed" | |
| ) from exc | |
| # Takes credentials from ~/.aws/credentials for default profile | |
| s3_session = aiobotocore.session.AioSession(profile="default") | |
| storage_options = {"session": s3_session} | |
| remote_file_system = s3fs.S3FileSystem(**storage_options) | |
| elif config_dataset.path.startswith( | |
| "gs://" | |
| ) or config_dataset.path.startswith("gcs://"): | |
| try: | |
| import gcsfs # type: ignore | |
| except ImportError as exc: | |
| raise ImportError( | |
| "gs:// or gcs:// paths require gcsfs to be installed" | |
| ) from exc | |
| # gcsfs will use default credentials from the environment else anon | |
| # https://gcsfs.readthedocs.io/en/latest/#credentials | |
| storage_options = {"token": None} | |
| remote_file_system = gcsfs.GCSFileSystem(**storage_options) | |
| # TODO: Figure out how to get auth creds passed | |
| # elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"): | |
| # try: | |
| # import adlfs | |
| # except ImportError as exc: | |
| # raise ImportError( | |
| # "adl:// or abfs:// paths require adlfs to be installed" | |
| # ) from exc | |
| # # Gen 1 | |
| # storage_options = { | |
| # "tenant_id": TENANT_ID, | |
| # "client_id": CLIENT_ID, | |
| # "client_secret": CLIENT_SECRET, | |
| # } | |
| # # Gen 2 | |
| # storage_options = { | |
| # "account_name": ACCOUNT_NAME, | |
| # "account_key": ACCOUNT_KEY, | |
| # } | |
| # remote_file_system = adlfs.AzureBlobFileSystem(**storage_options) | |
| try: | |
| if remote_file_system and remote_file_system.exists( | |
| config_dataset.path | |
| ): | |
| ds_from_cloud = True | |
| except (FileNotFoundError, ConnectionError): | |
| pass | |
| # prefer local dataset, even if hub exists | |
| local_path = Path(config_dataset.path) | |
| if local_path.exists(): | |
| if local_path.is_dir(): | |
| if config_dataset.data_files: | |
| ds_type = get_ds_type(config_dataset) | |
| ds = load_dataset( | |
| ds_type, | |
| name=config_dataset.name, | |
| data_files=config_dataset.data_files, | |
| streaming=False, | |
| split=None, | |
| ) | |
| else: | |
| ds = load_from_disk(config_dataset.path) | |
| elif local_path.is_file(): | |
| ds_type = get_ds_type(config_dataset) | |
| ds = load_dataset( | |
| ds_type, | |
| name=config_dataset.name, | |
| data_files=config_dataset.path, | |
| streaming=False, | |
| split=None, | |
| ) | |
| else: | |
| raise ValueError( | |
| "unhandled dataset load: local path exists, but is neither a directory or a file" | |
| ) | |
| elif ds_from_hub: | |
| ds = load_dataset( | |
| config_dataset.path, | |
| name=config_dataset.name, | |
| streaming=False, | |
| data_files=config_dataset.data_files, | |
| token=use_auth_token, | |
| ) | |
| elif ds_from_cloud and remote_file_system: | |
| if remote_file_system.isdir(config_dataset.path): | |
| ds = load_from_disk( | |
| config_dataset.path, | |
| storage_options=storage_options, | |
| ) | |
| elif remote_file_system.isfile(config_dataset.path): | |
| ds_type = get_ds_type(config_dataset) | |
| ds = load_dataset( | |
| ds_type, | |
| name=config_dataset.name, | |
| data_files=config_dataset.path, | |
| streaming=False, | |
| split=None, | |
| storage_options=storage_options, | |
| ) | |
| elif config_dataset.path.startswith("https://"): | |
| ds_type = get_ds_type(config_dataset) | |
| ds = load_dataset( | |
| ds_type, | |
| name=config_dataset.name, | |
| data_files=config_dataset.path, | |
| streaming=False, | |
| split=None, | |
| storage_options=storage_options, | |
| ) | |
| else: | |
| if isinstance(config_dataset.data_files, str): | |
| fp = hf_hub_download( | |
| repo_id=config_dataset.path, | |
| repo_type="dataset", | |
| filename=config_dataset.data_files, | |
| ) | |
| elif isinstance(config_dataset.data_files, list): | |
| fp = [] | |
| for file in config_dataset.data_files: | |
| fp.append( | |
| hf_hub_download( | |
| repo_id=config_dataset.path, | |
| repo_type="dataset", | |
| filename=file, | |
| ) | |
| ) | |
| else: | |
| raise ValueError( | |
| "data_files must be either a string or list of strings" | |
| ) | |
| ds = load_dataset( | |
| "json", | |
| name=config_dataset.name, | |
| data_files=fp, | |
| streaming=False, | |
| split=None, | |
| ) | |
| if not ds: | |
| raise ValueError("unhandled dataset load") | |
| d_base_type = d_prompt_style = None | |
| d_type = config_dataset.type | |
| if isinstance(d_type, str): | |
| d_type_split = d_type.split(":") | |
| d_base_type = d_type_split[0] | |
| d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None | |
| if isinstance(ds, DatasetDict): | |
| if config_dataset.split and config_dataset.split in ds: | |
| ds = ds[config_dataset.split] | |
| elif split in ds: | |
| ds = ds[split] | |
| else: | |
| raise ValueError( | |
| f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `" | |
| ) | |
| # support for using a subset of the data | |
| if config_dataset.shards: | |
| shards_idx = config_dataset.get("shards_idx", 0) | |
| ds = ds.shuffle(seed=seed).shard( | |
| num_shards=config_dataset.shards, index=shards_idx | |
| ) | |
| dataset_wrapper, dataset_prompter = get_dataset_wrapper( | |
| config_dataset=config_dataset, | |
| tokenizer=tokenizer, | |
| cfg=cfg, | |
| dataset=ds, | |
| d_base_type=d_base_type, | |
| d_prompt_style=d_prompt_style, | |
| ) | |
| datasets.append(dataset_wrapper) | |
| prompters.append(dataset_prompter) | |
| LOG.info("merging datasets") | |
| dataset = concatenate_datasets(datasets) | |
| if len(datasets) > 1: | |
| if cfg.shuffle_merged_datasets: | |
| LOG.debug("shuffle merged datasets") | |
| dataset = dataset.shuffle(seed=seed) | |
| else: | |
| LOG.debug("NOT shuffling merged datasets") | |
| dataset, _ = process_datasets_for_packing(cfg, dataset, None) | |
| if cfg.local_rank == 0: | |
| LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") | |
| dataset.save_to_disk(str(prepared_ds_path)) | |
| if cfg.push_dataset_to_hub: | |
| LOG.info( | |
| f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" | |
| ) | |
| dataset.push_to_hub( | |
| f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True | |
| ) | |
| return dataset, prompters | |
| def get_ds_type(config_dataset: DictDefault): | |
| """ | |
| Get the dataset type from the path if it's not specified | |
| """ | |
| ds_type = "json" | |
| if config_dataset.ds_type: | |
| ds_type = config_dataset.ds_type | |
| elif ".parquet" in config_dataset.path: | |
| ds_type = "parquet" | |
| elif ".arrow" in config_dataset.path: | |
| ds_type = "arrow" | |
| elif ".csv" in config_dataset.path: | |
| ds_type = "csv" | |
| elif ".txt" in config_dataset.path: | |
| ds_type = "text" | |
| return ds_type | |
| def load_prepare_datasets( | |
| tokenizer: PreTrainedTokenizerBase, | |
| cfg, | |
| default_dataset_prepared_path, | |
| split="train", | |
| ) -> Tuple[Dataset, Dataset, List[Prompter]]: | |
| dataset, prompters = load_tokenized_prepared_datasets( | |
| tokenizer, cfg, default_dataset_prepared_path, split=split | |
| ) | |
| if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: | |
| LOG.info( | |
| f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards" | |
| ) | |
| dataset = dataset.shard( | |
| num_shards=cfg.dataset_shard_num, | |
| index=cfg.dataset_shard_idx, | |
| ) | |
| if split == "train" and cfg.val_set_size: | |
| # ensure we end up with the same fingerprint by doing rank0 first and being able to cache | |
| to_hash_train = ( | |
| dataset._fingerprint # pylint: disable=protected-access | |
| + "|" | |
| + str(cfg.val_set_size) | |
| + "|" | |
| + "train" | |
| + "|" | |
| + str(cfg.seed or 42) | |
| ) | |
| to_hash_test = ( | |
| dataset._fingerprint # pylint: disable=protected-access | |
| + "|" | |
| + str(cfg.val_set_size) | |
| + "|" | |
| + "test" | |
| + "|" | |
| + str(cfg.seed or 42) | |
| ) | |
| train_fingerprint = md5(to_hash_train) | |
| test_fingerprint = md5(to_hash_test) | |
| dataset = dataset.train_test_split( | |
| test_size=cfg.val_set_size, | |
| shuffle=False, | |
| seed=cfg.seed or 42, | |
| train_new_fingerprint=train_fingerprint, | |
| test_new_fingerprint=test_fingerprint, | |
| ) | |
| train_dataset = dataset["train"] | |
| eval_dataset = dataset["test"] | |
| elif split == "test": | |
| train_dataset = None | |
| eval_dataset = dataset | |
| else: | |
| train_dataset = dataset | |
| eval_dataset = None | |
| return train_dataset, eval_dataset, prompters | |
| def get_dataset_wrapper( | |
| config_dataset, | |
| tokenizer, | |
| cfg, | |
| d_base_type, | |
| dataset, | |
| d_prompt_style=None, | |
| ): | |
| dataset_wrapper = None | |
| dataset_prompter = None | |
| ds_kwargs = { | |
| "process_count": cfg.dataset_processes, | |
| "keep_in_memory": cfg.dataset_keep_in_memory is True, | |
| } | |
| if ( | |
| isinstance(dataset, Dataset) | |
| and "input_ids" in dataset.features | |
| and "attention_mask" in dataset.features | |
| and "labels" in dataset.features | |
| ): | |
| # dataset is already tokenized, just drop it straight in | |
| dataset_prompter = UnsupportedPrompter() | |
| dataset_wrapper = dataset | |
| elif isinstance(config_dataset.type, DictDefault): | |
| ds_strategy = load( | |
| "user_defined", tokenizer, cfg, config_dataset.type.to_dict() | |
| ) | |
| dataset_prompter = UnsupportedPrompter() | |
| dataset_wrapper = TokenizedPromptDataset( | |
| ds_strategy, | |
| dataset, | |
| **ds_kwargs, | |
| ) | |
| elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset): | |
| dataset_prompter = UnsupportedPrompter() | |
| dataset_wrapper = TokenizedPromptDataset( | |
| ds_strategy, | |
| dataset, | |
| **ds_kwargs, | |
| ) | |
| elif d_base_type == "alpaca": | |
| dataset_prompter = AlpacaPrompter(d_prompt_style) | |
| ds_strategy = AlpacaPromptTokenizingStrategy( | |
| dataset_prompter, | |
| tokenizer, | |
| cfg.train_on_inputs, | |
| cfg.sequence_len, | |
| ) | |
| ds_wrapper = TokenizedPromptDataset( | |
| ds_strategy, | |
| dataset, | |
| **ds_kwargs, | |
| ) | |
| dataset_wrapper = ds_wrapper | |
| elif d_base_type == "explainchoice": | |
| dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style) | |
| ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( | |
| dataset_prompter, | |
| tokenizer, | |
| cfg.train_on_inputs, | |
| cfg.sequence_len, | |
| ) | |
| ds_wrapper = TokenizedPromptDataset( | |
| ds_strategy, | |
| dataset, | |
| **ds_kwargs, | |
| ) | |
| dataset_wrapper = ds_wrapper | |
| elif d_base_type == "concisechoice": | |
| dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style) | |
| ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( | |
| dataset_prompter, | |
| tokenizer, | |
| cfg.train_on_inputs, | |
| cfg.sequence_len, | |
| ) | |
| ds_wrapper = TokenizedPromptDataset( | |
| ds_strategy, | |
| dataset, | |
| **ds_kwargs, | |
| ) | |
| dataset_wrapper = ds_wrapper | |
| elif d_base_type == "summarizetldr": | |
| dataset_prompter = SummarizeTLDRPrompter(d_prompt_style) | |
| ds_strategy = SummarizeTLDRPromptTokenizingStrategy( | |
| dataset_prompter, | |
| tokenizer, | |
| cfg.train_on_inputs, | |
| cfg.sequence_len, | |
| ) | |
| ds_wrapper = TokenizedPromptDataset( | |
| ds_strategy, | |
| dataset, | |
| **ds_kwargs, | |
| ) | |
| dataset_wrapper = ds_wrapper | |
| elif d_base_type == "jeopardy": | |
| dataset_prompter = JeopardyPrompter(d_prompt_style) | |
| ds_strategy = JeopardyPromptTokenizingStrategy( | |
| dataset_prompter, | |
| tokenizer, | |
| cfg.train_on_inputs, | |
| cfg.sequence_len, | |
| ) | |
| ds_wrapper = TokenizedPromptDataset( | |
| ds_strategy, | |
| dataset, | |
| **ds_kwargs, | |
| ) | |
| dataset_wrapper = ds_wrapper | |
| elif d_base_type == "oasst": | |
| dataset_prompter = AlpacaPrompter(d_prompt_style) | |
| ds_strategy = OpenAssistantPromptTokenizingStrategy( | |
| dataset_prompter, | |
| tokenizer, | |
| cfg.train_on_inputs, | |
| cfg.sequence_len, | |
| ) | |
| ds_wrapper = TokenizedPromptDataset( | |
| ds_strategy, | |
| dataset, | |
| **ds_kwargs, | |
| ) | |
| dataset_wrapper = ds_wrapper | |
| elif d_base_type == "gpteacher": | |
| dataset_prompter = GPTeacherPrompter(d_prompt_style) | |
| ds_strategy = GPTeacherPromptTokenizingStrategy( | |
| dataset_prompter, | |
| tokenizer, | |
| cfg.train_on_inputs, | |
| cfg.sequence_len, | |
| ) | |
| ds_wrapper = TokenizedPromptDataset( | |
| ds_strategy, | |
| dataset, | |
| **ds_kwargs, | |
| ) | |
| dataset_wrapper = ds_wrapper | |
| elif d_base_type == "reflection": | |
| dataset_prompter = ReflectAlpacaPrompter(d_prompt_style) | |
| ds_strategy = AlpacaReflectionPTStrategy( | |
| dataset_prompter, | |
| tokenizer, | |
| cfg.train_on_inputs, | |
| cfg.sequence_len, | |
| ) | |
| ds_wrapper = TokenizedPromptDataset( | |
| ds_strategy, | |
| dataset, | |
| **ds_kwargs, | |
| ) | |
| dataset_wrapper = ds_wrapper | |
| else: | |
| suffix = "" | |
| if ":load_" in config_dataset.type: | |
| suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?" | |
| LOG.error( | |
| f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}" | |
| ) | |
| raise ValueError( | |
| f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}" | |
| ) | |
| return dataset_wrapper, dataset_prompter | |