| import logging | |
| from typing import Callable, Literal, Optional, Union | |
| from datasets import Dataset, Value | |
| from transformers import AutoTokenizer | |
| from ..trainer.utils import ConstantLengthDataset | |
| FORMAT_MAPPING = { | |
| "chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}], | |
| "instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)}, | |
| } | |
| def conversations_formatting_function(tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"]): | |
| r""" | |
| return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer | |
| apply chat template to the dataset | |
| """ | |
| def format_dataset(examples): | |
| if isinstance(examples[messages_field][0], list): | |
| output_texts = [] | |
| for i in range(len(examples[messages_field])): | |
| output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False)) | |
| return output_texts | |
| else: | |
| return tokenizer.apply_chat_template(examples[messages_field], tokenize=False) | |
| return format_dataset | |
| def instructions_formatting_function(tokenizer: AutoTokenizer): | |
| r""" | |
| return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer | |
| apply chat template to the dataset | |
| """ | |
| def format_dataset(examples): | |
| if isinstance(examples["prompt"], list): | |
| output_texts = [] | |
| for i in range(len(examples["prompt"])): | |
| converted_sample = [ | |
| {"role": "user", "content": examples["prompt"][i]}, | |
| {"role": "assistant", "content": examples["completion"][i]}, | |
| ] | |
| output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False)) | |
| return output_texts | |
| else: | |
| converted_sample = [ | |
| {"role": "user", "content": examples["prompt"]}, | |
| {"role": "assistant", "content": examples["completion"]}, | |
| ] | |
| return tokenizer.apply_chat_template(converted_sample, tokenize=False) | |
| return format_dataset | |
| def get_formatting_func_from_dataset(dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer) -> Optional[Callable]: | |
| r""" | |
| Finds the correct formatting function based on the dataset structure. Currently supported datasets are: | |
| - `ChatML` with [{"role": str, "content": str}] | |
| - `instruction` with [{"prompt": str, "completion": str}] | |
| Args: | |
| dataset (Dataset): User dataset | |
| tokenizer (AutoTokenizer): Tokenizer used for formatting | |
| Returns: | |
| Callable: Formatting function if the dataset format is supported else None | |
| """ | |
| if isinstance(dataset, Dataset): | |
| if "messages" in dataset.features: | |
| if dataset.features["messages"] == FORMAT_MAPPING["chatml"]: | |
| logging.info("Formatting dataset with chatml format") | |
| return conversations_formatting_function(tokenizer, "messages") | |
| if "conversations" in dataset.features: | |
| if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]: | |
| logging.info("Formatting dataset with chatml format") | |
| return conversations_formatting_function(tokenizer, "conversations") | |
| elif dataset.features == FORMAT_MAPPING["instruction"]: | |
| logging.info("Formatting dataset with instruction format") | |
| return instructions_formatting_function(tokenizer) | |
| return None | |