File size: 18,002 Bytes
01fcd60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
"""
Dataset Tokenizer
=================
Tokenizes datasets and saves as binary files for mmap loading.

Usage:
    # HuggingFace dataset
    python -m freqformer.tokenize_data --dataset wikitext --subset wikitext-2-raw-v1 --out data/wiki2

    # Custom text files
    python -m freqformer.tokenize_data --files train.txt val.txt --out data/custom

    # Append more data to existing binary files (does not overwrite)
    python -m freqformer.tokenize_data --dataset openwebtext --out data/mixed --append
    python -m freqformer.tokenize_data --dataset allenai/dolma3_pool --out data/mixed --append --streaming
"""

import argparse
import os
import re
import struct
import numpy as np
from pathlib import Path
from transformers import AutoTokenizer


HEADER_MAGIC = b"FREQTOK1"  # 8 bytes magic
HEADER_SIZE = 32  # magic(8) + version(4) + num_tokens(8) + seq_len(4) + vocab_size(4) + reserved(4)


def write_binary(tokens: list[int], path: str, vocab_size: int):
    """Write tokenized data as binary file with header."""
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    arr = np.array(tokens, dtype=np.uint16 if vocab_size < 65536 else np.uint32)
    with open(path, "wb") as f:
        # Header
        f.write(HEADER_MAGIC)
        f.write(struct.pack("<I", 1))           # version
        f.write(struct.pack("<Q", len(tokens)))  # num_tokens
        f.write(struct.pack("<I", 0))            # seq_len (0 = not chunked)
        f.write(struct.pack("<I", vocab_size))   # vocab_size
        f.write(struct.pack("<I", 0))            # reserved
        # Data
        f.write(arr.tobytes())
    size_mb = os.path.getsize(path) / 1024 / 1024
    print(f"  Saved {path}: {len(tokens):,} tokens ({size_mb:.1f} MB)")


def append_binary(tokens: list[int], path: str, vocab_size: int):
    """Append tokens to an existing binary file, or create if it doesn't exist."""
    if not os.path.exists(path):
        write_binary(tokens, path, vocab_size)
        return

    # Read existing header
    with open(path, "rb") as f:
        magic = f.read(8)
        if magic != HEADER_MAGIC:
            raise ValueError(f"Cannot append: {path} is not a valid FREQTOK1 file")
        _version = struct.unpack("<I", f.read(4))[0]
        old_num_tokens = struct.unpack("<Q", f.read(8))[0]
        _seq_len = struct.unpack("<I", f.read(4))[0]
        old_vocab_size = struct.unpack("<I", f.read(4))[0]

    if old_vocab_size != vocab_size:
        raise ValueError(
            f"Cannot append: vocab_size mismatch ({old_vocab_size} in file vs {vocab_size} new). "
            f"Use the same tokenizer."
        )

    dtype = np.uint16 if old_vocab_size < 65536 else np.uint32
    max_id = max(tokens) if tokens else 0
    if dtype == np.uint16 and max_id >= 65535:
        raise ValueError(
            f"Cannot append: token ID {max_id} exceeds uint16 range but existing file uses uint16. "
            f"Recreate the file with a larger vocab_size tokenizer."
        )
    arr = np.array(tokens, dtype=dtype)
    new_total = old_num_tokens + len(tokens)

    # Append data and update header
    with open(path, "r+b") as f:
        # Update num_tokens in header (offset 12 = magic(8) + version(4))
        f.seek(12)
        f.write(struct.pack("<Q", new_total))
        # Append data at end
        f.seek(0, 2)  # seek to end
        f.write(arr.tobytes())

    size_mb = os.path.getsize(path) / 1024 / 1024
    print(f"  Appended to {path}: +{len(tokens):,} tokens (total: {new_total:,}, {size_mb:.1f} MB)")


def _init_binary(path: str, vocab_size: int):
    """Create a fresh binary file with header (zero tokens). Returns dtype."""
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    dtype = np.uint16 if vocab_size < 65536 else np.uint32
    with open(path, "wb") as f:
        f.write(HEADER_MAGIC)
        f.write(struct.pack("<I", 1))            # version
        f.write(struct.pack("<Q", 0))            # num_tokens (updated later)
        f.write(struct.pack("<I", 0))            # seq_len
        f.write(struct.pack("<I", vocab_size))   # vocab_size
        f.write(struct.pack("<I", 0))            # reserved
    return dtype


