Initial GrandLine implementation: deterministic shard-first dataset preprocessing for LLM pretraining
ed59144 verified | """Test packing block.""" | |
| import sys | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).parent.parent / "src")) | |
| from grandline.blocks.pack import PackBlock | |
| from grandline.types import PackedSequence, TokenizedDocument | |
| def make_tok_doc(token_ids: list[int], doc_id: str = "0") -> TokenizedDocument: | |
| """Helper to create a TokenizedDocument.""" | |
| return TokenizedDocument( | |
| token_ids=token_ids, | |
| num_tokens=len(token_ids), | |
| doc_id=doc_id, | |
| shard_id="shard_0", | |
| ) | |
| def test_pack_basic(): | |
| """Basic packing: multiple docs fit in one sequence.""" | |
| pack = PackBlock(max_seq_len=20, eos_id=99, pad_id=0) | |
| docs = [ | |
| make_tok_doc([1, 2, 3], "0"), # 3 + 1 (eos) = 4 tokens | |
| make_tok_doc([4, 5, 6, 7], "1"), # 4 + 1 = 5 tokens | |
| make_tok_doc([8, 9], "2"), # 2 + 1 = 3 tokens | |
| ] | |
| # Total: 4 + 5 + 3 = 12 tokens, fits in 20 | |
| results = list(pack(iter(docs))) | |
| assert len(results) == 1 | |
| seq = results[0] | |
| assert len(seq.input_ids) == 20 | |
| assert seq.seq_lens == [4, 5, 3] | |
| assert seq.total_tokens == 12 | |
| # Check structure: [1,2,3,99, 4,5,6,7,99, 8,9,99, 0,0,0,0,0,0,0,0] | |
| assert seq.input_ids[:4] == [1, 2, 3, 99] | |
| assert seq.input_ids[4:9] == [4, 5, 6, 7, 99] | |
| assert seq.input_ids[9:12] == [8, 9, 99] | |
| assert all(t == 0 for t in seq.input_ids[12:]) | |
| def test_pack_overflow(): | |
| """Docs that don't fit in current bin trigger a new bin.""" | |
| pack = PackBlock(max_seq_len=10, eos_id=99, pad_id=0) | |
| docs = [ | |
| make_tok_doc([1, 2, 3, 4], "0"), # 5 tokens with eos | |
| make_tok_doc([5, 6, 7, 8], "1"), # 5 tokens with eos → total 10, fits | |
| make_tok_doc([9, 10], "2"), # 3 tokens → doesn't fit, new bin | |
| ] | |
| results = list(pack(iter(docs))) | |
| assert len(results) == 2 | |
| # First bin: [1,2,3,4,99, 5,6,7,8,99] = exactly 10 | |
| assert results[0].input_ids == [1, 2, 3, 4, 99, 5, 6, 7, 8, 99] | |
| assert results[0].seq_lens == [5, 5] | |
| # Second bin: [9,10,99, 0,0,0,0,0,0,0] = 3 tokens + 7 padding | |
| assert results[1].input_ids == [9, 10, 99, 0, 0, 0, 0, 0, 0, 0] | |
| assert results[1].seq_lens == [3] | |
| def test_pack_long_document(): | |
| """Documents longer than max_seq_len are split into chunks.""" | |
| pack = PackBlock(max_seq_len=5, eos_id=99, pad_id=0) | |
| # 8 tokens + eos = 9 tokens, exceeds max_seq_len of 5 | |
| docs = [make_tok_doc([1, 2, 3, 4, 5, 6, 7, 8], "0")] | |
| results = list(pack(iter(docs))) | |
| # Should produce ceil(9/5) = 2 full chunks | |
| assert len(results) >= 1 | |
| # First chunk should be exactly max_seq_len | |
| assert len(results[0].input_ids) == 5 | |
| def test_pack_empty_input(): | |
| """Empty input produces no output.""" | |
| pack = PackBlock(max_seq_len=10, eos_id=99, pad_id=0) | |
| results = list(pack(iter([]))) | |
| assert results == [] | |
| def test_pack_single_doc(): | |
| """Single document is packed correctly.""" | |
| pack = PackBlock(max_seq_len=10, eos_id=99, pad_id=0) | |
| docs = [make_tok_doc([1, 2, 3], "0")] # 4 tokens with eos | |
| results = list(pack(iter(docs))) | |
| assert len(results) == 1 | |
| assert results[0].input_ids == [1, 2, 3, 99, 0, 0, 0, 0, 0, 0] | |
| assert results[0].seq_lens == [4] | |
| assert results[0].total_tokens == 4 | |
| def test_pack_exact_fit(): | |
| """Document that exactly fills the bin.""" | |
| pack = PackBlock(max_seq_len=5, eos_id=99, pad_id=0) | |
| docs = [make_tok_doc([1, 2, 3, 4], "0")] # 4 + 1 eos = exactly 5 | |
| results = list(pack(iter(docs))) | |
| assert len(results) == 1 | |
| assert results[0].input_ids == [1, 2, 3, 4, 99] | |
| assert results[0].total_tokens == 5 | |
| def test_pack_determinism(): | |
| """Same input always produces same output.""" | |
| pack = PackBlock(max_seq_len=15, eos_id=99, pad_id=0) | |
| def make_input(): | |
| return [ | |
| make_tok_doc([1, 2, 3], "0"), | |
| make_tok_doc([4, 5], "1"), | |
| make_tok_doc([6, 7, 8, 9], "2"), | |
| ] | |
| result_1 = [(s.input_ids, s.seq_lens) for s in pack(iter(make_input()))] | |
| result_2 = [(s.input_ids, s.seq_lens) for s in pack(iter(make_input()))] | |
| assert result_1 == result_2 | |
| def test_pack_signature(): | |
| """Pack block has informative signature.""" | |
| pack = PackBlock(max_seq_len=2048, eos_id=151645, pad_id=0) | |
| sig = pack.signature | |
| assert "pack" in sig | |
| assert "2048" in sig | |
| assert "151645" in sig | |
| if __name__ == "__main__": | |
| test_pack_basic() | |
| test_pack_overflow() | |
| test_pack_long_document() | |
| test_pack_empty_input() | |
| test_pack_single_doc() | |
| test_pack_exact_fit() | |
| test_pack_determinism() | |
| test_pack_signature() | |
| print("All pack tests passed!") | |