| from functools import partial | |
| from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple | |
| from .processors.feedback import preprocess_feedback_dataset | |
| from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example | |
| from .processors.pretrain import preprocess_pretrain_dataset | |
| from .processors.supervised import ( | |
| preprocess_packed_supervised_dataset, | |
| preprocess_supervised_dataset, | |
| print_supervised_dataset_example, | |
| ) | |
| from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example | |
| if TYPE_CHECKING: | |
| from transformers import ProcessorMixin, Seq2SeqTrainingArguments | |
| from transformers.tokenization_utils import PreTrainedTokenizer | |
| from ..hparams import DataArguments | |
| from .template import Template | |
| def get_preprocess_and_print_func( | |
| data_args: "DataArguments", | |
| training_args: "Seq2SeqTrainingArguments", | |
| stage: Literal["pt", "sft", "rm", "kto"], | |
| template: "Template", | |
| tokenizer: "PreTrainedTokenizer", | |
| processor: Optional["ProcessorMixin"], | |
| ) -> Tuple[Callable, Callable]: | |
| if stage == "pt": | |
| preprocess_func = partial( | |
| preprocess_pretrain_dataset, | |
| tokenizer=tokenizer, | |
| data_args=data_args, | |
| ) | |
| print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) | |
| elif stage == "sft" and not training_args.predict_with_generate: | |
| if data_args.packing: | |
| preprocess_func = partial( | |
| preprocess_packed_supervised_dataset, | |
| template=template, | |
| tokenizer=tokenizer, | |
| data_args=data_args, | |
| ) | |
| else: | |
| preprocess_func = partial( | |
| preprocess_supervised_dataset, | |
| template=template, | |
| tokenizer=tokenizer, | |
| processor=processor, | |
| data_args=data_args, | |
| ) | |
| print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) | |
| elif stage == "rm": | |
| preprocess_func = partial( | |
| preprocess_pairwise_dataset, | |
| template=template, | |
| tokenizer=tokenizer, | |
| processor=processor, | |
| data_args=data_args, | |
| ) | |
| print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) | |
| elif stage == "kto": | |
| preprocess_func = partial( | |
| preprocess_feedback_dataset, | |
| template=template, | |
| tokenizer=tokenizer, | |
| processor=processor, | |
| data_args=data_args, | |
| ) | |
| print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) | |
| else: | |
| preprocess_func = partial( | |
| preprocess_unsupervised_dataset, | |
| template=template, | |
| tokenizer=tokenizer, | |
| processor=processor, | |
| data_args=data_args, | |
| ) | |
| print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) | |
| return preprocess_func, print_function | |