File size: 3,700 Bytes
6cfe55f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import importlib.util
import pathlib
import unittest
from dataclasses import dataclass

ROOT = pathlib.Path(__file__).resolve().parents[1]
MODULE_PATH = ROOT / "app" / "gpt" / "request_chunker.py"
spec = importlib.util.spec_from_file_location("request_chunker", MODULE_PATH)
if spec is None or spec.loader is None:
    raise ImportError("request_chunker module spec not found")
request_chunker = importlib.util.module_from_spec(spec)
spec.loader.exec_module(request_chunker)
RequestChunker = request_chunker.RequestChunker


@dataclass
class DummySeg:
    start: float
    end: float
    text: str


def build_messages(segments, image_urls, **_):
    content = [{"type": "text", "text": "".join(s.text for s in segments)}]
    for url in image_urls:
        content.append({"type": "image_url", "image_url": {"url": url, "detail": "auto"}})
    return [{"role": "user", "content": content}]


def size_estimator(messages):
    size = 0
    for part in messages[0]["content"]:
        if part["type"] == "text":
            size += len(part["text"])
        else:
            size += len(part["image_url"]["url"])
    return size


class TestRequestChunker(unittest.TestCase):
    def test_chunk_segments_preserves_order_and_content(self):
        segments = [
            DummySeg(0, 1, "aaaa"),
            DummySeg(1, 2, "bbbb"),
            DummySeg(2, 3, "cccc"),
        ]
        chunker = RequestChunker(build_messages, max_bytes=8, size_estimator=size_estimator)
        chunks = chunker.chunk(segments, [])
        texts = ["".join(seg.text for seg in c.segments) for c in chunks]
        self.assertEqual("".join(texts), "aaaabbbbcccc")
        self.assertTrue(all(texts))

    def test_chunk_images_distributed_across_batches(self):
        segments = [DummySeg(0, 1, "aa")]
        images = ["i" * 6, "j" * 6, "k" * 6]
        chunker = RequestChunker(build_messages, max_bytes=10, size_estimator=size_estimator)
        chunks = chunker.chunk(segments, images)
        all_images = [img for c in chunks for img in c.image_urls]
        self.assertEqual(all_images, images)

    def test_chunk_images_are_not_front_loaded_when_multiple_segment_chunks(self):
        segments = [
            DummySeg(0, 1, "aaaaaa"),
            DummySeg(1, 2, "bbbbbb"),
            DummySeg(2, 3, "cccccc"),
        ]
        images = ["11111", "22222", "33333"]
        chunker = RequestChunker(build_messages, max_bytes=12, size_estimator=size_estimator)
        chunks = chunker.chunk(segments, images)

        self.assertGreaterEqual(len(chunks), 3)
        image_counts = [len(c.image_urls) for c in chunks]
        self.assertGreater(image_counts[1], 0)
        self.assertGreater(image_counts[2], 0)
        all_images = [img for c in chunks for img in c.image_urls]
        self.assertEqual(all_images, images)

    def test_split_oversized_segment(self):
        segments = [DummySeg(0, 1, "x" * 25)]
        chunker = RequestChunker(build_messages, max_bytes=10, size_estimator=size_estimator)
        chunks = chunker.chunk(segments, [])
        combined = "".join(seg.text for c in chunks for seg in c.segments)
        self.assertEqual(combined, "x" * 25)

    def test_group_texts_by_budget(self):
        chunker = RequestChunker(build_messages, max_bytes=10, size_estimator=size_estimator)

        def build_text_messages(texts, *_args, **_kwargs):
            content = [{"type": "text", "text": "".join(texts)}]
            return [{"role": "user", "content": content}]

        groups = chunker.group_texts_by_budget(["aaaaa", "bbbbb", "ccccc"], build_text_messages)
        self.assertEqual(groups, [["aaaaa", "bbbbb"], ["ccccc"]])


if __name__ == "__main__":
    unittest.main()