File size: 14,698 Bytes
0fd26a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
"""
The core class for loading input/output image pairs using WebDataset.
"""

from torch.utils.data import IterableDataset
import numpy as np
import torch
import math
from PIL import Image
import einops

import webdataset as wds
import braceexpand

from data.transforms import center_crop_arr
from data.wds_utils import (
    log_and_continue,
    pytorch_worker_info,
    is_multi_node_environment,
    get_dataset_size,
    handle_reconstruction_task,
    extract_fields_to_tuple,
    identity_function,
    has_input_image,
    WeightedRoundRobinSampler,
    StrictProportionalBatchSampler,
)


class WebDatasetDataset(IterableDataset):
    """
    load input/output image pairs using WebDataset.
    """
    def __init__(self, tar_pattern, resolution=256, shuffle_buffer=300, 
                 resampled=True, handler=log_and_continue, 
                 estimated_samples_per_shard=1000,
                 split_data_by_node_flag=True,
                 allow_shared_shards=False,
                 vl_chat_processor=None,
                 device=None,
                 num_workers=None,
                 batch_size=None,
                 sampling_weights=None,
                 force_simple_mode=False,
                 enable_shuffle=True,
                 partial=False):
        super().__init__()
        self.resolution = resolution
        self.handler = handler
        self.vl_chat_processor = vl_chat_processor
        self.num_workers = num_workers if num_workers is not None else 1
        self.batch_size = batch_size
        self.sampling_weights = sampling_weights
        self.enable_shuffle = enable_shuffle
        self.partial = partial
        self.resampled = resampled

        if device is None:
            if torch.cuda.is_available():
                self.device = torch.device(f'cuda:{torch.cuda.current_device()}')
            else:
                self.device = torch.device('cpu')
        else:
            if isinstance(device, str):
                self.device = torch.device(device)
            else:
                self.device = device
        
        patterns = [p.strip() for p in tar_pattern.split(',') if p.strip()]
        self.use_proportional_sampling = len(patterns) > 1 and not force_simple_mode
        
        if self.use_proportional_sampling:
            self._init_proportional_sampling(
                patterns, tar_pattern, shuffle_buffer, resampled, handler,
                estimated_samples_per_shard, split_data_by_node_flag, allow_shared_shards
            )
        else:
            self._init_simple_mode(
                patterns, tar_pattern, shuffle_buffer, resampled, handler,
                estimated_samples_per_shard, split_data_by_node_flag, allow_shared_shards
            )


    def _init_proportional_sampling(self, patterns, tar_pattern, shuffle_buffer,
                                     resampled, handler, estimated_samples_per_shard,
                                     split_data_by_node_flag, allow_shared_shards):
        weights_str = None
        if self.sampling_weights is not None:
            if len(self.sampling_weights) != len(patterns):
                raise ValueError(f"number of sampling weights ({len(self.sampling_weights)}) must be equal to the number of path patterns ({len(patterns)})")
            if any(w <= 0 for w in self.sampling_weights):
                raise ValueError("all sampling weights must be positive")
            weights_str = " : ".join([f"{w*100:.1f}%" for w in self.sampling_weights])
            print(f"Detected {len(patterns)} path patterns, using weighted sampling with ratios ({weights_str})")
        else:
            print(f"Detected {len(patterns)} path patterns, using RoundRobin for proportional sampling (50:50)")
        
        pipelines = []
        total_shards = 0
        
        for i, pattern in enumerate(patterns):
            pattern_urls = list(braceexpand.braceexpand(pattern))
            total_shards += len(pattern_urls)

            if allow_shared_shards:
                need_nodesplitter = False
            elif split_data_by_node_flag and is_multi_node_environment():
                need_nodesplitter = True
            else:
                need_nodesplitter = False
            
            pipeline = self._create_single_pattern_pipeline(
                pattern_urls, shuffle_buffer, resampled, handler,
                need_nodesplitter, allow_shared_shards
            )
            pipelines.append(pipeline)
            print(f"Pattern {i+1}: {len(pattern_urls)} shards")
        
        self.num_shards = total_shards
        if split_data_by_node_flag:
            total_num_samples, _ = get_dataset_size(tar_pattern, estimated_samples_per_shard)
            self.total_num_samples = total_num_samples
        else:
            self.total_num_samples = total_shards * estimated_samples_per_shard
        
        self.num_samples = self.num_shards * estimated_samples_per_shard
        with_epoch_size = self._compute_epoch_size()
        
        if self.sampling_weights is not None:
            if self.batch_size is not None:
                if not self.resampled:
                    raise ValueError("StrictProportionalBatchSampler only used in resampled=True")
                merged_source = StrictProportionalBatchSampler(
                    pipelines, self.sampling_weights, self.batch_size
                )
                batch_already_done = True
            else:
                merged_source = WeightedRoundRobinSampler(pipelines, self.sampling_weights)
                batch_already_done = False
        else:
            merged_source = wds.RoundRobin(*pipelines)
            batch_already_done = False
        
        pipeline_stages = [merged_source]
        if self.batch_size is not None and not batch_already_done:
            pipeline_stages.append(wds.batched(self.batch_size, partial=self.partial))
        
        self._pipeline = wds.DataPipeline(*pipeline_stages).with_epoch(with_epoch_size)
        
        sampling_mode = "weighted sampling" if self.sampling_weights is not None else "RoundRobin"
        weights_info = f" ({weights_str})" if self.sampling_weights is not None else ""
        if self.batch_size is not None:
            print(f"{sampling_mode} mode{weights_info}: {len(patterns)} patterns, {total_shards} total shards, "
                  f"with_epoch({with_epoch_size} batches per worker), "
                  f"num_batches={self.num_batches}, num_samples={self.num_samples}")
        else:
            print(f"{sampling_mode} mode{weights_info}: {len(patterns)} patterns, {total_shards} total shards, "
                  f"with_epoch({with_epoch_size} samples per worker), num_samples={self.num_samples}")


    def _init_simple_mode(self, patterns, tar_pattern, shuffle_buffer,
                           resampled, handler, estimated_samples_per_shard,
                           split_data_by_node_flag, allow_shared_shards):
        if len(patterns) > 1:
            print(f"Simple mode: detected {len(patterns)} path patterns, merging all paths")
            all_urls = []
            for i, pattern in enumerate(patterns):
                pattern_urls = list(braceexpand.braceexpand(pattern))
                all_urls.extend(pattern_urls)
                print(f"  Pattern {i+1}: {len(pattern_urls)} shards")
            print(f"  Total merged: {len(all_urls)} shards")
        else:
            all_urls = list(braceexpand.braceexpand(tar_pattern))

        urls = all_urls
        
        need_nodesplitter = False
        if allow_shared_shards:
            print(f"Shared shards mode: all {len(urls)} shards accessible by all processes")
        elif split_data_by_node_flag and is_multi_node_environment():
            need_nodesplitter = True
            print(f"Multi-process mode: using {len(urls)} shards, will be split by nodesplitter")
        else:
            print(f"Single process mode: using all {len(urls)} shards")
        
        self.num_shards = len(urls)
        if split_data_by_node_flag:
            total_num_samples, _ = get_dataset_size(tar_pattern, estimated_samples_per_shard)
            self.num_samples = self.num_shards * estimated_samples_per_shard
            self.total_num_samples = total_num_samples
        else:
            self.num_samples = self.num_shards * estimated_samples_per_shard
            self.total_num_samples = self.num_samples
        
        if not resampled:
            _, world_size, _, _ = pytorch_worker_info()
            total_workers_needed = self.num_workers * world_size if world_size > 1 else self.num_workers
            if self.num_shards < total_workers_needed:
                print(f"Warning: Only {self.num_shards} shards but need {total_workers_needed} workers "
                      f"(num_workers={self.num_workers}, world_size={world_size}). "
                      f"Some workers may not have data. Consider using resampled=True or increasing shard count.")
        
        with_epoch_size = self._compute_epoch_size()
        
        if resampled:
            shard_source = wds.ResampledShards(urls)
        else:
            shard_source = wds.SimpleShardList(urls)
        
        pipeline_stages = [shard_source]
        
        if not allow_shared_shards and need_nodesplitter and hasattr(wds, "split_by_node"):
            pipeline_stages.append(wds.split_by_node)
            print(f"Using wds.split_by_node for multi-process training")
        
        if self.num_workers > 1:
            pipeline_stages.append(wds.split_by_worker)
            print(f"Added wds.split_by_worker for {self.num_workers} workers")
        
        pipeline_stages.append(wds.tarfile_to_samples(handler=handler))
        pipeline_stages.extend(self._build_processing_stages(shuffle_buffer, handler))
        
        if self.batch_size is not None:
            pipeline_stages.append(wds.batched(self.batch_size, partial=self.partial))
        
        self._pipeline = wds.DataPipeline(*pipeline_stages).with_epoch(with_epoch_size)
        
        if self.batch_size is not None:
            print(f"WebDataset initialized: {self.num_shards} shards for this node, "
                  f"with_epoch({with_epoch_size} batches per worker), "
                  f"num_batches={self.num_batches}, num_samples={self.num_samples} "
                  f"(num_workers={self.num_workers}, batch_size={self.batch_size})")
        else:
            print(f"WebDataset initialized: {self.num_shards} shards for this node, "
                  f"with_epoch({with_epoch_size} samples per worker), "
                  f"num_samples={self.num_samples} (num_workers={self.num_workers})")


    def _compute_epoch_size(self):
        if self.batch_size is not None:
            _, world_size, _, _ = pytorch_worker_info()
            num_worker_batches = math.ceil(self.total_num_samples / (world_size * self.batch_size * self.num_workers))
            self.num_batches = num_worker_batches * self.num_workers
            self.num_samples = self.num_batches * self.batch_size
            return num_worker_batches
        else:
            if self.num_workers > 1:
                epoch_size = math.ceil(self.num_samples / self.num_workers)
            else:
                epoch_size = self.num_samples
            self.num_samples = epoch_size
            return epoch_size

    def _build_processing_stages(self, shuffle_buffer, handler):
        stages = []
        if self.enable_shuffle:
            stages.append(wds.shuffle(shuffle_buffer, handler=handler))
        
        stages.extend([
            wds.decode("pil", handler=handler),
            wds.map(handle_reconstruction_task, handler=handler),
            wds.select(has_input_image),
            wds.map(extract_fields_to_tuple),
            wds.map_tuple(
                self._preprocess_input,
                self._preprocess_output,
                identity_function
            ),
            wds.map(self._unpack_input_tuple),
        ])
        return stages

    def _create_single_pattern_pipeline(self, urls, shuffle_buffer, resampled, handler,
                                        need_nodesplitter, allow_shared_shards):
        if resampled:
            shard_source = wds.ResampledShards(urls)
        else:
            shard_source = wds.SimpleShardList(urls)
        
        pipeline_stages = [shard_source]
        
        if not allow_shared_shards and need_nodesplitter and hasattr(wds, "split_by_node"):
            pipeline_stages.append(wds.split_by_node)
        
        if self.num_workers > 1:
            pipeline_stages.append(wds.split_by_worker)
        
        pipeline_stages.append(wds.tarfile_to_samples(handler=handler))
        pipeline_stages.extend(self._build_processing_stages(shuffle_buffer, handler))
        
        return wds.DataPipeline(*pipeline_stages)

    
    def _preprocess_input(self, pil_image):
        if not isinstance(pil_image, Image.Image):
            pil_image = Image.fromarray(pil_image)
        pil_image = pil_image.convert("RGB")
        
        input_arr = center_crop_arr(pil_image, image_size=self.resolution)
        input_arr = (input_arr / 127.5 - 1.0).astype(np.float32)
        input_tensor = torch.from_numpy(einops.rearrange(input_arr, 'h w c -> c h w'))
        
        if self.vl_chat_processor is not None:
            images_outputs = self.vl_chat_processor.image_processor(
                [pil_image], 
                return_tensors="pt"
            )
            pixel_values = images_outputs.pixel_values.squeeze(0)
            return pixel_values, input_tensor
        else:
            return np.array(pil_image, dtype=np.uint8), input_tensor
    
    def _preprocess_output(self, pil_image):
        if not isinstance(pil_image, Image.Image):
            pil_image = Image.fromarray(pil_image)
        pil_image = pil_image.convert("RGB")
        output_arr = center_crop_arr(pil_image, image_size=self.resolution)
        output_arr = (output_arr / 127.5 - 1.0).astype(np.float32)
        return torch.from_numpy(einops.rearrange(output_arr, 'h w c -> c h w'))
    
    def _unpack_input_tuple(self, sample):
        input_tuple, output_tensor, sample_type = sample
        pixel_values, input_tensor = input_tuple
        return pixel_values, output_tensor, input_tensor, sample_type

    
    def __iter__(self):
        return iter(self._pipeline)
    
    def set_vl_chat_processor(self, vl_chat_processor):
        self.vl_chat_processor = vl_chat_processor
    
    def set_device(self, device):
        if isinstance(device, str):
            self.device = torch.device(device)
        else:
            self.device = device