|
|
|
|
|
import torch |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
import math |
|
|
import os |
|
|
from tqdm import tqdm |
|
|
import pickle |
|
|
|
|
|
TRAIN_PATH_10M = '01-data/clean_train_10M' |
|
|
DATASETS = ['bnc_spoken', 'childes', 'gutenberg', 'open_subtitles', 'simple_wiki', 'switchboard'] |
|
|
|
|
|
class FullBabyLMDataset(Dataset): |
|
|
def __init__(self, cfg, pretokenized_data=None): |
|
|
tokenizer_path = cfg["tokenizer_dir"] |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
tokenizer_path, |
|
|
trust_remote_code=True, |
|
|
local_files_only=True |
|
|
) |
|
|
|
|
|
|
|
|
self.model_bos = self.tokenizer.bos_token_id |
|
|
self.model_eos = self.tokenizer.eos_token_id |
|
|
self.model_pad = self.tokenizer.pad_token_id |
|
|
|
|
|
if pretokenized_data is not None: |
|
|
self.data = pretokenized_data |
|
|
return |
|
|
|
|
|
|
|
|
self.data = [] |
|
|
dataset_folder = TRAIN_PATH_10M |
|
|
|
|
|
for dataset in DATASETS: |
|
|
dataset_path = os.path.join(dataset_folder, f'{dataset}.train') |
|
|
with open(dataset_path, 'r', encoding='utf-8') as f: |
|
|
all_text = ' '.join(f.readlines()) |
|
|
print(f'Opened {dataset_path}') |
|
|
|
|
|
|
|
|
tokenized_dataset = self.tokenizer([all_text])['input_ids'][0] |
|
|
print(f'Tokenized {dataset_path}; {len(tokenized_dataset)} tokens total') |
|
|
|
|
|
|
|
|
chunk_size = cfg["datapoint_length"] |
|
|
num_chunks = math.ceil(len(tokenized_dataset) / chunk_size) |
|
|
for curr_chunk in tqdm(range(num_chunks), desc=f"Chunking {dataset}"): |
|
|
start = curr_chunk * chunk_size |
|
|
end = (curr_chunk + 1) * chunk_size |
|
|
chunk_tokens = tokenized_dataset[start:end] |
|
|
if isinstance(chunk_tokens, torch.Tensor): |
|
|
chunk_tokens = chunk_tokens.tolist() |
|
|
self.data.append(chunk_tokens) |
|
|
print(f"Chunked {dataset_path}") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
return torch.as_tensor([self.model_bos] + self.data[idx] + [self.model_eos], dtype=torch.long) |
|
|
|
|
|
|
|
|
def load_babylm_data(cfg): |
|
|
num_words = "100M" if cfg["training_type"] == "strict" else "10M" |
|
|
cache_dir = '01-data/cached_train' |
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
filename = os.path.join(cache_dir, f'train_gpt2_{num_words}.pkl') |
|
|
|
|
|
|
|
|
if os.path.exists(filename): |
|
|
with open(filename, 'rb') as f: |
|
|
token_chunks = pickle.load(f) |
|
|
full_babylm_dset = FullBabyLMDataset(cfg, pretokenized_data=token_chunks) |
|
|
else: |
|
|
tmp_dataset = FullBabyLMDataset(cfg) |
|
|
with open(filename, 'wb') as f: |
|
|
pickle.dump(tmp_dataset.data, f) |
|
|
full_babylm_dset = tmp_dataset |
|
|
|
|
|
collate_fn = get_collate_fn(full_babylm_dset.model_eos, full_babylm_dset.model_pad) |
|
|
dataloader = DataLoader( |
|
|
full_babylm_dset, |
|
|
batch_size=cfg["batch_size"], |
|
|
shuffle=True, |
|
|
collate_fn=collate_fn, |
|
|
num_workers=0, |
|
|
pin_memory=False |
|
|
) |
|
|
return dataloader |
|
|
|
|
|
def get_collate_fn(model_eos, model_pad): |
|
|
def collate_fn(batch): |
|
|
tokens = pad_sequence(batch, padding_value=model_pad, batch_first=True) |
|
|
input_tokens = tokens[:, :-1] |
|
|
target_tokens = tokens[:, 1:] |
|
|
target_mask = input_tokens != model_pad |
|
|
|
|
|
target_mask[:, 0] = True |
|
|
return input_tokens, target_tokens, target_mask |
|
|
return collate_fn |
|
|
|