# File: data_utils.py import torch from torch.utils.data import Dataset, DataLoader from torch.nn.utils.rnn import pad_sequence # from tokenizer import ParadigmTokenizerWrapper from transformers import AutoTokenizer import math import os from tqdm import tqdm import pickle TRAIN_PATH_10M = '01-data/clean_train_10M' DATASETS = ['bnc_spoken', 'childes', 'gutenberg', 'open_subtitles', 'simple_wiki', 'switchboard'] class FullBabyLMDataset(Dataset): def __init__(self, cfg, pretokenized_data=None): tokenizer_path = cfg["tokenizer_dir"] # Use HF loader so tokenizer_class + auto_map are honored self.tokenizer = AutoTokenizer.from_pretrained( tokenizer_path, trust_remote_code=True, local_files_only=True ) # Pull specials directly from the wrapper (it *is* a PreTrainedTokenizerFast) self.model_bos = self.tokenizer.bos_token_id self.model_eos = self.tokenizer.eos_token_id self.model_pad = self.tokenizer.pad_token_id if pretokenized_data is not None: self.data = pretokenized_data return # Tokenize, split and reconstruct each dataset self.data = [] dataset_folder = TRAIN_PATH_10M # using the 10M setting here for dataset in DATASETS: dataset_path = os.path.join(dataset_folder, f'{dataset}.train') with open(dataset_path, 'r', encoding='utf-8') as f: all_text = ' '.join(f.readlines()) print(f'Opened {dataset_path}') # Tokenize in BATCH mode so indexing [0] is correct tokenized_dataset = self.tokenizer([all_text])['input_ids'][0] print(f'Tokenized {dataset_path}; {len(tokenized_dataset)} tokens total') # Chunk into datapoints chunk_size = cfg["datapoint_length"] num_chunks = math.ceil(len(tokenized_dataset) / chunk_size) for curr_chunk in tqdm(range(num_chunks), desc=f"Chunking {dataset}"): start = curr_chunk * chunk_size end = (curr_chunk + 1) * chunk_size chunk_tokens = tokenized_dataset[start:end] if isinstance(chunk_tokens, torch.Tensor): chunk_tokens = chunk_tokens.tolist() self.data.append(chunk_tokens) print(f"Chunked {dataset_path}") def __len__(self): return len(self.data) def __getitem__(self, idx): # Add BOS/EOS here (sequence length + 2) return torch.as_tensor([self.model_bos] + self.data[idx] + [self.model_eos], dtype=torch.long) ## General utilities ## def load_babylm_data(cfg): num_words = "100M" if cfg["training_type"] == "strict" else "10M" cache_dir = '01-data/cached_train' os.makedirs(cache_dir, exist_ok=True) filename = os.path.join(cache_dir, f'train_gpt2_{num_words}.pkl') # Cache ONLY the tokenized chunks, not the Dataset object if os.path.exists(filename): with open(filename, 'rb') as f: token_chunks = pickle.load(f) full_babylm_dset = FullBabyLMDataset(cfg, pretokenized_data=token_chunks) else: tmp_dataset = FullBabyLMDataset(cfg) with open(filename, 'wb') as f: pickle.dump(tmp_dataset.data, f) full_babylm_dset = tmp_dataset collate_fn = get_collate_fn(full_babylm_dset.model_eos, full_babylm_dset.model_pad) dataloader = DataLoader( full_babylm_dset, batch_size=cfg["batch_size"], shuffle=True, collate_fn=collate_fn, num_workers=0, # set >0 if your env supports it pin_memory=False # set True on GPUs if it helps ) return dataloader def get_collate_fn(model_eos, model_pad): def collate_fn(batch): tokens = pad_sequence(batch, padding_value=model_pad, batch_first=True) input_tokens = tokens[:, :-1] target_tokens = tokens[:, 1:] target_mask = input_tokens != model_pad # Ensure first position is always trainable target_mask[:, 0] = True return input_tokens, target_tokens, target_mask return collate_fn