| """ |
| Alpaca Clean dataset with Llama3-Instruct prompt formatting |
| """ |
|
|
| from functools import partial |
| from os.path import join |
|
|
| import numpy as np |
| from tqdm import tqdm |
|
|
| import torch |
| from torch.utils.data import Dataset, DataLoader |
|
|
| from datasets import load_metric, load_dataset |
| from transformers import AutoTokenizer |
| from transformers import DataCollatorForSeq2Seq, DefaultDataCollator, DataCollatorWithPadding |
|
|
| from .utils import ( |
| get_lm_loader, get_seq2seq_loader, |
| convert_to_hf_dataset, |
| get_tokenizer_from_config, |
| download_scrolls_metric as download_metric |
| ) |
| from .utils.packing import ConcatDataset |
|
|
|
|
| SYSTEM_PROMPT = "You are a helpful AI assistant who always responds to appropriately complete a user's request." |
|
|
|
|
| def encode_response(response: str, tokenizer) -> list[int]: |
| tokens = tokenizer.encode(response.strip(), add_special_tokens=False) |
| |
| tokens.append(tokenizer.eos_token_id) |
| try: |
| tokens.append(tokenizer.get_added_vocab()["<|end_of_text|>"]) |
| except KeyError: |
| pass |
| return tokens |
|
|
|
|
| def load_data(name: str, dataset_config: dict, pretrained_model_config: dict, |
| preprocess_config: dict, **loader_kwargs: any): |
|
|
| |
| cache_dir = dataset_config['cache_dir'] |
| input_len = dataset_config['chunk_size'] |
| concat_data = dataset_config['concat_data'] |
| load_from_cache_file = False |
|
|
| |
| if 'istral' in pretrained_model_config['pretrained_model_name_or_path']: |
| system_prompt = '' |
| else: |
| system_prompt = SYSTEM_PROMPT |
|
|
| tokenizer_name = pretrained_model_config['pretrained_model_name_or_path'] |
| tokenizer_name = tokenizer_name.split('/')[-1] |
| save_path = join(cache_dir, f'{name}_{tokenizer_name}') |
| |
| |
| tokenizer = get_tokenizer_from_config(pretrained_model_config) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| print(f'Setting tokenizer.pad_token to {tokenizer.pad_token}') |
|
|
| tokenizer.padding_side = 'left' |
|
|
| |
| ignore_kwargs = ['concat_data', 'chunk_size', 'pose_kwargs', 'system_prompt', 'name'] |
| train_set = load_dataset( |
| **{k: v for k, v in dataset_config.items() if k not in ignore_kwargs}, |
| split='train[100:-100]', |
| ) |
| val_set = load_dataset( |
| **{k: v for k, v in dataset_config.items() if k not in ignore_kwargs}, |
| split='train[:100]+train[-100:]', |
| ) |
| test_set = load_dataset( |
| **{k: v for k, v in dataset_config.items() if k not in ignore_kwargs}, |
| split='train[:100]+train[-100:]', |
| ) |
|
|
| |
| train_set = train_set.map(partial(template_and_tokenize, tokenizer=tokenizer, |
| include_label=True, system_prompt=system_prompt), |
| remove_columns=list(train_set.features), |
| load_from_cache_file=load_from_cache_file) |
| val_set = val_set.map(partial(template_and_tokenize, tokenizer=tokenizer, |
| include_label=True, system_prompt=system_prompt), |
| remove_columns=list(val_set.features), |
| load_from_cache_file=load_from_cache_file) |
| test_set = test_set.map(partial(template_and_tokenize, tokenizer=tokenizer, |
| include_label=False, system_prompt=system_prompt), |
| remove_columns=list(test_set.features), |
| load_from_cache_file=load_from_cache_file) |
|
|
| |
| if concat_data: |
| train_set = ConcatDataset(train_set, chunk_size=input_len) |
| val_set = ConcatDataset(val_set, chunk_size=input_len) |
| |
| |
| dataloaders = { |
| 'train': get_lm_loader(train_set, tokenizer, 'train', input_len, **loader_kwargs), |
| 'validation': get_lm_loader(val_set, tokenizer, 'validation', input_len, **loader_kwargs), |
| 'test': get_seq2seq_loader(test_set, tokenizer, 'test', **loader_kwargs), |
| } |
| |
| metric = load_metric(download_metric(), 'gov_report') |
| |
| |
| for k, v in dataloaders.items(): |
| dataloaders[k].dataset.tokenizer = tokenizer |
| dataloaders[k].dataset.metric = metric |
| return dataloaders |
|
|
|
|
| def template_and_tokenize(sample, tokenizer, include_label: bool = True, |
| system_prompt: str = None): |
| if system_prompt is None: |
| system_prompt = SYSTEM_PROMPT |
|
|
| prompt = sample['instruction'] |
| if sample['input'] != '': |
| prompt += f"\n\n{sample['input']}" |
| |
| messages = [ |
| {"role": "system", "content": system_prompt}, |
| ] if system_prompt != '' else [] |
| messages.append({"role": "user", "content": prompt}) |
| prompt_ids = tokenizer.apply_chat_template( |
| messages, tokenize=True, add_generation_prompt=True, |
| ) |
| if include_label: |
| answer = encode_response(sample['output'], tokenizer) |
| else: |
| answer = [] |
| target = encode_response(sample['output'], tokenizer) |
| |
| input_ids = prompt_ids + answer |
| attn_mask = [1] * len(input_ids) |
| sample = { |
| "input_ids": input_ids, |
| "attention_mask" : attn_mask, |
| "labels": [-100] * len(prompt_ids) + answer if include_label else target, |
| } |
| return sample |
|
|
|
|