File size: 6,352 Bytes
c39435c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# -*- coding: utf-8 -*-

from __future__ import annotations

import argparse
from itertools import chain
from typing import Any, Dict, List, Optional

import torch
from datasets import load_dataset,load_from_disk
from transformers import AutoTokenizer
from transformers.utils import logging

logger = logging.get_logger(__name__)


def tokenize(
    examples: Dict[str, List[Any]],
    tokenizer: AutoTokenizer,
    seq_len: int = 2048,
    ctx_len: int = None,
    return_offsets: bool = False
) -> Dict[str, List[List[int]]]:
    """
    Tokenize the input text and split into chunks of specified context length.

    Args:
        examples:
            Dictionary containing the input text.
        tokenizer:
            Initialized tokenizer.
        seq_len:
            Total sequence length for each training sample. Default: 2048.
        ctx_len:
            Max contiguous length to preserve (will not be split). Default: `None`.
        return_offsets:
            Return cumulative offsets for concatenated inputs. Default: `False`.

    Returns:
        Dictionary containing tokenized and chunked input ids, and optionally offsets.
    """
    text = examples['text']
    input_ids = tokenizer(text)['input_ids']
    # further split each input into chunks of length `ctx_len` if provided
    if ctx_len is not None:
        input_ids = [seq[i:i+ctx_len] for seq in input_ids for i in range(0, len(seq), ctx_len)]
    lens = torch.tensor([len(seq) for seq in input_ids]).cumsum(0)
    total_len = lens[-1] // seq_len * seq_len

    input_ids = list(chain(*input_ids))
    # each yielded sample is of length `seq_len`
    input_ids = [input_ids[i:i+seq_len] for i in range(0, total_len, seq_len)]

    if not return_offsets:
        return {'input_ids': input_ids}

    # insert boundaries into cumulative offsets
    offsets = torch.cat((lens, torch.arange(0, total_len, seq_len))).unique().sort()[0] % seq_len
    # split offsets according the start positions
    offsets = [i.tolist() + [seq_len] for i in offsets.tensor_split(torch.where(offsets.eq(0))[0][1:])][:len(input_ids)]
    return {'input_ids': input_ids, 'offsets': offsets}


def preprocess(
    dataset: str,
    name: Optional[str] = None,
    split: str = 'train',
    seed: int = 42,
    output: str = 'data',
    tokenizer: str = 'fla-hub/gla-1.3B-100B',
    num_proc: int = 64,
    batch_size: int = 2048,
    seq_len: int = 2048,
    ctx_len: int = None,
    return_offsets: bool = False
) -> None:
    """
    Load, tokenize, and save the processed dataset.

    Args:
        dataset:
            Path or name of the dataset. Default: 'HuggingFaceFW/fineweb-edu'.
        name:
            Name of the dataset configuration. Default: `None`.
        split:
            Dataset split to process. Default: 'train'.
        seed:
            Random seed for shuffling the dataset. Default: 42.
        output:
            Output directory. Default: 'data'.
        tokenizer:
            Tokenizer name. Default: 'fla-hub/gla-1.3B-100B'.
        num_proc:
            Number of processes for parallel processing. Default: 64.
        batch_size:
            Batch size for processing. Default: 2048.
        seq_len:
            Total sequence length for each training sample. Default: 2048.
        ctx_len:
            Max contiguous length to preserve (will not be split). Default: `None`.
        return_offsets:
            Return cumulative offsets for concatenated inputs. Default: `False`.
    """
    tokenized_path = f'{output}/{name}/{split}' 

    if ctx_len is not None and ctx_len > seq_len:
        raise ValueError(f'ctx_len ({ctx_len}) must be less than or equal to seq_len ({seq_len})')

    logger.info(f'Loading tokenizer {tokenizer}')
    tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True)
    logger.info(f'Tokenizer initialized:\n {tokenizer}')

    logger.info(f'Loading dataset: {dataset}')
    dataset = load_dataset(path=dataset,split='train')
    # dataset = load_from_disk(dataset)
    print('done_load')
    dataset = dataset.shuffle(seed=seed)
    logger.info(f'Dataset loaded: {dataset}')
    print((dataset))
    print(next(iter(dataset)).keys())
    remove_columns = list(next(iter(dataset)).keys())
    logger.info(f'Tokenizing and processing the dataset with batch size {batch_size}')
    dataset = dataset.map(
        lambda examples: tokenize(examples, tokenizer, seq_len, ctx_len, return_offsets),
        batched=True,
        batch_size=batch_size,
        remove_columns=remove_columns,
        num_proc=num_proc,
        desc="Running tokenizer on dataset"
    )

    logger.info(f'Saving processed dataset to {tokenized_path}')
    dataset.save_to_disk(tokenized_path, num_proc=num_proc)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Preprocess and tokenize dataset")
    parser.add_argument("--dataset", default="HuggingFaceFW/fineweb-edu", help="Path or name of the dataset")
    parser.add_argument("--name", default=None, help="Name of the dataset configuration")
    parser.add_argument("--split", default="train", help="Dataset split to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--output", default="data", help="Output directory")
    parser.add_argument("--tokenizer", default="fla-hub/gla-1.3B-100B", help="Tokenizer name")
    parser.add_argument("--num_proc", type=int, default=64, help="Number of processes for parallel processing")
    parser.add_argument("--batch_size", type=int, default=2048, help="Batch size for processing")
    parser.add_argument("--seq_len", type=int, default=2048, help="Total sequence length for each training sample")
    parser.add_argument("--ctx_len", type=int, default=None, help="Max contiguous length to preserve (will not be split)")
    parser.add_argument("--return_offsets", action="store_true", help="Return cumulative offsets for concatenated inputs")
    args = parser.parse_args()

    preprocess(
        dataset=args.dataset,
        name=args.name,
        split=args.split,
        seed=args.seed,
        output=args.output,
        tokenizer=args.tokenizer,
        num_proc=args.num_proc,
        batch_size=args.batch_size,
        seq_len=args.seq_len,
        ctx_len=args.ctx_len,
        return_offsets=args.return_offsets
    )