def tokenize_hf_dataset(
    dataset_name: str, subset: str, tokenizer_name: str, out_dir: str,
    max_samples: int = 0, split: str = "train", val_ratio: float = 0.02,
    streaming: bool = False, append: bool = False, num_proc: int = 0,
    max_tokens: int = 0, seed: int = 1234, skip_errors: bool = False,
):
    """Tokenize a HuggingFace dataset.

    Non-streaming: uses datasets.map() with num_proc for parallel tokenization.
    Streaming: batched tokenization with periodic disk flushes (low RAM).
    """
    from datasets import load_dataset
    import multiprocessing

    label = f"{dataset_name}/{subset}" if subset else dataset_name
    print(f"Loading dataset: {label} (split={split}, streaming={streaming})")
    if max_samples:
        print(f"  Max samples: {max_samples:,}")
    if max_tokens:
        print(f"  Max tokens: {max_tokens:,}")

    ds_kwargs = {}
    if subset:
        ds_kwargs["name"] = subset
    if streaming:
        ds_kwargs["streaming"] = True

    data_files = None
    dataset_path = Path(dataset_name)
    if streaming and dataset_path.exists() and dataset_path.is_dir():
        patterns = ["*.jsonl", "*.jsonl.zst", "*.json", "*.json.gz"]
        files = []
        for pattern in patterns:
            files.extend(dataset_path.rglob(pattern))
        if files:
            data_files = sorted({str(p) for p in files})

    def _load_streaming_dataset(exclude_files: set[str] | None = None):
        if data_files:
            if exclude_files:
                files = [f for f in data_files if f not in exclude_files]
            else:
                files = data_files
            if not files:
                raise ValueError("All data files were excluded; cannot continue streaming.")
            return load_dataset("json", data_files=files, split=split, streaming=True)
        return load_dataset(dataset_name, split=split, **ds_kwargs)

    ds = _load_streaming_dataset()

    print(f"Loading tokenizer: {tokenizer_name}")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    vocab_size = len(tokenizer)  # includes added special tokens

    if num_proc <= 0:
        num_proc = min(multiprocessing.cpu_count(), 16)

    # Detect text column
    if streaming:
        first = next(iter(ds))
        text_col = "text" if "text" in first else list(first.keys())[0]
    else:
        text_col = "text" if "text" in ds.column_names else ds.column_names[0]

    print(f"  text_col='{text_col}', vocab_size={vocab_size}, num_proc={num_proc}")

    def _read_existing_count(path: str) -> int:
        with open(path, "rb") as f:
            magic = f.read(8)
            if magic != HEADER_MAGIC:
                raise ValueError(f"{path} is not a valid FREQTOK1 file")
            _version = struct.unpack("<I", f.read(4))[0]
            num_tokens = struct.unpack("<Q", f.read(8))[0]
            _seq_len = struct.unpack("<I", f.read(4))[0]
            file_vocab = struct.unpack("<I", f.read(4))[0]
        if file_vocab != vocab_size:
            raise ValueError(
                f"Cannot append: vocab_size mismatch ({file_vocab} in file vs {vocab_size} new). "
                f"Use the same tokenizer."
            )
        return num_tokens

    # ---- Non-streaming: parallel tokenization via datasets.map() ----
    if not streaming:
        if max_samples:
            ds = ds.select(range(min(max_samples, len(ds))))

        def _tok_batch(batch):
            encoded = tokenizer(batch[text_col], add_special_tokens=False)
            return {"input_ids": encoded["input_ids"]}

        print(f"\nTokenizing with {num_proc} processes...")
        ds_tok = ds.map(
            _tok_batch, batched=True, batch_size=10000,
            num_proc=num_proc, remove_columns=ds.column_names,
            desc="Tokenizing",
        )

        # Direct disk writes (constant RAM)
        dtype = np.uint16 if vocab_size < 65536 else np.uint32
        train_path = os.path.join(out_dir, "train.bin")
        val_path = os.path.join(out_dir, "val.bin")

        if not append:
            _init_binary(train_path, vocab_size)
            if val_ratio > 0:
                _init_binary(val_path, vocab_size)
        else:
            if not os.path.exists(train_path):
                _init_binary(train_path, vocab_size)
            if val_ratio > 0 and not os.path.exists(val_path):
                _init_binary(val_path, vocab_size)

        train_written = _read_existing_count(train_path) if append else 0
        val_written = _read_existing_count(val_path) if (append and val_ratio > 0) else 0

        rng = np.random.default_rng(seed)
        train_f = open(train_path, "r+b")
        train_f.seek(0, 2)
        val_f = None
        if val_ratio > 0:
            val_f = open(val_path, "r+b")
            val_f.seek(0, 2)

        total_tokens = 0
        row_count = 0
        log_every = 100000
        try:
            for row in ds_tok:
                ids = row["input_ids"]
                if not ids:
                    continue
                arr = np.array(ids, dtype=dtype)
                if val_ratio > 0 and rng.random() < val_ratio:
                    val_f.write(arr.tobytes())
                    val_written += len(arr)
                else:
                    train_f.write(arr.tobytes())
                    train_written += len(arr)
                total_tokens = train_written + val_written
                row_count += 1
                if row_count % log_every == 0:
                    print(f"  {row_count:>10,} docs | {total_tokens:>12,} tokens "
                          f"(train: {train_written:,}, val: {val_written:,})")
                if max_tokens and total_tokens >= max_tokens:
                    break

            train_f.seek(12)
            train_f.write(struct.pack("<Q", train_written))
            if val_ratio > 0:
                val_f.seek(12)
                val_f.write(struct.pack("<Q", val_written))
        finally:
            train_f.close()
            if val_f is not None:
                val_f.close()

        print(f"  Total: {train_written + val_written:,} tokens")
        print(f"  Train: {train_written:,} tokens, Val: {val_written:,} tokens")
        return

    # ---- Streaming: per-batch disk writes (constant RAM) ----
    dtype = np.uint16 if vocab_size < 65536 else np.uint32
    train_path = os.path.join(out_dir, "train.bin")
    val_path = os.path.join(out_dir, "val.bin")

    if append:
        if not os.path.exists(train_path):
            _init_binary(train_path, vocab_size)
        if val_ratio > 0 and not os.path.exists(val_path):
            _init_binary(val_path, vocab_size)
        train_written = _read_existing_count(train_path)
        val_written = _read_existing_count(val_path) if val_ratio > 0 else 0
    else:
        _init_binary(train_path, vocab_size)
        train_written = 0
        if val_ratio > 0:
            _init_binary(val_path, vocab_size)
        val_written = 0

    count = 0
    batch_texts = []
    batch_size = 10000
    log_every = 100000

    rng = np.random.default_rng(seed)
    train_f = open(train_path, "r+b")
    train_f.seek(0, 2)
    val_f = None
    if val_ratio > 0:
        val_f = open(val_path, "r+b")
        val_f.seek(0, 2)

    def _write_ids(ids_list: list[list[int]]):
        nonlocal train_written, val_written
        for ids in ids_list:
            arr = np.array(ids, dtype=dtype)
            if val_ratio > 0 and rng.random() < val_ratio:
                val_f.write(arr.tobytes())
                val_written += len(arr)
            else:
                train_f.write(arr.tobytes())
                train_written += len(arr)

    print(f"\nTokenizing (streaming, batch_size={batch_size:,}, direct writes)...")
    exclude_files: set[str] = set()
    skip_errors_enabled = skip_errors and data_files is not None

    def _extract_bad_file(err: Exception) -> str | None:
        msg = str(err)
        match = re.search(r"file '.*::([^']+)'", msg)
        if match:
            return match.group(1)
        match = re.search(r"::([^'\s]+)" , msg)
        if match:
            return match.group(1)
        return None

    try:
        while True:
            if skip_errors_enabled and exclude_files:
                ds = _load_streaming_dataset(exclude_files)
            if skip_errors_enabled and count > 0:
                ds = ds.skip(count)
            try:
                for example in ds:
                    if max_samples and count >= max_samples:
                        break
                    if max_tokens and (train_written + val_written) >= max_tokens:
                        break
                    text = example[text_col]
                    if not text or not text.strip():
                        continue
                    batch_texts.append(text)
                    count += 1

                    if len(batch_texts) >= batch_size:
                        encoded = tokenizer(
                            batch_texts,
                            add_special_tokens=False,
                            return_attention_mask=False,
                            return_token_type_ids=False,
                        )
                        _write_ids(encoded["input_ids"])
                        batch_texts = []

                    if count % log_every == 0:
                        total_tokens = train_written + val_written
                        print(
                            f"  {count:>10,} samples | {total_tokens:>12,} tokens "
                            f"(train: {train_written:,}, val: {val_written:,})"
                        )
                break
            except Exception as err:
                if not skip_errors_enabled:
                    raise
                bad_file = _extract_bad_file(err)
                if not bad_file or bad_file in exclude_files:
                    raise
                print(f"  Warning: failed to read {bad_file}; skipping and resuming...")
                exclude_files.add(bad_file)
                continue
        # Flush remaining
        if batch_texts:
            encoded = tokenizer(
                batch_texts,
                add_special_tokens=False,
                return_attention_mask=False,
                return_token_type_ids=False,
            )
            _write_ids(encoded["input_ids"])

        train_f.seek(12)
        train_f.write(struct.pack("<Q", train_written))
        if val_ratio > 0:
            val_f.seek(12)
            val_f.write(struct.pack("<Q", val_written))
    finally:
        train_f.close()
        if val_f is not None:
            val_f.close()

    print(f"  Total: {count:,} samples, {train_written + val_written:,} tokens written")
    print(f"  Train: {train_written:,} tokens, Val: {val_written:,} tokens")


