File size: 4,294 Bytes
b08e9da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""Streaming packed-token loader for the tiny-vocab physics MoE.

Yields contiguous `input_ids` blocks of length `seq_len+1` (so the trainer can
split into input[:-1]/labels[1:]). Scenes are serialized with the reduced
sim-only serialization, encoded with the trained tokenizer (which adds
<bos>..<eos> via its post-processor), then concatenated and chopped into
fixed blocks (standard causal-LM packing).

Train = interleave across the 24 train/ shards (all scenario types mixed).
Val   = a held-back slice from a few train-type val/ shards (same types,
        unseen scenes) so eval loss reflects in-distribution generalization.
"""
from __future__ import annotations
import json
from typing import Iterator, List

import numpy as np
import torch
from datasets import load_dataset, interleave_datasets
from tokenizers import Tokenizer

import physics_serialize as psz

REPO = "AlexWortega/physics-scenarios-packed"
TRAIN_TYPES = [
    "avalanche","basketball","billiards","breakout","bridge","chain","conveyor",
    "dominos","explosion","funnel","head_on","jenga","marble_run","orbit",
    "pendulum","pinball","plinko","projectile","pyramid","seesaw","ski_jump",
    "tower","wind","wrecking_ball",
]


def _serialize_row(row) -> str | None:
    lines = row["jsonl"].decode().splitlines()
    if not lines:
        return None
    try:
        header = json.loads(lines[0])
        frames = [json.loads(x) for x in lines[1:] if x.startswith("{")]
    except Exception:
        return None
    if not frames:
        return None
    return psz.serialize_scene(header, frames)


def _shard_mix(split: str, types: List[str], seed: int):
    shards = [f"{split}/{t}.tar.gz" for t in types]
    dss = [load_dataset(REPO, data_files={split: [s]}, split=split, streaming=True)
           for s in shards]
    if len(dss) == 1:
        return dss[0]
    return interleave_datasets(dss, seed=seed, stopping_strategy="all_exhausted")


def packed_token_stream(tokenizer_path: str, seq_len: int, split: str = "train",
                        seed: int = 0, types: List[str] | None = None,
                        shuffle_buffer: int = 0) -> Iterator[np.ndarray]:
    """Yield np.int64 arrays of length seq_len+1 (packed)."""
    tok = Tokenizer.from_file(tokenizer_path)
    types = types or TRAIN_TYPES
    mix = _shard_mix(split, types, seed)
    if shuffle_buffer > 0:
        mix = mix.shuffle(seed=seed, buffer_size=shuffle_buffer)
    block = seq_len + 1
    buf: List[int] = []
    for row in mix:
        s = _serialize_row(row)
        if s is None:
            continue
        ids = tok.encode(s).ids  # includes <bos>..<eos>
        buf.extend(ids)
        while len(buf) >= block:
            yield np.asarray(buf[:block], dtype=np.int64)
            buf = buf[block:]


def batch_iterator(tokenizer_path: str, seq_len: int, batch_size: int,
                   split: str = "train", seed: int = 0,
                   types: List[str] | None = None, shuffle_buffer: int = 0,
                   device: str = "cuda", infinite: bool = False):
    """Yield (input_ids[B,S], labels[B,S]) int64 cuda tensors."""
    while True:
        gen = packed_token_stream(tokenizer_path, seq_len, split, seed,
                                  types, shuffle_buffer)
        chunk: List[np.ndarray] = []
        for arr in gen:
            chunk.append(arr)
            if len(chunk) == batch_size:
                m = np.stack(chunk, axis=0)  # [B, S+1]
                inp = torch.from_numpy(m[:, :-1]).to(device, non_blocking=True)
                lbl = torch.from_numpy(m[:, 1:]).to(device, non_blocking=True)
                yield inp, lbl
                chunk = []
        if not infinite:
            return
        seed += 1  # reshuffle the interleave on each epoch


if __name__ == "__main__":
    import sys, time
    tp = sys.argv[1] if len(sys.argv) > 1 else "tokenizer.json"
    it = batch_iterator(tp, seq_len=1024, batch_size=4, split="train",
                        shuffle_buffer=0, device="cpu")
    t0 = time.time(); n = 0
    for inp, lbl in it:
        n += 1
        if n == 1:
            print("first batch", inp.shape, "min", int(inp.min()), "max", int(inp.max()))
        if n >= 20:
            break
    print(f"{n} batches in {time.time()-t0:.1f}s  ({n*4*1024} tokens)")