| """Data pipeline: shard writing/verification + BinStream correctness.""" |
|
|
| import os |
| import json |
|
|
| import numpy as np |
| import pytest |
| import torch |
|
|
| from matilda.data import ( |
| ShardWriter, verify_manifest, shard_paths, BinStream, DTYPE, |
| ) |
|
|
|
|
| def test_shardwriter_roundtrip_and_manifest(tmp_path): |
| w = ShardWriter(str(tmp_path), shard_tokens=100) |
| |
| w.add(list(range(0, 120))) |
| w.add(list(range(120, 250))) |
| manifest = w.close(meta={"tokenizer": "gpt2"}) |
|
|
| assert manifest["total_tokens"] == 250 |
| assert [s["tokens"] for s in manifest["shards"]] == [100, 100, 50] |
| assert manifest["tokenizer"] == "gpt2" |
| assert verify_manifest(str(tmp_path)) is True |
|
|
| |
| rebuilt = np.concatenate( |
| [np.fromfile(p, dtype=DTYPE) for p in shard_paths(str(tmp_path))]) |
| assert rebuilt.tolist() == list(range(250)) |
|
|
|
|
| def test_verify_detects_corruption(tmp_path): |
| w = ShardWriter(str(tmp_path), shard_tokens=100) |
| w.add(list(range(150))) |
| w.close() |
| |
| p = shard_paths(str(tmp_path))[0] |
| data = bytearray(open(p, "rb").read()) |
| data[0] ^= 0xFF |
| open(p, "wb").write(data) |
| with pytest.raises(ValueError, match="checksum mismatch"): |
| verify_manifest(str(tmp_path)) |
|
|
|
|
| def _write_ramp_bin(path, n): |
| |
| np.arange(n, dtype=DTYPE).tofile(path) |
|
|
|
|
| def test_binstream_shift_is_next_token(tmp_path): |
| p = os.path.join(tmp_path, "ramp.bin") |
| _write_ramp_bin(p, 5000) |
| s = BinStream([p], batch_size=8, seq_len=32, seed=0) |
| x, y = s.next() |
| assert x.shape == (8, 32) and y.shape == (8, 32) |
| assert torch.equal(y, x + 1) |
|
|
|
|
| def test_binstream_resume_is_deterministic(tmp_path): |
| p = os.path.join(tmp_path, "ramp.bin") |
| _write_ramp_bin(p, 5000) |
| s = BinStream([p], batch_size=4, seq_len=16, seed=7) |
| snap = s.state_dict() |
| x1, y1 = s.next() |
| s.load_state_dict(snap) |
| x2, y2 = s.next() |
| assert torch.equal(x1, x2) and torch.equal(y1, y2) |
|
|
|
|
| def test_binstream_skips_too_small_shards(tmp_path): |
| |
| big = os.path.join(tmp_path, "big.bin") |
| tiny = os.path.join(tmp_path, "tiny.bin") |
| _write_ramp_bin(big, 5000) |
| _write_ramp_bin(tiny, 10) |
| s = BinStream([big, tiny], batch_size=4, seq_len=32, seed=0) |
| assert len(s.arrays) == 1 |
| x, y = s.next() |
| assert torch.equal(y, x + 1) |
|
|
|
|
| def test_binstream_multishard_weighting(tmp_path): |
| p1 = os.path.join(tmp_path, "a.bin") |
| p2 = os.path.join(tmp_path, "b.bin") |
| _write_ramp_bin(p1, 2000) |
| _write_ramp_bin(p2, 2000) |
| s = BinStream([p1, p2], batch_size=4, seq_len=8, seed=1) |
| x, y = s.next() |
| assert torch.equal(y, x + 1) |
|
|