def tokenize_files(files: list[str], tokenizer_name: str, out_dir: str, append: bool = False):
    """Tokenize raw text files."""
    print(f"Loading tokenizer: {tokenizer_name}")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    vocab_size = len(tokenizer)  # includes added special tokens
    write_fn = append_binary if append else write_binary

    for fpath in files:
        print(f"\nTokenizing {fpath}...")
        with open(fpath, "r", encoding="utf-8") as f:
            text = f.read()

        tokens = tokenizer.encode(text, add_special_tokens=False)
        stem = Path(fpath).stem
        write_fn(tokens, os.path.join(out_dir, f"{stem}.bin"), vocab_size)


def main():
    parser = argparse.ArgumentParser(description="Tokenize datasets for FreqFormer training")
    parser.add_argument("--dataset", type=str, default=None, help="HuggingFace dataset name")
    parser.add_argument("--subset", type=str, default=None, help="Dataset subset/config")
    parser.add_argument("--files", nargs="+", default=None, help="Raw text files to tokenize")
    parser.add_argument("--tokenizer", type=str, default="freqformer/tokenizer", help="Tokenizer name/path")
    parser.add_argument("--out", type=str, required=True, help="Output directory")
    parser.add_argument("--max_samples", type=int, default=0, help="Max samples to take (0=all)")
    parser.add_argument("--split", type=str, default="train", help="Dataset split to use")
    parser.add_argument("--val_ratio", type=float, default=0.02, help="Fraction for validation split")
    parser.add_argument("--streaming", action="store_true", help="Use streaming mode for large datasets")
    parser.add_argument("--append", action="store_true", help="Append to existing binary files instead of overwriting")
    parser.add_argument("--num_proc", type=int, default=0, help="Number of CPU processes for tokenization (0=auto)")
    parser.add_argument("--max_tokens", type=int, default=0, help="Max tokens to produce (0=all)")
    parser.add_argument("--seed", type=int, default=1234, help="RNG seed for streaming val split")
    parser.add_argument("--skip_errors", action="store_true", help="Skip corrupted JSON files in streaming mode")
    args = parser.parse_args()

    if args.append:
        print("Mode: APPEND (will add to existing data files)")

    if args.dataset:
        tokenize_hf_dataset(
            args.dataset, args.subset, args.tokenizer, args.out,
            max_samples=args.max_samples, split=args.split,
            val_ratio=args.val_ratio, streaming=args.streaming,
            append=args.append, num_proc=args.num_proc,
            max_tokens=args.max_tokens,
            seed=args.seed,
            skip_errors=args.skip_errors,
        )
    elif args.files:
        tokenize_files(args.files, args.tokenizer, args.out, append=args.append)
    else:
        parser.error("Specify --dataset or --files")

    print(f"\nDone! Data saved to {args.out}/")
    print(f"Use --data_dir {args.out} when training.")


if __name__ == "__main__":
    main()