grandline / tests /test_pack.py
dignity045's picture
Initial GrandLine implementation: deterministic shard-first dataset preprocessing for LLM pretraining
ed59144 verified
"""Test packing block."""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from grandline.blocks.pack import PackBlock
from grandline.types import PackedSequence, TokenizedDocument
def make_tok_doc(token_ids: list[int], doc_id: str = "0") -> TokenizedDocument:
"""Helper to create a TokenizedDocument."""
return TokenizedDocument(
token_ids=token_ids,
num_tokens=len(token_ids),
doc_id=doc_id,
shard_id="shard_0",
)
def test_pack_basic():
"""Basic packing: multiple docs fit in one sequence."""
pack = PackBlock(max_seq_len=20, eos_id=99, pad_id=0)
docs = [
make_tok_doc([1, 2, 3], "0"), # 3 + 1 (eos) = 4 tokens
make_tok_doc([4, 5, 6, 7], "1"), # 4 + 1 = 5 tokens
make_tok_doc([8, 9], "2"), # 2 + 1 = 3 tokens
]
# Total: 4 + 5 + 3 = 12 tokens, fits in 20
results = list(pack(iter(docs)))
assert len(results) == 1
seq = results[0]
assert len(seq.input_ids) == 20
assert seq.seq_lens == [4, 5, 3]
assert seq.total_tokens == 12
# Check structure: [1,2,3,99, 4,5,6,7,99, 8,9,99, 0,0,0,0,0,0,0,0]
assert seq.input_ids[:4] == [1, 2, 3, 99]
assert seq.input_ids[4:9] == [4, 5, 6, 7, 99]
assert seq.input_ids[9:12] == [8, 9, 99]
assert all(t == 0 for t in seq.input_ids[12:])
def test_pack_overflow():
"""Docs that don't fit in current bin trigger a new bin."""
pack = PackBlock(max_seq_len=10, eos_id=99, pad_id=0)
docs = [
make_tok_doc([1, 2, 3, 4], "0"), # 5 tokens with eos
make_tok_doc([5, 6, 7, 8], "1"), # 5 tokens with eos → total 10, fits
make_tok_doc([9, 10], "2"), # 3 tokens → doesn't fit, new bin
]
results = list(pack(iter(docs)))
assert len(results) == 2
# First bin: [1,2,3,4,99, 5,6,7,8,99] = exactly 10
assert results[0].input_ids == [1, 2, 3, 4, 99, 5, 6, 7, 8, 99]
assert results[0].seq_lens == [5, 5]
# Second bin: [9,10,99, 0,0,0,0,0,0,0] = 3 tokens + 7 padding
assert results[1].input_ids == [9, 10, 99, 0, 0, 0, 0, 0, 0, 0]
assert results[1].seq_lens == [3]
def test_pack_long_document():
"""Documents longer than max_seq_len are split into chunks."""
pack = PackBlock(max_seq_len=5, eos_id=99, pad_id=0)
# 8 tokens + eos = 9 tokens, exceeds max_seq_len of 5
docs = [make_tok_doc([1, 2, 3, 4, 5, 6, 7, 8], "0")]
results = list(pack(iter(docs)))
# Should produce ceil(9/5) = 2 full chunks
assert len(results) >= 1
# First chunk should be exactly max_seq_len
assert len(results[0].input_ids) == 5
def test_pack_empty_input():
"""Empty input produces no output."""
pack = PackBlock(max_seq_len=10, eos_id=99, pad_id=0)
results = list(pack(iter([])))
assert results == []
def test_pack_single_doc():
"""Single document is packed correctly."""
pack = PackBlock(max_seq_len=10, eos_id=99, pad_id=0)
docs = [make_tok_doc([1, 2, 3], "0")] # 4 tokens with eos
results = list(pack(iter(docs)))
assert len(results) == 1
assert results[0].input_ids == [1, 2, 3, 99, 0, 0, 0, 0, 0, 0]
assert results[0].seq_lens == [4]
assert results[0].total_tokens == 4
def test_pack_exact_fit():
"""Document that exactly fills the bin."""
pack = PackBlock(max_seq_len=5, eos_id=99, pad_id=0)
docs = [make_tok_doc([1, 2, 3, 4], "0")] # 4 + 1 eos = exactly 5
results = list(pack(iter(docs)))
assert len(results) == 1
assert results[0].input_ids == [1, 2, 3, 4, 99]
assert results[0].total_tokens == 5
def test_pack_determinism():
"""Same input always produces same output."""
pack = PackBlock(max_seq_len=15, eos_id=99, pad_id=0)
def make_input():
return [
make_tok_doc([1, 2, 3], "0"),
make_tok_doc([4, 5], "1"),
make_tok_doc([6, 7, 8, 9], "2"),
]
result_1 = [(s.input_ids, s.seq_lens) for s in pack(iter(make_input()))]
result_2 = [(s.input_ids, s.seq_lens) for s in pack(iter(make_input()))]
assert result_1 == result_2
def test_pack_signature():
"""Pack block has informative signature."""
pack = PackBlock(max_seq_len=2048, eos_id=151645, pad_id=0)
sig = pack.signature
assert "pack" in sig
assert "2048" in sig
assert "151645" in sig
if __name__ == "__main__":
test_pack_basic()
test_pack_overflow()
test_pack_long_document()
test_pack_empty_input()
test_pack_single_doc()
test_pack_exact_fit()
test_pack_determinism()
test_pack_signature()
print("All pack tests passed!")