File size: 4,643 Bytes
ed59144 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | """Test packing block."""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from grandline.blocks.pack import PackBlock
from grandline.types import PackedSequence, TokenizedDocument
def make_tok_doc(token_ids: list[int], doc_id: str = "0") -> TokenizedDocument:
"""Helper to create a TokenizedDocument."""
return TokenizedDocument(
token_ids=token_ids,
num_tokens=len(token_ids),
doc_id=doc_id,
shard_id="shard_0",
)
def test_pack_basic():
"""Basic packing: multiple docs fit in one sequence."""
pack = PackBlock(max_seq_len=20, eos_id=99, pad_id=0)
docs = [
make_tok_doc([1, 2, 3], "0"), # 3 + 1 (eos) = 4 tokens
make_tok_doc([4, 5, 6, 7], "1"), # 4 + 1 = 5 tokens
make_tok_doc([8, 9], "2"), # 2 + 1 = 3 tokens
]
# Total: 4 + 5 + 3 = 12 tokens, fits in 20
results = list(pack(iter(docs)))
assert len(results) == 1
seq = results[0]
assert len(seq.input_ids) == 20
assert seq.seq_lens == [4, 5, 3]
assert seq.total_tokens == 12
# Check structure: [1,2,3,99, 4,5,6,7,99, 8,9,99, 0,0,0,0,0,0,0,0]
assert seq.input_ids[:4] == [1, 2, 3, 99]
assert seq.input_ids[4:9] == [4, 5, 6, 7, 99]
assert seq.input_ids[9:12] == [8, 9, 99]
assert all(t == 0 for t in seq.input_ids[12:])
def test_pack_overflow():
"""Docs that don't fit in current bin trigger a new bin."""
pack = PackBlock(max_seq_len=10, eos_id=99, pad_id=0)
docs = [
make_tok_doc([1, 2, 3, 4], "0"), # 5 tokens with eos
make_tok_doc([5, 6, 7, 8], "1"), # 5 tokens with eos → total 10, fits
make_tok_doc([9, 10], "2"), # 3 tokens → doesn't fit, new bin
]
results = list(pack(iter(docs)))
assert len(results) == 2
# First bin: [1,2,3,4,99, 5,6,7,8,99] = exactly 10
assert results[0].input_ids == [1, 2, 3, 4, 99, 5, 6, 7, 8, 99]
assert results[0].seq_lens == [5, 5]
# Second bin: [9,10,99, 0,0,0,0,0,0,0] = 3 tokens + 7 padding
assert results[1].input_ids == [9, 10, 99, 0, 0, 0, 0, 0, 0, 0]
assert results[1].seq_lens == [3]
def test_pack_long_document():
"""Documents longer than max_seq_len are split into chunks."""
pack = PackBlock(max_seq_len=5, eos_id=99, pad_id=0)
# 8 tokens + eos = 9 tokens, exceeds max_seq_len of 5
docs = [make_tok_doc([1, 2, 3, 4, 5, 6, 7, 8], "0")]
results = list(pack(iter(docs)))
# Should produce ceil(9/5) = 2 full chunks
assert len(results) >= 1
# First chunk should be exactly max_seq_len
assert len(results[0].input_ids) == 5
def test_pack_empty_input():
"""Empty input produces no output."""
pack = PackBlock(max_seq_len=10, eos_id=99, pad_id=0)
results = list(pack(iter([])))
assert results == []
def test_pack_single_doc():
"""Single document is packed correctly."""
pack = PackBlock(max_seq_len=10, eos_id=99, pad_id=0)
docs = [make_tok_doc([1, 2, 3], "0")] # 4 tokens with eos
results = list(pack(iter(docs)))
assert len(results) == 1
assert results[0].input_ids == [1, 2, 3, 99, 0, 0, 0, 0, 0, 0]
assert results[0].seq_lens == [4]
assert results[0].total_tokens == 4
def test_pack_exact_fit():
"""Document that exactly fills the bin."""
pack = PackBlock(max_seq_len=5, eos_id=99, pad_id=0)
docs = [make_tok_doc([1, 2, 3, 4], "0")] # 4 + 1 eos = exactly 5
results = list(pack(iter(docs)))
assert len(results) == 1
assert results[0].input_ids == [1, 2, 3, 4, 99]
assert results[0].total_tokens == 5
def test_pack_determinism():
"""Same input always produces same output."""
pack = PackBlock(max_seq_len=15, eos_id=99, pad_id=0)
def make_input():
return [
make_tok_doc([1, 2, 3], "0"),
make_tok_doc([4, 5], "1"),
make_tok_doc([6, 7, 8, 9], "2"),
]
result_1 = [(s.input_ids, s.seq_lens) for s in pack(iter(make_input()))]
result_2 = [(s.input_ids, s.seq_lens) for s in pack(iter(make_input()))]
assert result_1 == result_2
def test_pack_signature():
"""Pack block has informative signature."""
pack = PackBlock(max_seq_len=2048, eos_id=151645, pad_id=0)
sig = pack.signature
assert "pack" in sig
assert "2048" in sig
assert "151645" in sig
if __name__ == "__main__":
test_pack_basic()
test_pack_overflow()
test_pack_long_document()
test_pack_empty_input()
test_pack_single_doc()
test_pack_exact_fit()
test_pack_determinism()
test_pack_signature()
print("All pack tests passed!")
|