File size: 4,643 Bytes
ed59144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""Test packing block."""

import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent / "src"))

from grandline.blocks.pack import PackBlock
from grandline.types import PackedSequence, TokenizedDocument


def make_tok_doc(token_ids: list[int], doc_id: str = "0") -> TokenizedDocument:
    """Helper to create a TokenizedDocument."""
    return TokenizedDocument(
        token_ids=token_ids,
        num_tokens=len(token_ids),
        doc_id=doc_id,
        shard_id="shard_0",
    )


def test_pack_basic():
    """Basic packing: multiple docs fit in one sequence."""
    pack = PackBlock(max_seq_len=20, eos_id=99, pad_id=0)

    docs = [
        make_tok_doc([1, 2, 3], "0"),      # 3 + 1 (eos) = 4 tokens
        make_tok_doc([4, 5, 6, 7], "1"),   # 4 + 1 = 5 tokens
        make_tok_doc([8, 9], "2"),          # 2 + 1 = 3 tokens
    ]
    # Total: 4 + 5 + 3 = 12 tokens, fits in 20

    results = list(pack(iter(docs)))
    assert len(results) == 1

    seq = results[0]
    assert len(seq.input_ids) == 20
    assert seq.seq_lens == [4, 5, 3]
    assert seq.total_tokens == 12
    # Check structure: [1,2,3,99, 4,5,6,7,99, 8,9,99, 0,0,0,0,0,0,0,0]
    assert seq.input_ids[:4] == [1, 2, 3, 99]
    assert seq.input_ids[4:9] == [4, 5, 6, 7, 99]
    assert seq.input_ids[9:12] == [8, 9, 99]
    assert all(t == 0 for t in seq.input_ids[12:])


def test_pack_overflow():
    """Docs that don't fit in current bin trigger a new bin."""
    pack = PackBlock(max_seq_len=10, eos_id=99, pad_id=0)

    docs = [
        make_tok_doc([1, 2, 3, 4], "0"),   # 5 tokens with eos
        make_tok_doc([5, 6, 7, 8], "1"),   # 5 tokens with eos → total 10, fits
        make_tok_doc([9, 10], "2"),          # 3 tokens → doesn't fit, new bin
    ]

    results = list(pack(iter(docs)))
    assert len(results) == 2

    # First bin: [1,2,3,4,99, 5,6,7,8,99] = exactly 10
    assert results[0].input_ids == [1, 2, 3, 4, 99, 5, 6, 7, 8, 99]
    assert results[0].seq_lens == [5, 5]

    # Second bin: [9,10,99, 0,0,0,0,0,0,0] = 3 tokens + 7 padding
    assert results[1].input_ids == [9, 10, 99, 0, 0, 0, 0, 0, 0, 0]
    assert results[1].seq_lens == [3]


def test_pack_long_document():
    """Documents longer than max_seq_len are split into chunks."""
    pack = PackBlock(max_seq_len=5, eos_id=99, pad_id=0)

    # 8 tokens + eos = 9 tokens, exceeds max_seq_len of 5
    docs = [make_tok_doc([1, 2, 3, 4, 5, 6, 7, 8], "0")]

    results = list(pack(iter(docs)))
    # Should produce ceil(9/5) = 2 full chunks
    assert len(results) >= 1

    # First chunk should be exactly max_seq_len
    assert len(results[0].input_ids) == 5


def test_pack_empty_input():
    """Empty input produces no output."""
    pack = PackBlock(max_seq_len=10, eos_id=99, pad_id=0)
    results = list(pack(iter([])))
    assert results == []


def test_pack_single_doc():
    """Single document is packed correctly."""
    pack = PackBlock(max_seq_len=10, eos_id=99, pad_id=0)

    docs = [make_tok_doc([1, 2, 3], "0")]  # 4 tokens with eos

    results = list(pack(iter(docs)))
    assert len(results) == 1
    assert results[0].input_ids == [1, 2, 3, 99, 0, 0, 0, 0, 0, 0]
    assert results[0].seq_lens == [4]
    assert results[0].total_tokens == 4


def test_pack_exact_fit():
    """Document that exactly fills the bin."""
    pack = PackBlock(max_seq_len=5, eos_id=99, pad_id=0)

    docs = [make_tok_doc([1, 2, 3, 4], "0")]  # 4 + 1 eos = exactly 5

    results = list(pack(iter(docs)))
    assert len(results) == 1
    assert results[0].input_ids == [1, 2, 3, 4, 99]
    assert results[0].total_tokens == 5


def test_pack_determinism():
    """Same input always produces same output."""
    pack = PackBlock(max_seq_len=15, eos_id=99, pad_id=0)

    def make_input():
        return [
            make_tok_doc([1, 2, 3], "0"),
            make_tok_doc([4, 5], "1"),
            make_tok_doc([6, 7, 8, 9], "2"),
        ]

    result_1 = [(s.input_ids, s.seq_lens) for s in pack(iter(make_input()))]
    result_2 = [(s.input_ids, s.seq_lens) for s in pack(iter(make_input()))]

    assert result_1 == result_2


def test_pack_signature():
    """Pack block has informative signature."""
    pack = PackBlock(max_seq_len=2048, eos_id=151645, pad_id=0)
    sig = pack.signature
    assert "pack" in sig
    assert "2048" in sig
    assert "151645" in sig


if __name__ == "__main__":
    test_pack_basic()
    test_pack_overflow()
    test_pack_long_document()
    test_pack_empty_input()
    test_pack_single_doc()
    test_pack_exact_fit()
    test_pack_determinism()
    test_pack_signature()
    print("All pack tests passed!")