Rename scripts/preference_dataset.py to scripts/custom_datasets/preference_dataset.py
af6cb0e
verified
| from datasets import load_dataset | |
| from torchtune.data import StackExchangedPairedTemplate | |
| from torchtune.datasets._preference import PreferenceDataset | |
| from torchtune.modules.tokenizers import Tokenizer | |
| from typing import Optional, Tuple, List | |
| def extract_assistant_content(sample): | |
| """ | |
| Extracts the text content of the assistant response from the lists of messages. | |
| Args: | |
| sample (dict): A dictionary containing the prompt, chosen, and rejected lists of messages. | |
| Returns: | |
| dict: The original sample dictionary with the extracted assistant content. | |
| """ | |
| sample['chosen'] = sample['chosen'][-1]['content'] | |
| sample['rejected'] = sample['rejected'][-1]['content'] | |
| return sample | |
| class ModifiedPreferenceDataset(PreferenceDataset): | |
| def __getitem__(self, index: int) -> Tuple[List[int], List[int], List[int], List[int]]: | |
| sample = self._data[index] | |
| batch = self._prepare_sample(sample) | |
| return ( | |
| batch["chosen_input_ids"], | |
| batch["chosen_labels"], | |
| batch["rejected_input_ids"], | |
| batch["rejected_labels"], | |
| ) | |
| def orpo_dpo_mix_40k_dataset( | |
| tokenizer: Tokenizer, | |
| *, | |
| max_seq_len: int = 8192, | |
| ) -> ModifiedPreferenceDataset: | |
| """ | |
| Preference dataset for the 'mlabonne/orpo-dpo-mix-40k' dataset. | |
| Args: | |
| tokenizer (Tokenizer): Tokenizer used to encode data. | |
| max_seq_len (int): Maximum number of tokens in the returned input and label token id lists. | |
| Default is 8192. | |
| data_dir (str): Directory to store the downloaded dataset. Default is "data". | |
| Returns: | |
| ModifiedPreferenceDataset: The modified preference dataset built from the 'mlabonne/orpo-dpo-mix-40k' dataset. | |
| """ | |
| return ModifiedPreferenceDataset( | |
| tokenizer=tokenizer, | |
| source="mlabonne/orpo-dpo-mix-40k", | |
| template=StackExchangedPairedTemplate(), | |
| transform=extract_assistant_content, | |
| column_map={ | |
| "prompt": "prompt", | |
| "chosen": "chosen", | |
| "rejected": "rejected", | |
| }, | |
| max_seq_len=max_seq_len, | |
| split="train", | |
| data_dir="data" | |
| ) | |