File size: 9,513 Bytes
cf3d756
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import torch
from torch.utils.data import IterableDataset, get_worker_info
import threading
from queue import Queue
from typing import Iterator
import itertools
import random

random.seed(42)  # Set the random seed to the meaning of life for good luck

class ConstantLengthDataset(IterableDataset):
    def __init__(

        self,

        dataset,

        infinite: bool = False,

        max_sample_length: int = 1024,

        seq_length: int = 1024,

        num_of_sequences: int = 1024,

        queue_size: int = 2,

        max_images_per_example: int = 4,

        max_images_per_knapsack: int = 18,

    ):
        self.dataset = dataset
        self.max_sample_length = max_sample_length
        self.seq_length = seq_length
        self.max_length = seq_length * num_of_sequences
        self.epoch = 0  # only advanced when infinite=True
        self.infinite = infinite
        self.queue_size = max(queue_size, 1)
        self.max_images_per_example = max_images_per_example
        self.max_images_per_knapsack = max_images_per_knapsack
        self._sentinel = object()
        self._average_length_per_sample = (
            self.dataset.mp_image_token_length + 198
        )  # 198 is the average tokens for the cauldron dataset

    def __len__(self):
        return int(
            len(self.dataset) * self._average_length_per_sample / self.seq_length
        )

    def __iter__(self) -> Iterator[dict]:
        """

        Returns an iterator over the dataset that yields fixed-length sequences for training.



        The iterator uses a producer-consumer pattern with a background thread to efficiently

        pre-fetch and buffer samples. The producer thread continuously reads from the base

        dataset and fills a queue, while the main thread consumes from the queue.



        The dataset is automatically sharded across workers when using num_workers > 1.



        Returns:

            Iterator[dict]: An iterator that yields training samples with the following structure:

                - input_ids: Tensor of token ids of shape (seq_length,)

                - labels: Tensor of labels of shape (seq_length,)

                - attention_mask: Tensor of attention mask of shape (seq_length,)

                - images: List of processed image tensors

        """
        worker_info = get_worker_info()
        worker_id = worker_info.id if worker_info else 0
        num_workers = worker_info.num_workers if worker_info else 1

        def make_base_iterator():
            """Return a (sharded) iterator over the underlying dataset."""
            if not hasattr(self.dataset.dataset, "__len__"):
                return self.dataset.iter_for_worker()  # with iterable datasets, each worker gets different shards
            
            all_indices = range(len(self.dataset))

            # Shard the *indices* first, before any data is fetched.
            if num_workers > 1:
                worker_indices = itertools.islice(
                    all_indices, worker_id, None, num_workers
                )
            else:
                worker_indices = all_indices

            # Create an iterator that only calls __getitem__ for the assigned indices.
            def sharded_item_iterator():
                for idx in worker_indices:
                    yield self.dataset[idx]

            return sharded_item_iterator()

        queue: Queue = Queue(maxsize=self.queue_size)

        producer = threading.Thread(
            target=self._producer, args=(make_base_iterator, queue), daemon=True
        )
        producer.start()

        while True:
            batch_of_batches = queue.get()
            if batch_of_batches is self._sentinel:
                break
            for batch in batch_of_batches:
                yield batch

    def _producer(

        self,

        make_iterator,  # a zero-arg lambda that returns a fresh (possibly sharded) iterator

        queue: Queue,

    ):
        """Runs in a separate daemon thread and keeps `queue` full."""
        iterator = make_iterator()
        more_examples = True

        while more_examples:
            # ------------- 1) pull raw samples until we have enough -------- #
            buffer, buffer_len = [], 0
            while buffer_len < self.max_length:
                try:
                    sample = next(iterator)
                except StopIteration:
                    if self.infinite:
                        iterator = make_iterator()
                        self.epoch += 1
                        print(f"Epoch {self.epoch} finished, restarting iterator")
                        continue
                    else:
                        more_examples = False
                        break

                if sample is None:  # Ratings filtered out the sample
                    continue

                if len(sample["input_ids"]) >= self.max_sample_length:
                    continue  # skip overly long samples
                if len(sample["images"]) > self.max_images_per_example:
                    continue  # skip samples that exceed the image constraint

                sample["input_ids"] = torch.cat(
                    [
                        sample["input_ids"],
                        torch.tensor([self.dataset.tokenizer.pad_token_id]),
                    ]
                )
                sample["attention_mask"] = torch.cat(
                    [sample["attention_mask"], torch.tensor([0])]
                )
                sample["labels"] = torch.cat([sample["labels"], torch.tensor([-100])])

                buffer.append(sample)
                buffer_len += len(sample["input_ids"])

            if not buffer:
                break  # nothing left and not infinite

            # ------------- 2) run greedy knapsack & pack groups ------------ #
            groups = self._balanced_greedy_knapsack(
                buffer,
                self.seq_length,
                delta=5,
                max_images_per_knapsack=self.max_images_per_knapsack,
            )

            packed_group = []
            for g in groups:
                packed = self._pack_one_group(g, buffer, self.seq_length)
                packed_group.append({
                    "input_ids":      packed[0],
                    "labels":         packed[1],
                    "attention_mask": packed[2],
                    "images":         packed[3],
                })

            if packed_group:
                queue.put(packed_group)

        # finished → unblock consumer
        queue.put(self._sentinel)

    def _balanced_greedy_knapsack(

        self, buffer, L, delta=0, max_images_per_knapsack=None

    ):
        # Extract lengths and image counts from buffer
        lengths = [len(x["input_ids"]) for x in buffer]
        image_counts = [len(x["images"]) for x in buffer]

        # keep the position while sorting
        items = sorted(
            enumerate(zip(lengths, image_counts)), key=lambda x: x[1][0], reverse=True
        )

        min_knapsacks = (sum(lengths) + L - 1) // L + delta
        knapsack_load = [0] * min_knapsacks
        knapsack_image_counts = [0] * min_knapsacks
        knapsack_groups = [[] for _ in range(min_knapsacks)]

        for idx, (item_len, item_image_count) in items:
            # Find a suitable knapsack that satisfies both length and image count constraints
            suitable_knapsack = None

            # First try to find a knapsack that can fit both constraints
            for ks_id in sorted(
                range(len(knapsack_load)), key=knapsack_load.__getitem__
            ):
                length_fits = knapsack_load[ks_id] + item_len <= L
                image_fits = (
                    max_images_per_knapsack is None
                    or knapsack_image_counts[ks_id] + item_image_count
                    <= max_images_per_knapsack
                )

                if length_fits and image_fits:
                    suitable_knapsack = ks_id
                    break

            # If no existing knapsack can fit, create a new one
            if suitable_knapsack is None:
                suitable_knapsack = len(knapsack_load)
                knapsack_load.append(0)
                knapsack_image_counts.append(0)
                knapsack_groups.append([])

            knapsack_groups[suitable_knapsack].append(idx)
            knapsack_load[suitable_knapsack] += item_len
            knapsack_image_counts[suitable_knapsack] += item_image_count

        # remove the completely empty bags that the +delta heuristic created
        random.shuffle(knapsack_groups)  # Knapsacks are semi-ordered after packing, thanks Luis for noticing!
        return [g for g in knapsack_groups if g]

    def _pack_one_group(self, group_indices, batch, max_len):
        ids, lbl, am, ims = [], [], [], []

        for i in group_indices:
            ids.extend(batch[i]["input_ids"])
            lbl.extend(batch[i]["labels"])
            am.extend(batch[i]["attention_mask"])
            ims.extend(batch[i]["images"])

        # safety: assert we never overflow
        if len(ids) > max_len:
            raise ValueError(f"Packed length {len(ids)} > max_len {max_len}")

        return torch.stack(ids), torch.stack(lbl), torch.stack(am), ims