PyTorch
gpt2
File size: 4,178 Bytes
c2760fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# File: data_utils.py
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# from tokenizer import ParadigmTokenizerWrapper  
from transformers import AutoTokenizer            

import math
import os
from tqdm import tqdm
import pickle

TRAIN_PATH_10M = '01-data/clean_train_10M'
DATASETS = ['bnc_spoken', 'childes', 'gutenberg', 'open_subtitles', 'simple_wiki', 'switchboard']

class FullBabyLMDataset(Dataset):
    def __init__(self, cfg, pretokenized_data=None):
        tokenizer_path = cfg["tokenizer_dir"]

        # Use HF loader so tokenizer_class + auto_map are honored
        self.tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_path,
            trust_remote_code=True,
            local_files_only=True
        )

        # Pull specials directly from the wrapper (it *is* a PreTrainedTokenizerFast)
        self.model_bos = self.tokenizer.bos_token_id
        self.model_eos = self.tokenizer.eos_token_id
        self.model_pad = self.tokenizer.pad_token_id

        if pretokenized_data is not None:
            self.data = pretokenized_data
            return

        # Tokenize, split and reconstruct each dataset
        self.data = []
        dataset_folder = TRAIN_PATH_10M  # using the 10M setting here

        for dataset in DATASETS:
            dataset_path = os.path.join(dataset_folder, f'{dataset}.train')
            with open(dataset_path, 'r', encoding='utf-8') as f:
                all_text = ' '.join(f.readlines())
            print(f'Opened {dataset_path}')

            # Tokenize in BATCH mode so indexing [0] is correct
            tokenized_dataset = self.tokenizer([all_text])['input_ids'][0]
            print(f'Tokenized {dataset_path}; {len(tokenized_dataset)} tokens total')

            # Chunk into datapoints
            chunk_size = cfg["datapoint_length"]
            num_chunks = math.ceil(len(tokenized_dataset) / chunk_size)
            for curr_chunk in tqdm(range(num_chunks), desc=f"Chunking {dataset}"):
                start = curr_chunk * chunk_size
                end = (curr_chunk + 1) * chunk_size
                chunk_tokens = tokenized_dataset[start:end]
                if isinstance(chunk_tokens, torch.Tensor):
                    chunk_tokens = chunk_tokens.tolist()
                self.data.append(chunk_tokens)
            print(f"Chunked {dataset_path}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Add BOS/EOS here (sequence length + 2)
        return torch.as_tensor([self.model_bos] + self.data[idx] + [self.model_eos], dtype=torch.long)

## General utilities ##
def load_babylm_data(cfg):
    num_words = "100M" if cfg["training_type"] == "strict" else "10M"
    cache_dir = '01-data/cached_train'
    os.makedirs(cache_dir, exist_ok=True)
    filename = os.path.join(cache_dir, f'train_gpt2_{num_words}.pkl')

    # Cache ONLY the tokenized chunks, not the Dataset object
    if os.path.exists(filename):
        with open(filename, 'rb') as f:
            token_chunks = pickle.load(f)
        full_babylm_dset = FullBabyLMDataset(cfg, pretokenized_data=token_chunks)
    else:
        tmp_dataset = FullBabyLMDataset(cfg)
        with open(filename, 'wb') as f:
            pickle.dump(tmp_dataset.data, f)
        full_babylm_dset = tmp_dataset

    collate_fn = get_collate_fn(full_babylm_dset.model_eos, full_babylm_dset.model_pad)
    dataloader = DataLoader(
        full_babylm_dset,
        batch_size=cfg["batch_size"],
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=0,         # set >0 if your env supports it
        pin_memory=False       # set True on GPUs if it helps
    )
    return dataloader

def get_collate_fn(model_eos, model_pad):
    def collate_fn(batch):
        tokens = pad_sequence(batch, padding_value=model_pad, batch_first=True)
        input_tokens = tokens[:, :-1]
        target_tokens = tokens[:, 1:]
        target_mask = input_tokens != model_pad
        # Ensure first position is always trainable
        target_mask[:, 0] = True
        return input_tokens, target_tokens, target_mask
    return collate_fn