sanps commited on
Commit
e582a2f
·
verified ·
1 Parent(s): f9a257d

Upload data.py

Browse files
Files changed (1) hide show
  1. data.py +792 -0
data.py ADDED
@@ -0,0 +1,792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WebDataset-based data loader for foveated VLM training.
3
+
4
+ Reads tar shards produced by video2dataset / the CPU precompute pipeline.
5
+ Each sample in a shard contains EITHER:
6
+ A) Pre-extracted frames:
7
+ - {key}.jpg or {key}_000.jpg, {key}_001.jpg, ... -- JPEG frames (224x224)
8
+ - {key}.json -- metadata: {caption, token_ids, loss_mask, ...}
9
+ B) Raw MP4 from video2dataset:
10
+ - {key}.mp4 -- raw video file
11
+ - {key}.txt -- caption text
12
+ - {key}.json -- metadata: {videoid, duration, url, ...}
13
+
14
+ On-the-fly tokenization: if token_ids/loss_mask are missing from JSON,
15
+ the sample is tokenized at load time using the provided tokenizer.
16
+
17
+ Returns dicts with:
18
+ frames: [T, 3, 224, 224] float32, ImageNet-normalized for DINO
19
+ input_ids: [S] long, token IDs
20
+ loss_mask: [S] float32, 1.0 for answer tokens, 0.0 otherwise
21
+ num_frames: int actual frame count before any padding
22
+ """
23
+
24
+ import io
25
+ import json
26
+ import os
27
+ import re
28
+ import subprocess
29
+ import tempfile
30
+ from typing import Optional
31
+
32
+ import torch
33
+ import torchvision.transforms.functional as TF
34
+ import webdataset as wds
35
+
36
+ # ImageNet normalization for DINOv2 (same constants as src/data/llava_video_dataset.py)
37
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
38
+ IMAGENET_STD = (0.229, 0.224, 0.225)
39
+
40
+ # Regex to detect multi-frame filenames like "sample_003.jpg"
41
+ _FRAME_INDEX_RE = re.compile(r"^(.+)_(\d{3})\.(jpg|jpeg|png)$")
42
+
43
+ # Regex to detect single-frame filenames like "sample.jpg"
44
+ _SINGLE_FRAME_RE = re.compile(r"^(.+)\.(jpg|jpeg|png)$")
45
+
46
+
47
+ _NORM_MEAN = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
48
+ _NORM_STD = torch.tensor(IMAGENET_STD).view(3, 1, 1)
49
+
50
+
51
+ def _load_image_tensor(data: bytes) -> torch.Tensor:
52
+ """Decode JPEG/PNG bytes to a [3, 224, 224] float32 tensor, ImageNet-normalized."""
53
+ try:
54
+ # Fast path: torchvision decode_jpeg — avoids PIL/numpy overhead
55
+ from torchvision.io import decode_jpeg
56
+ raw = torch.frombuffer(bytearray(data), dtype=torch.uint8)
57
+ tensor = decode_jpeg(raw).float().div_(255.0) # [3, H, W]
58
+ tensor.sub_(_NORM_MEAN).div_(_NORM_STD)
59
+ return tensor
60
+ except Exception:
61
+ # Fallback: PIL (handles PNG and edge cases)
62
+ from PIL import Image
63
+ img = Image.open(io.BytesIO(data)).convert("RGB")
64
+ tensor = TF.to_tensor(img) # [3, H, W] float32 in [0, 1]
65
+ tensor = TF.normalize(tensor, mean=IMAGENET_MEAN, std=IMAGENET_STD)
66
+ return tensor
67
+
68
+
69
+ def _decode_mp4_frames(mp4_bytes: bytes, max_frames: int = 64) -> list[torch.Tensor]:
70
+ """Decode MP4 bytes to a list of [3, 224, 224] tensors at 1 FPS."""
71
+ try:
72
+ import decord
73
+ decord.bridge.set_bridge("torch")
74
+ vr = decord.VideoReader(io.BytesIO(mp4_bytes), width=224, height=224)
75
+ fps = vr.get_avg_fps()
76
+ total = len(vr)
77
+ # Sample at 1 FPS
78
+ step = max(1, int(fps))
79
+ indices = list(range(0, total, step))[:max_frames]
80
+ if not indices:
81
+ return []
82
+ batch = vr.get_batch(indices) # [T, H, W, C] uint8
83
+ frames = []
84
+ for i in range(batch.shape[0]):
85
+ t = batch[i].permute(2, 0, 1).float() / 255.0 # [3, 224, 224]
86
+ t = TF.normalize(t, mean=IMAGENET_MEAN, std=IMAGENET_STD)
87
+ frames.append(t)
88
+ return frames
89
+ except ImportError:
90
+ pass
91
+
92
+ # Fallback: ffmpeg subprocess
93
+ with tempfile.NamedTemporaryFile(suffix=".mp4", dir="/workspace/tmp", delete=True) as f:
94
+ f.write(mp4_bytes)
95
+ f.flush()
96
+ frames_dir = f.name + "_frames"
97
+ os.makedirs(frames_dir, exist_ok=True)
98
+ try:
99
+ subprocess.run(
100
+ ["ffmpeg", "-y", "-i", f.name,
101
+ "-vf", "fps=1,scale=224:224:force_original_aspect_ratio=increase,crop=224:224",
102
+ "-frames:v", str(max_frames), "-q:v", "2",
103
+ os.path.join(frames_dir, "frame_%03d.jpg")],
104
+ capture_output=True, timeout=30,
105
+ )
106
+ from PIL import Image
107
+ frame_files = sorted(os.listdir(frames_dir))
108
+ frames = []
109
+ for fname in frame_files[:max_frames]:
110
+ fp = os.path.join(frames_dir, fname)
111
+ img = Image.open(fp).convert("RGB")
112
+ t = TF.to_tensor(img)
113
+ t = TF.normalize(t, mean=IMAGENET_MEAN, std=IMAGENET_STD)
114
+ frames.append(t)
115
+ return frames
116
+ except Exception:
117
+ return []
118
+ finally:
119
+ import shutil
120
+ shutil.rmtree(frames_dir, ignore_errors=True)
121
+
122
+
123
+ def decode_sample(sample: dict, max_frames: int = 64,
124
+ tokenizer=None, stage: int = 1,
125
+ replicate_image_frames: int = 1) -> Optional[dict]:
126
+ """
127
+ Decode a single webdataset sample dict into training tensors.
128
+
129
+ The sample dict has keys like:
130
+ "jpg" or "jpeg" or "png" -- single frame bytes
131
+ "000.jpg", "001.jpg", ... -- multi-frame bytes
132
+ "json" -- metadata JSON bytes or dict
133
+
134
+ Returns None if the sample is malformed (caller should filter).
135
+ """
136
+ # ------------------------------------------------------------------
137
+ # 1. Parse metadata JSON
138
+ # ------------------------------------------------------------------
139
+ meta_raw = sample.get("json")
140
+ if meta_raw is None:
141
+ return None
142
+
143
+ if isinstance(meta_raw, bytes):
144
+ try:
145
+ meta = json.loads(meta_raw.decode("utf-8"))
146
+ except (json.JSONDecodeError, UnicodeDecodeError):
147
+ return None
148
+ elif isinstance(meta_raw, str):
149
+ try:
150
+ meta = json.loads(meta_raw)
151
+ except json.JSONDecodeError:
152
+ return None
153
+ elif isinstance(meta_raw, dict):
154
+ meta = meta_raw
155
+ else:
156
+ return None
157
+
158
+ token_ids = meta.get("token_ids")
159
+ loss_mask = meta.get("loss_mask")
160
+
161
+ # On-the-fly tokenization if pre-tokenized data is missing
162
+ if token_ids is None or loss_mask is None:
163
+ from tokenization import (
164
+ tokenize_stage1, tokenize_sft, SOURCE_PROMPTS, DEFAULT_VISUAL_PROMPT,
165
+ )
166
+
167
+ # Unified format: user/assistant keys
168
+ user_text = meta.get("user", "")
169
+ assistant_text = meta.get("assistant", "")
170
+ source = meta.get("source", "")
171
+
172
+ if user_text or assistant_text:
173
+ # Has structured user/assistant format
174
+ is_text_only = meta.get("frame_count", 0) == 0
175
+ if stage == 1 and not is_text_only:
176
+ # Stage 1 visual data: per-source conditioning prompt
177
+ # Use shard's user field if non-empty, else per-source default
178
+ user_prompt = user_text if user_text else SOURCE_PROMPTS.get(source, DEFAULT_VISUAL_PROMPT)
179
+ tok = tokenize_stage1(assistant_text, tokenizer=tokenizer, user_prompt=user_prompt)
180
+ elif stage == 1 and is_text_only:
181
+ # Stage 1 text retention: keep proper chat format, all-text loss
182
+ tok = tokenize_sft(
183
+ user_text,
184
+ assistant_text,
185
+ stage=stage,
186
+ tokenizer=tokenizer,
187
+ )
188
+ tok["loss_mask"] = [1] * len(tok["token_ids"])
189
+ else:
190
+ # Stage 2-3: answer-only loss on assistant portion
191
+ # Use shard's user field if non-empty, else per-source default
192
+ effective_user = user_text if user_text else SOURCE_PROMPTS.get(source, DEFAULT_VISUAL_PROMPT)
193
+ tok = tokenize_sft(
194
+ effective_user,
195
+ assistant_text,
196
+ stage=stage,
197
+ tokenizer=tokenizer,
198
+ )
199
+ else:
200
+ # Legacy format: caption key or .txt file
201
+ caption = meta.get("caption", "")
202
+ if not caption:
203
+ txt_raw = sample.get("txt")
204
+ if isinstance(txt_raw, bytes):
205
+ caption = txt_raw.decode("utf-8", errors="replace").strip()
206
+ elif isinstance(txt_raw, str):
207
+ caption = txt_raw.strip()
208
+
209
+ if not caption or tokenizer is None:
210
+ return None
211
+
212
+ user_prompt = SOURCE_PROMPTS.get(source, DEFAULT_VISUAL_PROMPT)
213
+ if stage == 1:
214
+ tok = tokenize_stage1(caption, tokenizer=tokenizer, user_prompt=user_prompt)
215
+ else:
216
+ tok = tokenize_sft(user_prompt, caption, stage=stage, tokenizer=tokenizer)
217
+
218
+ if tokenizer is None:
219
+ return None
220
+
221
+ token_ids = tok["token_ids"]
222
+ loss_mask = tok["loss_mask"]
223
+
224
+ # ------------------------------------------------------------------
225
+ # 2. Collect frames (JPEG bytes or decode from MP4)
226
+ # ------------------------------------------------------------------
227
+ frames: list[torch.Tensor] = []
228
+
229
+ # Try MP4 first (video2dataset raw output)
230
+ mp4_data = sample.get("mp4")
231
+ if isinstance(mp4_data, bytes) and len(mp4_data) > 100:
232
+ frames = _decode_mp4_frames(mp4_data, max_frames=max_frames)
233
+ else:
234
+ # Try numbered JPEG frames (000.jpg, 001.jpg, ...)
235
+ numbered_keys: list[tuple[int, str]] = []
236
+ for key in sample:
237
+ m = re.match(r"^(\d{3})\.(jpg|jpeg|png)$", key)
238
+ if m:
239
+ numbered_keys.append((int(m.group(1)), key))
240
+
241
+ if numbered_keys:
242
+ numbered_keys.sort(key=lambda x: x[0])
243
+ for _, key in numbered_keys:
244
+ raw = sample[key]
245
+ if isinstance(raw, bytes):
246
+ try:
247
+ frames.append(_load_image_tensor(raw))
248
+ except Exception:
249
+ continue
250
+ else:
251
+ # Single frame: look for jpg / jpeg / png key
252
+ for ext in ("jpg", "jpeg", "png"):
253
+ if ext in sample and isinstance(sample[ext], bytes):
254
+ try:
255
+ frames.append(_load_image_tensor(sample[ext]))
256
+ except Exception:
257
+ pass
258
+ break
259
+
260
+ if not frames:
261
+ return None
262
+
263
+ # Truncate to max_frames
264
+ if len(frames) > max_frames:
265
+ frames = frames[:max_frames]
266
+
267
+ # Replicate single-frame images to N frames (A8 ablation: static video)
268
+ if replicate_image_frames > 1 and len(frames) == 1:
269
+ frames = frames * replicate_image_frames
270
+
271
+ num_frames = len(frames)
272
+ frames_tensor = torch.stack(frames, dim=0) # [T, 3, 224, 224]
273
+
274
+ # ------------------------------------------------------------------
275
+ # 3. Build text tensors
276
+ # ------------------------------------------------------------------
277
+ input_ids = torch.tensor(token_ids, dtype=torch.long)
278
+ loss_mask_t = torch.tensor(loss_mask, dtype=torch.float32)
279
+
280
+ # Ensure consistent lengths
281
+ min_len = min(len(input_ids), len(loss_mask_t))
282
+ input_ids = input_ids[:min_len]
283
+ loss_mask_t = loss_mask_t[:min_len]
284
+
285
+ return {
286
+ "frames": frames_tensor, # [T, 3, 224, 224]
287
+ "input_ids": input_ids, # [S]
288
+ "loss_mask": loss_mask_t, # [S]
289
+ "num_frames": num_frames, # int
290
+ }
291
+
292
+
293
+ def decode_dpo_sample(sample: dict, max_frames: int = 64,
294
+ tokenizer=None, replicate_image_frames: int = 1) -> Optional[dict]:
295
+ """
296
+ Decode a single DPO webdataset sample into training tensors.
297
+
298
+ DPO samples have JSON with keys:
299
+ user: user prompt
300
+ chosen_assistant: preferred response
301
+ rejected_assistant: dispreferred response
302
+ source: dataset source (e.g. "rlaif_v")
303
+ frame_count: number of frames (1 for images)
304
+
305
+ Returns None if the sample is malformed (caller should filter).
306
+
307
+ Returns dict with:
308
+ frames: [T, 3, 224, 224] shared visual input
309
+ chosen_input_ids: [S_c] tokenized user+chosen
310
+ chosen_loss_mask: [S_c] answer-only mask for chosen
311
+ rejected_input_ids: [S_r] tokenized user+rejected
312
+ rejected_loss_mask: [S_r] answer-only mask for rejected
313
+ num_frames: int actual frame count
314
+ """
315
+ # ------------------------------------------------------------------
316
+ # 1. Parse metadata JSON
317
+ # ------------------------------------------------------------------
318
+ meta_raw = sample.get("json")
319
+ if meta_raw is None:
320
+ return None
321
+
322
+ if isinstance(meta_raw, bytes):
323
+ try:
324
+ meta = json.loads(meta_raw.decode("utf-8"))
325
+ except (json.JSONDecodeError, UnicodeDecodeError):
326
+ return None
327
+ elif isinstance(meta_raw, str):
328
+ try:
329
+ meta = json.loads(meta_raw)
330
+ except json.JSONDecodeError:
331
+ return None
332
+ elif isinstance(meta_raw, dict):
333
+ meta = meta_raw
334
+ else:
335
+ return None
336
+
337
+ user_text = meta.get("user", "")
338
+ chosen_text = meta.get("chosen_assistant", "")
339
+ rejected_text = meta.get("rejected_assistant", "")
340
+
341
+ if not chosen_text or not rejected_text:
342
+ return None
343
+ if tokenizer is None:
344
+ return None
345
+
346
+ # ------------------------------------------------------------------
347
+ # 2. Tokenize chosen and rejected with answer-only loss masks
348
+ # ------------------------------------------------------------------
349
+ from tokenization import tokenize_sft, SOURCE_PROMPTS, DEFAULT_VISUAL_PROMPT
350
+
351
+ source = meta.get("source", "")
352
+ effective_user = user_text if user_text else SOURCE_PROMPTS.get(source, DEFAULT_VISUAL_PROMPT)
353
+
354
+ chosen_tok = tokenize_sft(effective_user, chosen_text, stage=3, tokenizer=tokenizer)
355
+ rejected_tok = tokenize_sft(effective_user, rejected_text, stage=3, tokenizer=tokenizer)
356
+
357
+ # ------------------------------------------------------------------
358
+ # 3. Collect frames (same logic as decode_sample)
359
+ # ------------------------------------------------------------------
360
+ frames: list[torch.Tensor] = []
361
+
362
+ mp4_data = sample.get("mp4")
363
+ if isinstance(mp4_data, bytes) and len(mp4_data) > 100:
364
+ frames = _decode_mp4_frames(mp4_data, max_frames=max_frames)
365
+ else:
366
+ numbered_keys: list[tuple[int, str]] = []
367
+ for key in sample:
368
+ m = re.match(r"^(\d{3})\.(jpg|jpeg|png)$", key)
369
+ if m:
370
+ numbered_keys.append((int(m.group(1)), key))
371
+
372
+ if numbered_keys:
373
+ numbered_keys.sort(key=lambda x: x[0])
374
+ for _, key in numbered_keys:
375
+ raw = sample[key]
376
+ if isinstance(raw, bytes):
377
+ try:
378
+ frames.append(_load_image_tensor(raw))
379
+ except Exception:
380
+ continue
381
+ else:
382
+ for ext in ("jpg", "jpeg", "png"):
383
+ if ext in sample and isinstance(sample[ext], bytes):
384
+ try:
385
+ frames.append(_load_image_tensor(sample[ext]))
386
+ except Exception:
387
+ pass
388
+ break
389
+
390
+ if not frames:
391
+ return None
392
+
393
+ if len(frames) > max_frames:
394
+ frames = frames[:max_frames]
395
+
396
+ if replicate_image_frames > 1 and len(frames) == 1:
397
+ frames = frames * replicate_image_frames
398
+
399
+ num_frames = len(frames)
400
+ frames_tensor = torch.stack(frames, dim=0) # [T, 3, 224, 224]
401
+
402
+ # ------------------------------------------------------------------
403
+ # 4. Build text tensors
404
+ # ------------------------------------------------------------------
405
+ chosen_ids = torch.tensor(chosen_tok["token_ids"], dtype=torch.long)
406
+ chosen_mask = torch.tensor(chosen_tok["loss_mask"], dtype=torch.float32)
407
+ rejected_ids = torch.tensor(rejected_tok["token_ids"], dtype=torch.long)
408
+ rejected_mask = torch.tensor(rejected_tok["loss_mask"], dtype=torch.float32)
409
+
410
+ # Ensure consistent lengths within each pair
411
+ c_len = min(len(chosen_ids), len(chosen_mask))
412
+ chosen_ids = chosen_ids[:c_len]
413
+ chosen_mask = chosen_mask[:c_len]
414
+
415
+ r_len = min(len(rejected_ids), len(rejected_mask))
416
+ rejected_ids = rejected_ids[:r_len]
417
+ rejected_mask = rejected_mask[:r_len]
418
+
419
+ return {
420
+ "frames": frames_tensor, # [T, 3, 224, 224]
421
+ "chosen_input_ids": chosen_ids, # [S_c]
422
+ "chosen_loss_mask": chosen_mask, # [S_c]
423
+ "rejected_input_ids": rejected_ids, # [S_r]
424
+ "rejected_loss_mask": rejected_mask, # [S_r]
425
+ "num_frames": num_frames, # int
426
+ }
427
+
428
+
429
+ def _sample_decoder(max_frames: int, tokenizer=None, stage: int = 1,
430
+ replicate_image_frames: int = 1):
431
+ """Return a map function for use in a webdataset pipeline."""
432
+ def _decode(sample):
433
+ result = decode_sample(sample, max_frames=max_frames,
434
+ tokenizer=tokenizer, stage=stage,
435
+ replicate_image_frames=replicate_image_frames)
436
+ if result is None:
437
+ return None
438
+ return result
439
+ return _decode
440
+
441
+
442
+ def _dpo_sample_decoder(max_frames: int, tokenizer=None,
443
+ replicate_image_frames: int = 1):
444
+ """Return a map function for DPO samples in a webdataset pipeline."""
445
+ def _decode(sample):
446
+ result = decode_dpo_sample(sample, max_frames=max_frames,
447
+ tokenizer=tokenizer,
448
+ replicate_image_frames=replicate_image_frames)
449
+ if result is None:
450
+ return None
451
+ return result
452
+ return _decode
453
+
454
+
455
+ def _is_valid(sample) -> bool:
456
+ """Filter predicate: keep only successfully decoded samples."""
457
+ return sample is not None
458
+
459
+
460
+ def _min_frames_filter(min_frames: int):
461
+ """Filter predicate: keep only samples with >= min_frames frames."""
462
+ def _filter(sample):
463
+ return sample is not None and sample["frames"].shape[0] >= min_frames
464
+ return _filter
465
+
466
+
467
+ def _length_sort_buffer(buffer_size: int = 1000):
468
+ """
469
+ Sort samples by frame count within a rolling buffer.
470
+
471
+ When the DataLoader forms batches from consecutive samples, this ensures
472
+ samples with similar frame counts end up in the same batch — dramatically
473
+ reducing padding waste. A buffer of 1000 samples (default) gives good
474
+ grouping while maintaining enough randomization.
475
+ """
476
+ def _sort(src):
477
+ buf = []
478
+ for sample in src:
479
+ buf.append(sample)
480
+ if len(buf) >= buffer_size:
481
+ buf.sort(key=lambda s: s["frames"].shape[0])
482
+ yield from buf
483
+ buf = []
484
+ if buf:
485
+ buf.sort(key=lambda s: s["frames"].shape[0])
486
+ yield from buf
487
+ return _sort
488
+
489
+
490
+ def create_webdataset(
491
+ shard_pattern: str,
492
+ tokenizer=None,
493
+ stage: int = 1,
494
+ max_frames: int = 64,
495
+ min_frames: int = 0,
496
+ shuffle: bool = True,
497
+ seed: int = 42,
498
+ epoch: int = 0,
499
+ num_workers: int = 4,
500
+ batch_size: Optional[int] = None,
501
+ shardshuffle: int = 1000,
502
+ replicate_image_frames: int = 1,
503
+ ) -> wds.WebDataset:
504
+ """
505
+ Create a webdataset pipeline that streams tar shards.
506
+
507
+ Parameters
508
+ ----------
509
+ shard_pattern : str
510
+ Brace-expansion pattern for tar shards, e.g.
511
+ "/workspace/webvid_frames/{00000..02999}.tar"
512
+ tokenizer : optional
513
+ Tokenizer for on-the-fly tokenization of raw captions.
514
+ If None, samples must have pre-tokenized token_ids in JSON.
515
+ max_frames : int
516
+ Maximum number of frames per sample (extras truncated). Default 64,
517
+ matching SmolVLM2's frame cap.
518
+ shuffle : bool
519
+ Whether to shuffle shards and samples. Disable for deterministic
520
+ evaluation.
521
+ seed : int
522
+ Random seed for reproducible shard + sample shuffling.
523
+ epoch : int
524
+ Epoch counter — combined with seed for per-epoch shuffling so that
525
+ each epoch sees a different order without losing reproducibility.
526
+ num_workers : int
527
+ Hint for shard splitting across DataLoader workers. webdataset
528
+ handles the splitting internally via its nodesplitter.
529
+ batch_size : int, optional
530
+ If provided, the pipeline batches internally (rare — usually the
531
+ external DataLoader + collate_foveated handles batching).
532
+ shardshuffle : int
533
+ Buffer size for shard-level shuffle. Larger = better randomisation
534
+ at the cost of memory. 1000 shards ~= 1M samples for our shard
535
+ size of 1000 samples/shard.
536
+
537
+ Returns
538
+ -------
539
+ wds.WebDataset
540
+ An iterable dataset that yields dicts:
541
+ frames: [T, 3, 224, 224]
542
+ input_ids: [S]
543
+ loss_mask: [S]
544
+ num_frames: int
545
+ """
546
+ effective_seed = seed + epoch
547
+
548
+ # Resolve shard_pattern: can be a string glob, brace-expansion, or a list of globs.
549
+ # webdataset handles brace-expansion ({0000..0999}.tar) but NOT shell globs (*.tar).
550
+ import glob as globmod
551
+ if isinstance(shard_pattern, list):
552
+ urls = []
553
+ for pat in shard_pattern:
554
+ urls.extend(sorted(globmod.glob(pat)))
555
+ if not urls:
556
+ raise ValueError(f"No shards found for patterns: {shard_pattern}")
557
+ elif '*' in shard_pattern or '?' in shard_pattern:
558
+ urls = sorted(globmod.glob(shard_pattern))
559
+ if not urls:
560
+ raise ValueError(f"No shards found for pattern: {shard_pattern}")
561
+ else:
562
+ urls = shard_pattern
563
+
564
+ # Build the pipeline.
565
+ dataset = wds.WebDataset(
566
+ urls,
567
+ nodesplitter=wds.split_by_worker,
568
+ shardshuffle=shardshuffle if shuffle else False,
569
+ seed=effective_seed if shuffle else None,
570
+ empty_check=False, # avoid crash when workers get no valid samples
571
+ handler=wds.warn_and_continue, # skip corrupted shards instead of crashing
572
+ )
573
+
574
+ if shuffle:
575
+ # Shuffle within a buffer of samples (after shard-level shuffle).
576
+ dataset = dataset.shuffle(size=5000, seed=effective_seed)
577
+
578
+ # Decode: we do NOT use wds.decode() because we need custom multi-frame
579
+ # logic. Instead we pass raw bytes and decode in _sample_decoder.
580
+ dataset = dataset.map(_sample_decoder(max_frames, tokenizer=tokenizer, stage=stage,
581
+ replicate_image_frames=replicate_image_frames))
582
+ dataset = dataset.select(_is_valid)
583
+
584
+ if min_frames > 0:
585
+ dataset = dataset.select(_min_frames_filter(min_frames))
586
+
587
+ # Length-sort buffer DISABLED: grouping long videos into same batch causes
588
+ # (1) GPU OOM cascades (n_real > 700), (2) RAM growth from worker backlog
589
+ # during OOM retry loops, (3) system OOM crashes. Random batching with
590
+ # bucketed padding is safer and only ~10-15% less efficient.
591
+ # if shuffle:
592
+ # dataset = dataset.compose(_length_sort_buffer(128))
593
+
594
+ if batch_size is not None:
595
+ dataset = dataset.batched(batch_size)
596
+
597
+ return dataset
598
+
599
+
600
+ def create_dpo_webdataset(
601
+ shard_pattern: str,
602
+ tokenizer=None,
603
+ max_frames: int = 64,
604
+ shuffle: bool = True,
605
+ seed: int = 42,
606
+ epoch: int = 0,
607
+ num_workers: int = 4,
608
+ batch_size: Optional[int] = None,
609
+ shardshuffle: int = 1000,
610
+ replicate_image_frames: int = 1,
611
+ ) -> wds.WebDataset:
612
+ """
613
+ Create a webdataset pipeline for DPO (preference) data.
614
+
615
+ Each sample contains chosen and rejected responses for the same visual input.
616
+ Returns dicts with:
617
+ frames: [T, 3, 224, 224]
618
+ chosen_input_ids: [S_c]
619
+ chosen_loss_mask: [S_c]
620
+ rejected_input_ids: [S_r]
621
+ rejected_loss_mask: [S_r]
622
+ num_frames: int
623
+
624
+ Parameters
625
+ ----------
626
+ shard_pattern : str
627
+ Brace-expansion pattern for tar shards.
628
+ tokenizer : optional
629
+ Tokenizer for on-the-fly tokenization.
630
+ max_frames : int
631
+ Maximum number of frames per sample.
632
+ shuffle : bool
633
+ Whether to shuffle shards and samples.
634
+ seed : int
635
+ Random seed for shuffling.
636
+ epoch : int
637
+ Epoch counter for per-epoch shuffling.
638
+ num_workers : int
639
+ Hint for shard splitting.
640
+ batch_size : int, optional
641
+ If provided, batch internally (rare).
642
+ shardshuffle : int
643
+ Buffer size for shard-level shuffle.
644
+ replicate_image_frames : int
645
+ Replicate single-frame images to N frames.
646
+ """
647
+ effective_seed = seed + epoch
648
+
649
+ import glob as globmod
650
+ if isinstance(shard_pattern, list):
651
+ urls = []
652
+ for pat in shard_pattern:
653
+ urls.extend(sorted(globmod.glob(pat)))
654
+ if not urls:
655
+ raise ValueError(f"No shards found for patterns: {shard_pattern}")
656
+ elif '*' in shard_pattern or '?' in shard_pattern:
657
+ urls = sorted(globmod.glob(shard_pattern))
658
+ if not urls:
659
+ raise ValueError(f"No shards found for pattern: {shard_pattern}")
660
+ else:
661
+ urls = shard_pattern
662
+
663
+ dataset = wds.WebDataset(
664
+ urls,
665
+ nodesplitter=wds.split_by_worker,
666
+ shardshuffle=shardshuffle if shuffle else False,
667
+ seed=effective_seed if shuffle else None,
668
+ empty_check=False,
669
+ handler=wds.warn_and_continue,
670
+ )
671
+
672
+ if shuffle:
673
+ dataset = dataset.shuffle(size=5000, seed=effective_seed)
674
+
675
+ dataset = dataset.map(_dpo_sample_decoder(max_frames, tokenizer=tokenizer,
676
+ replicate_image_frames=replicate_image_frames))
677
+ dataset = dataset.select(_is_valid)
678
+
679
+ if batch_size is not None:
680
+ dataset = dataset.batched(batch_size)
681
+
682
+ return dataset
683
+
684
+
685
+ def make_dynamic_dataloader(
686
+ shard_pattern: str,
687
+ max_total_frames: int = 512,
688
+ max_batch_size: int = 64,
689
+ max_frames: int = 64,
690
+ min_frames: int = 0,
691
+ shuffle: bool = True,
692
+ seed: int = 42,
693
+ epoch: int = 0,
694
+ num_workers: int = 4,
695
+ pin_memory: bool = True,
696
+ prefetch_factor: int = 4,
697
+ tokenizer=None,
698
+ stage: int = 1,
699
+ replicate_image_frames: int = 1,
700
+ ) -> torch.utils.data.DataLoader:
701
+ """
702
+ Dynamic-batch dataloader: batch size varies per batch based on total
703
+ frame count. Short-video batches get more samples; long-video batches
704
+ get fewer. Total frames per batch is capped at max_total_frames.
705
+
706
+ This keeps GPU work roughly constant across batches and eliminates the
707
+ pathological case where one T=64 sample forces the entire batch to pad
708
+ to 64 frames.
709
+ """
710
+ from collate import token_budget_batcher
711
+
712
+ dataset = create_webdataset(
713
+ shard_pattern=shard_pattern,
714
+ tokenizer=tokenizer,
715
+ stage=stage,
716
+ max_frames=max_frames,
717
+ min_frames=min_frames,
718
+ shuffle=shuffle,
719
+ seed=seed,
720
+ epoch=epoch,
721
+ num_workers=num_workers,
722
+ replicate_image_frames=replicate_image_frames,
723
+ )
724
+
725
+ # The batcher forms variable-size batches and collates them internally.
726
+ # length_bucket=True sorts by total length within a buffer to reduce padding waste.
727
+ dataset = dataset.compose(token_budget_batcher(
728
+ max_total_frames, max_batch_size,
729
+ length_bucket=True, bucket_buffer=max_batch_size * 4,
730
+ ))
731
+
732
+ # batch_size=None: each dataset item is already a collated batch dict
733
+ loader = torch.utils.data.DataLoader(
734
+ dataset,
735
+ batch_size=None,
736
+ num_workers=num_workers,
737
+ pin_memory=pin_memory,
738
+ prefetch_factor=prefetch_factor if num_workers > 0 else None,
739
+ persistent_workers=num_workers > 0,
740
+ )
741
+ return loader
742
+
743
+
744
+ def make_dataloader(
745
+ shard_pattern: str,
746
+ batch_size: int,
747
+ max_frames: int = 64,
748
+ min_frames: int = 0,
749
+ shuffle: bool = True,
750
+ seed: int = 42,
751
+ epoch: int = 0,
752
+ num_workers: int = 4,
753
+ collate_fn=None,
754
+ pin_memory: bool = True,
755
+ prefetch_factor: int = 4,
756
+ tokenizer=None,
757
+ stage: int = 1,
758
+ replicate_image_frames: int = 1,
759
+ ) -> torch.utils.data.DataLoader:
760
+ """
761
+ Convenience wrapper: creates the webdataset pipeline and wraps it in a
762
+ standard PyTorch DataLoader with the given collate function.
763
+
764
+ If collate_fn is None, use collate.collate_foveated.
765
+ """
766
+ if collate_fn is None:
767
+ from collate import collate_foveated
768
+ collate_fn = collate_foveated
769
+
770
+ dataset = create_webdataset(
771
+ shard_pattern=shard_pattern,
772
+ tokenizer=tokenizer,
773
+ stage=stage,
774
+ max_frames=max_frames,
775
+ min_frames=min_frames,
776
+ shuffle=shuffle,
777
+ seed=seed,
778
+ epoch=epoch,
779
+ num_workers=num_workers,
780
+ replicate_image_frames=replicate_image_frames,
781
+ )
782
+
783
+ loader = torch.utils.data.DataLoader(
784
+ dataset,
785
+ batch_size=batch_size,
786
+ num_workers=num_workers,
787
+ collate_fn=collate_fn,
788
+ pin_memory=pin_memory,
789
+ prefetch_factor=prefetch_factor if num_workers > 0 else None,
790
+ persistent_workers=num_workers > 0,
791
+ )
792
+ return loader