| | |
| | import torch |
| | from torch.utils.data import Dataset, DataLoader |
| | from torch.nn.utils.rnn import pad_sequence |
| |
|
| | |
| | from transformers import AutoTokenizer |
| |
|
| | import math |
| | import os |
| | from tqdm import tqdm |
| | import pickle |
| |
|
| | TRAIN_PATH_10M = '01-data/clean_train_10M' |
| | DATASETS = ['bnc_spoken', 'childes', 'gutenberg', 'open_subtitles', 'simple_wiki', 'switchboard'] |
| |
|
| | class FullBabyLMDataset(Dataset): |
| | def __init__(self, cfg, pretokenized_data=None): |
| | tokenizer_path = cfg["tokenizer_dir"] |
| |
|
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | tokenizer_path, |
| | trust_remote_code=True, |
| | local_files_only=True |
| | ) |
| |
|
| | |
| | self.model_bos = self.tokenizer.bos_token_id |
| | self.model_eos = self.tokenizer.eos_token_id |
| | self.model_pad = self.tokenizer.pad_token_id |
| |
|
| | if pretokenized_data is not None: |
| | self.data = pretokenized_data |
| | return |
| |
|
| | |
| | self.data = [] |
| | dataset_folder = TRAIN_PATH_10M |
| |
|
| | for dataset in DATASETS: |
| | dataset_path = os.path.join(dataset_folder, f'{dataset}.train') |
| | with open(dataset_path, 'r', encoding='utf-8') as f: |
| | all_text = ' '.join(f.readlines()) |
| | print(f'Opened {dataset_path}') |
| |
|
| | |
| | tokenized_dataset = self.tokenizer([all_text])['input_ids'][0] |
| | print(f'Tokenized {dataset_path}; {len(tokenized_dataset)} tokens total') |
| |
|
| | |
| | chunk_size = cfg["datapoint_length"] |
| | num_chunks = math.ceil(len(tokenized_dataset) / chunk_size) |
| | for curr_chunk in tqdm(range(num_chunks), desc=f"Chunking {dataset}"): |
| | start = curr_chunk * chunk_size |
| | end = (curr_chunk + 1) * chunk_size |
| | chunk_tokens = tokenized_dataset[start:end] |
| | if isinstance(chunk_tokens, torch.Tensor): |
| | chunk_tokens = chunk_tokens.tolist() |
| | self.data.append(chunk_tokens) |
| | print(f"Chunked {dataset_path}") |
| |
|
| | def __len__(self): |
| | return len(self.data) |
| |
|
| | def __getitem__(self, idx): |
| | |
| | return torch.as_tensor([self.model_bos] + self.data[idx] + [self.model_eos], dtype=torch.long) |
| |
|
| | |
| | def load_babylm_data(cfg): |
| | num_words = "100M" if cfg["training_type"] == "strict" else "10M" |
| | cache_dir = '01-data/cached_train' |
| | os.makedirs(cache_dir, exist_ok=True) |
| | filename = os.path.join(cache_dir, f'train_gpt2_{num_words}.pkl') |
| |
|
| | |
| | if os.path.exists(filename): |
| | with open(filename, 'rb') as f: |
| | token_chunks = pickle.load(f) |
| | full_babylm_dset = FullBabyLMDataset(cfg, pretokenized_data=token_chunks) |
| | else: |
| | tmp_dataset = FullBabyLMDataset(cfg) |
| | with open(filename, 'wb') as f: |
| | pickle.dump(tmp_dataset.data, f) |
| | full_babylm_dset = tmp_dataset |
| |
|
| | collate_fn = get_collate_fn(full_babylm_dset.model_eos, full_babylm_dset.model_pad) |
| | dataloader = DataLoader( |
| | full_babylm_dset, |
| | batch_size=cfg["batch_size"], |
| | shuffle=True, |
| | collate_fn=collate_fn, |
| | num_workers=0, |
| | pin_memory=False |
| | ) |
| | return dataloader |
| |
|
| | def get_collate_fn(model_eos, model_pad): |
| | def collate_fn(batch): |
| | tokens = pad_sequence(batch, padding_value=model_pad, batch_first=True) |
| | input_tokens = tokens[:, :-1] |
| | target_tokens = tokens[:, 1:] |
| | target_mask = input_tokens != model_pad |
| | |
| | target_mask[:, 0] = True |
| | return input_tokens, target_tokens, target_mask |
| | return collate_fn |
| |
|