| | """ |
| | WebDataset-based data loader for foveated VLM training. |
| | |
| | Reads tar shards produced by video2dataset / the CPU precompute pipeline. |
| | Each sample in a shard contains EITHER: |
| | A) Pre-extracted frames: |
| | - {key}.jpg or {key}_000.jpg, {key}_001.jpg, ... -- JPEG frames (224x224) |
| | - {key}.json -- metadata: {caption, token_ids, loss_mask, ...} |
| | B) Raw MP4 from video2dataset: |
| | - {key}.mp4 -- raw video file |
| | - {key}.txt -- caption text |
| | - {key}.json -- metadata: {videoid, duration, url, ...} |
| | |
| | On-the-fly tokenization: if token_ids/loss_mask are missing from JSON, |
| | the sample is tokenized at load time using the provided tokenizer. |
| | |
| | Returns dicts with: |
| | frames: [T, 3, 224, 224] float32, ImageNet-normalized for DINO |
| | input_ids: [S] long, token IDs |
| | loss_mask: [S] float32, 1.0 for answer tokens, 0.0 otherwise |
| | num_frames: int actual frame count before any padding |
| | """ |
| |
|
| | import io |
| | import json |
| | import os |
| | import re |
| | import subprocess |
| | import tempfile |
| | from typing import Optional |
| |
|
| | import torch |
| | import torchvision.transforms.functional as TF |
| | import webdataset as wds |
| |
|
| | |
| | IMAGENET_MEAN = (0.485, 0.456, 0.406) |
| | IMAGENET_STD = (0.229, 0.224, 0.225) |
| |
|
| | |
| | _FRAME_INDEX_RE = re.compile(r"^(.+)_(\d{3})\.(jpg|jpeg|png)$") |
| |
|
| | |
| | _SINGLE_FRAME_RE = re.compile(r"^(.+)\.(jpg|jpeg|png)$") |
| |
|
| |
|
| | _NORM_MEAN = torch.tensor(IMAGENET_MEAN).view(3, 1, 1) |
| | _NORM_STD = torch.tensor(IMAGENET_STD).view(3, 1, 1) |
| |
|
| |
|
| | def _load_image_tensor(data: bytes) -> torch.Tensor: |
| | """Decode JPEG/PNG bytes to a [3, 224, 224] float32 tensor, ImageNet-normalized.""" |
| | try: |
| | |
| | from torchvision.io import decode_jpeg |
| | raw = torch.frombuffer(bytearray(data), dtype=torch.uint8) |
| | tensor = decode_jpeg(raw).float().div_(255.0) |
| | tensor.sub_(_NORM_MEAN).div_(_NORM_STD) |
| | return tensor |
| | except Exception: |
| | |
| | from PIL import Image |
| | img = Image.open(io.BytesIO(data)).convert("RGB") |
| | tensor = TF.to_tensor(img) |
| | tensor = TF.normalize(tensor, mean=IMAGENET_MEAN, std=IMAGENET_STD) |
| | return tensor |
| |
|
| |
|
| | def _decode_mp4_frames(mp4_bytes: bytes, max_frames: int = 64) -> list[torch.Tensor]: |
| | """Decode MP4 bytes to a list of [3, 224, 224] tensors at 1 FPS.""" |
| | try: |
| | import decord |
| | decord.bridge.set_bridge("torch") |
| | vr = decord.VideoReader(io.BytesIO(mp4_bytes), width=224, height=224) |
| | fps = vr.get_avg_fps() |
| | total = len(vr) |
| | |
| | step = max(1, int(fps)) |
| | indices = list(range(0, total, step))[:max_frames] |
| | if not indices: |
| | return [] |
| | batch = vr.get_batch(indices) |
| | frames = [] |
| | for i in range(batch.shape[0]): |
| | t = batch[i].permute(2, 0, 1).float() / 255.0 |
| | t = TF.normalize(t, mean=IMAGENET_MEAN, std=IMAGENET_STD) |
| | frames.append(t) |
| | return frames |
| | except ImportError: |
| | pass |
| |
|
| | |
| | with tempfile.NamedTemporaryFile(suffix=".mp4", dir="/workspace/tmp", delete=True) as f: |
| | f.write(mp4_bytes) |
| | f.flush() |
| | frames_dir = f.name + "_frames" |
| | os.makedirs(frames_dir, exist_ok=True) |
| | try: |
| | subprocess.run( |
| | ["ffmpeg", "-y", "-i", f.name, |
| | "-vf", "fps=1,scale=224:224:force_original_aspect_ratio=increase,crop=224:224", |
| | "-frames:v", str(max_frames), "-q:v", "2", |
| | os.path.join(frames_dir, "frame_%03d.jpg")], |
| | capture_output=True, timeout=30, |
| | ) |
| | from PIL import Image |
| | frame_files = sorted(os.listdir(frames_dir)) |
| | frames = [] |
| | for fname in frame_files[:max_frames]: |
| | fp = os.path.join(frames_dir, fname) |
| | img = Image.open(fp).convert("RGB") |
| | t = TF.to_tensor(img) |
| | t = TF.normalize(t, mean=IMAGENET_MEAN, std=IMAGENET_STD) |
| | frames.append(t) |
| | return frames |
| | except Exception: |
| | return [] |
| | finally: |
| | import shutil |
| | shutil.rmtree(frames_dir, ignore_errors=True) |
| |
|
| |
|
| | def decode_sample(sample: dict, max_frames: int = 64, |
| | tokenizer=None, stage: int = 1, |
| | replicate_image_frames: int = 1) -> Optional[dict]: |
| | """ |
| | Decode a single webdataset sample dict into training tensors. |
| | |
| | The sample dict has keys like: |
| | "jpg" or "jpeg" or "png" -- single frame bytes |
| | "000.jpg", "001.jpg", ... -- multi-frame bytes |
| | "json" -- metadata JSON bytes or dict |
| | |
| | Returns None if the sample is malformed (caller should filter). |
| | """ |
| | |
| | |
| | |
| | meta_raw = sample.get("json") |
| | if meta_raw is None: |
| | return None |
| |
|
| | if isinstance(meta_raw, bytes): |
| | try: |
| | meta = json.loads(meta_raw.decode("utf-8")) |
| | except (json.JSONDecodeError, UnicodeDecodeError): |
| | return None |
| | elif isinstance(meta_raw, str): |
| | try: |
| | meta = json.loads(meta_raw) |
| | except json.JSONDecodeError: |
| | return None |
| | elif isinstance(meta_raw, dict): |
| | meta = meta_raw |
| | else: |
| | return None |
| |
|
| | token_ids = meta.get("token_ids") |
| | loss_mask = meta.get("loss_mask") |
| |
|
| | |
| | if token_ids is None or loss_mask is None: |
| | from tokenization import ( |
| | tokenize_stage1, tokenize_sft, SOURCE_PROMPTS, DEFAULT_VISUAL_PROMPT, |
| | ) |
| |
|
| | |
| | user_text = meta.get("user", "") |
| | assistant_text = meta.get("assistant", "") |
| | source = meta.get("source", "") |
| |
|
| | if user_text or assistant_text: |
| | |
| | is_text_only = meta.get("frame_count", 0) == 0 |
| | if stage == 1 and not is_text_only: |
| | |
| | |
| | user_prompt = user_text if user_text else SOURCE_PROMPTS.get(source, DEFAULT_VISUAL_PROMPT) |
| | tok = tokenize_stage1(assistant_text, tokenizer=tokenizer, user_prompt=user_prompt) |
| | elif stage == 1 and is_text_only: |
| | |
| | tok = tokenize_sft( |
| | user_text, |
| | assistant_text, |
| | stage=stage, |
| | tokenizer=tokenizer, |
| | ) |
| | tok["loss_mask"] = [1] * len(tok["token_ids"]) |
| | else: |
| | |
| | |
| | effective_user = user_text if user_text else SOURCE_PROMPTS.get(source, DEFAULT_VISUAL_PROMPT) |
| | tok = tokenize_sft( |
| | effective_user, |
| | assistant_text, |
| | stage=stage, |
| | tokenizer=tokenizer, |
| | ) |
| | else: |
| | |
| | caption = meta.get("caption", "") |
| | if not caption: |
| | txt_raw = sample.get("txt") |
| | if isinstance(txt_raw, bytes): |
| | caption = txt_raw.decode("utf-8", errors="replace").strip() |
| | elif isinstance(txt_raw, str): |
| | caption = txt_raw.strip() |
| |
|
| | if not caption or tokenizer is None: |
| | return None |
| |
|
| | user_prompt = SOURCE_PROMPTS.get(source, DEFAULT_VISUAL_PROMPT) |
| | if stage == 1: |
| | tok = tokenize_stage1(caption, tokenizer=tokenizer, user_prompt=user_prompt) |
| | else: |
| | tok = tokenize_sft(user_prompt, caption, stage=stage, tokenizer=tokenizer) |
| |
|
| | if tokenizer is None: |
| | return None |
| |
|
| | token_ids = tok["token_ids"] |
| | loss_mask = tok["loss_mask"] |
| |
|
| | |
| | |
| | |
| | frames: list[torch.Tensor] = [] |
| |
|
| | |
| | mp4_data = sample.get("mp4") |
| | if isinstance(mp4_data, bytes) and len(mp4_data) > 100: |
| | frames = _decode_mp4_frames(mp4_data, max_frames=max_frames) |
| | else: |
| | |
| | numbered_keys: list[tuple[int, str]] = [] |
| | for key in sample: |
| | m = re.match(r"^(\d{3})\.(jpg|jpeg|png)$", key) |
| | if m: |
| | numbered_keys.append((int(m.group(1)), key)) |
| |
|
| | if numbered_keys: |
| | numbered_keys.sort(key=lambda x: x[0]) |
| | for _, key in numbered_keys: |
| | raw = sample[key] |
| | if isinstance(raw, bytes): |
| | try: |
| | frames.append(_load_image_tensor(raw)) |
| | except Exception: |
| | continue |
| | else: |
| | |
| | for ext in ("jpg", "jpeg", "png"): |
| | if ext in sample and isinstance(sample[ext], bytes): |
| | try: |
| | frames.append(_load_image_tensor(sample[ext])) |
| | except Exception: |
| | pass |
| | break |
| |
|
| | if not frames: |
| | return None |
| |
|
| | |
| | if len(frames) > max_frames: |
| | frames = frames[:max_frames] |
| |
|
| | |
| | if replicate_image_frames > 1 and len(frames) == 1: |
| | frames = frames * replicate_image_frames |
| |
|
| | num_frames = len(frames) |
| | frames_tensor = torch.stack(frames, dim=0) |
| |
|
| | |
| | |
| | |
| | input_ids = torch.tensor(token_ids, dtype=torch.long) |
| | loss_mask_t = torch.tensor(loss_mask, dtype=torch.float32) |
| |
|
| | |
| | min_len = min(len(input_ids), len(loss_mask_t)) |
| | input_ids = input_ids[:min_len] |
| | loss_mask_t = loss_mask_t[:min_len] |
| |
|
| | return { |
| | "frames": frames_tensor, |
| | "input_ids": input_ids, |
| | "loss_mask": loss_mask_t, |
| | "num_frames": num_frames, |
| | } |
| |
|
| |
|
| | def decode_dpo_sample(sample: dict, max_frames: int = 64, |
| | tokenizer=None, replicate_image_frames: int = 1) -> Optional[dict]: |
| | """ |
| | Decode a single DPO webdataset sample into training tensors. |
| | |
| | DPO samples have JSON with keys: |
| | user: user prompt |
| | chosen_assistant: preferred response |
| | rejected_assistant: dispreferred response |
| | source: dataset source (e.g. "rlaif_v") |
| | frame_count: number of frames (1 for images) |
| | |
| | Returns None if the sample is malformed (caller should filter). |
| | |
| | Returns dict with: |
| | frames: [T, 3, 224, 224] shared visual input |
| | chosen_input_ids: [S_c] tokenized user+chosen |
| | chosen_loss_mask: [S_c] answer-only mask for chosen |
| | rejected_input_ids: [S_r] tokenized user+rejected |
| | rejected_loss_mask: [S_r] answer-only mask for rejected |
| | num_frames: int actual frame count |
| | """ |
| | |
| | |
| | |
| | meta_raw = sample.get("json") |
| | if meta_raw is None: |
| | return None |
| |
|
| | if isinstance(meta_raw, bytes): |
| | try: |
| | meta = json.loads(meta_raw.decode("utf-8")) |
| | except (json.JSONDecodeError, UnicodeDecodeError): |
| | return None |
| | elif isinstance(meta_raw, str): |
| | try: |
| | meta = json.loads(meta_raw) |
| | except json.JSONDecodeError: |
| | return None |
| | elif isinstance(meta_raw, dict): |
| | meta = meta_raw |
| | else: |
| | return None |
| |
|
| | user_text = meta.get("user", "") |
| | chosen_text = meta.get("chosen_assistant", "") |
| | rejected_text = meta.get("rejected_assistant", "") |
| |
|
| | if not chosen_text or not rejected_text: |
| | return None |
| | if tokenizer is None: |
| | return None |
| |
|
| | |
| | |
| | |
| | from tokenization import tokenize_sft, SOURCE_PROMPTS, DEFAULT_VISUAL_PROMPT |
| |
|
| | source = meta.get("source", "") |
| | effective_user = user_text if user_text else SOURCE_PROMPTS.get(source, DEFAULT_VISUAL_PROMPT) |
| |
|
| | chosen_tok = tokenize_sft(effective_user, chosen_text, stage=3, tokenizer=tokenizer) |
| | rejected_tok = tokenize_sft(effective_user, rejected_text, stage=3, tokenizer=tokenizer) |
| |
|
| | |
| | |
| | |
| | frames: list[torch.Tensor] = [] |
| |
|
| | mp4_data = sample.get("mp4") |
| | if isinstance(mp4_data, bytes) and len(mp4_data) > 100: |
| | frames = _decode_mp4_frames(mp4_data, max_frames=max_frames) |
| | else: |
| | numbered_keys: list[tuple[int, str]] = [] |
| | for key in sample: |
| | m = re.match(r"^(\d{3})\.(jpg|jpeg|png)$", key) |
| | if m: |
| | numbered_keys.append((int(m.group(1)), key)) |
| |
|
| | if numbered_keys: |
| | numbered_keys.sort(key=lambda x: x[0]) |
| | for _, key in numbered_keys: |
| | raw = sample[key] |
| | if isinstance(raw, bytes): |
| | try: |
| | frames.append(_load_image_tensor(raw)) |
| | except Exception: |
| | continue |
| | else: |
| | for ext in ("jpg", "jpeg", "png"): |
| | if ext in sample and isinstance(sample[ext], bytes): |
| | try: |
| | frames.append(_load_image_tensor(sample[ext])) |
| | except Exception: |
| | pass |
| | break |
| |
|
| | if not frames: |
| | return None |
| |
|
| | if len(frames) > max_frames: |
| | frames = frames[:max_frames] |
| |
|
| | if replicate_image_frames > 1 and len(frames) == 1: |
| | frames = frames * replicate_image_frames |
| |
|
| | num_frames = len(frames) |
| | frames_tensor = torch.stack(frames, dim=0) |
| |
|
| | |
| | |
| | |
| | chosen_ids = torch.tensor(chosen_tok["token_ids"], dtype=torch.long) |
| | chosen_mask = torch.tensor(chosen_tok["loss_mask"], dtype=torch.float32) |
| | rejected_ids = torch.tensor(rejected_tok["token_ids"], dtype=torch.long) |
| | rejected_mask = torch.tensor(rejected_tok["loss_mask"], dtype=torch.float32) |
| |
|
| | |
| | c_len = min(len(chosen_ids), len(chosen_mask)) |
| | chosen_ids = chosen_ids[:c_len] |
| | chosen_mask = chosen_mask[:c_len] |
| |
|
| | r_len = min(len(rejected_ids), len(rejected_mask)) |
| | rejected_ids = rejected_ids[:r_len] |
| | rejected_mask = rejected_mask[:r_len] |
| |
|
| | return { |
| | "frames": frames_tensor, |
| | "chosen_input_ids": chosen_ids, |
| | "chosen_loss_mask": chosen_mask, |
| | "rejected_input_ids": rejected_ids, |
| | "rejected_loss_mask": rejected_mask, |
| | "num_frames": num_frames, |
| | } |
| |
|
| |
|
| | def _sample_decoder(max_frames: int, tokenizer=None, stage: int = 1, |
| | replicate_image_frames: int = 1): |
| | """Return a map function for use in a webdataset pipeline.""" |
| | def _decode(sample): |
| | result = decode_sample(sample, max_frames=max_frames, |
| | tokenizer=tokenizer, stage=stage, |
| | replicate_image_frames=replicate_image_frames) |
| | if result is None: |
| | return None |
| | return result |
| | return _decode |
| |
|
| |
|
| | def _dpo_sample_decoder(max_frames: int, tokenizer=None, |
| | replicate_image_frames: int = 1): |
| | """Return a map function for DPO samples in a webdataset pipeline.""" |
| | def _decode(sample): |
| | result = decode_dpo_sample(sample, max_frames=max_frames, |
| | tokenizer=tokenizer, |
| | replicate_image_frames=replicate_image_frames) |
| | if result is None: |
| | return None |
| | return result |
| | return _decode |
| |
|
| |
|
| | def _is_valid(sample) -> bool: |
| | """Filter predicate: keep only successfully decoded samples.""" |
| | return sample is not None |
| |
|
| |
|
| | def _min_frames_filter(min_frames: int): |
| | """Filter predicate: keep only samples with >= min_frames frames.""" |
| | def _filter(sample): |
| | return sample is not None and sample["frames"].shape[0] >= min_frames |
| | return _filter |
| |
|
| |
|
| | def _length_sort_buffer(buffer_size: int = 1000): |
| | """ |
| | Sort samples by frame count within a rolling buffer. |
| | |
| | When the DataLoader forms batches from consecutive samples, this ensures |
| | samples with similar frame counts end up in the same batch — dramatically |
| | reducing padding waste. A buffer of 1000 samples (default) gives good |
| | grouping while maintaining enough randomization. |
| | """ |
| | def _sort(src): |
| | buf = [] |
| | for sample in src: |
| | buf.append(sample) |
| | if len(buf) >= buffer_size: |
| | buf.sort(key=lambda s: s["frames"].shape[0]) |
| | yield from buf |
| | buf = [] |
| | if buf: |
| | buf.sort(key=lambda s: s["frames"].shape[0]) |
| | yield from buf |
| | return _sort |
| |
|
| |
|
| | def create_webdataset( |
| | shard_pattern: str, |
| | tokenizer=None, |
| | stage: int = 1, |
| | max_frames: int = 64, |
| | min_frames: int = 0, |
| | shuffle: bool = True, |
| | seed: int = 42, |
| | epoch: int = 0, |
| | num_workers: int = 4, |
| | batch_size: Optional[int] = None, |
| | shardshuffle: int = 1000, |
| | replicate_image_frames: int = 1, |
| | ) -> wds.WebDataset: |
| | """ |
| | Create a webdataset pipeline that streams tar shards. |
| | |
| | Parameters |
| | ---------- |
| | shard_pattern : str |
| | Brace-expansion pattern for tar shards, e.g. |
| | "/workspace/webvid_frames/{00000..02999}.tar" |
| | tokenizer : optional |
| | Tokenizer for on-the-fly tokenization of raw captions. |
| | If None, samples must have pre-tokenized token_ids in JSON. |
| | max_frames : int |
| | Maximum number of frames per sample (extras truncated). Default 64, |
| | matching SmolVLM2's frame cap. |
| | shuffle : bool |
| | Whether to shuffle shards and samples. Disable for deterministic |
| | evaluation. |
| | seed : int |
| | Random seed for reproducible shard + sample shuffling. |
| | epoch : int |
| | Epoch counter — combined with seed for per-epoch shuffling so that |
| | each epoch sees a different order without losing reproducibility. |
| | num_workers : int |
| | Hint for shard splitting across DataLoader workers. webdataset |
| | handles the splitting internally via its nodesplitter. |
| | batch_size : int, optional |
| | If provided, the pipeline batches internally (rare — usually the |
| | external DataLoader + collate_foveated handles batching). |
| | shardshuffle : int |
| | Buffer size for shard-level shuffle. Larger = better randomisation |
| | at the cost of memory. 1000 shards ~= 1M samples for our shard |
| | size of 1000 samples/shard. |
| | |
| | Returns |
| | ------- |
| | wds.WebDataset |
| | An iterable dataset that yields dicts: |
| | frames: [T, 3, 224, 224] |
| | input_ids: [S] |
| | loss_mask: [S] |
| | num_frames: int |
| | """ |
| | effective_seed = seed + epoch |
| |
|
| | |
| | |
| | import glob as globmod |
| | if isinstance(shard_pattern, list): |
| | urls = [] |
| | for pat in shard_pattern: |
| | urls.extend(sorted(globmod.glob(pat))) |
| | if not urls: |
| | raise ValueError(f"No shards found for patterns: {shard_pattern}") |
| | elif '*' in shard_pattern or '?' in shard_pattern: |
| | urls = sorted(globmod.glob(shard_pattern)) |
| | if not urls: |
| | raise ValueError(f"No shards found for pattern: {shard_pattern}") |
| | else: |
| | urls = shard_pattern |
| |
|
| | |
| | dataset = wds.WebDataset( |
| | urls, |
| | nodesplitter=wds.split_by_worker, |
| | shardshuffle=shardshuffle if shuffle else False, |
| | seed=effective_seed if shuffle else None, |
| | empty_check=False, |
| | handler=wds.warn_and_continue, |
| | ) |
| |
|
| | if shuffle: |
| | |
| | dataset = dataset.shuffle(size=5000, seed=effective_seed) |
| |
|
| | |
| | |
| | dataset = dataset.map(_sample_decoder(max_frames, tokenizer=tokenizer, stage=stage, |
| | replicate_image_frames=replicate_image_frames)) |
| | dataset = dataset.select(_is_valid) |
| |
|
| | if min_frames > 0: |
| | dataset = dataset.select(_min_frames_filter(min_frames)) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if batch_size is not None: |
| | dataset = dataset.batched(batch_size) |
| |
|
| | return dataset |
| |
|
| |
|
| | def create_dpo_webdataset( |
| | shard_pattern: str, |
| | tokenizer=None, |
| | max_frames: int = 64, |
| | shuffle: bool = True, |
| | seed: int = 42, |
| | epoch: int = 0, |
| | num_workers: int = 4, |
| | batch_size: Optional[int] = None, |
| | shardshuffle: int = 1000, |
| | replicate_image_frames: int = 1, |
| | ) -> wds.WebDataset: |
| | """ |
| | Create a webdataset pipeline for DPO (preference) data. |
| | |
| | Each sample contains chosen and rejected responses for the same visual input. |
| | Returns dicts with: |
| | frames: [T, 3, 224, 224] |
| | chosen_input_ids: [S_c] |
| | chosen_loss_mask: [S_c] |
| | rejected_input_ids: [S_r] |
| | rejected_loss_mask: [S_r] |
| | num_frames: int |
| | |
| | Parameters |
| | ---------- |
| | shard_pattern : str |
| | Brace-expansion pattern for tar shards. |
| | tokenizer : optional |
| | Tokenizer for on-the-fly tokenization. |
| | max_frames : int |
| | Maximum number of frames per sample. |
| | shuffle : bool |
| | Whether to shuffle shards and samples. |
| | seed : int |
| | Random seed for shuffling. |
| | epoch : int |
| | Epoch counter for per-epoch shuffling. |
| | num_workers : int |
| | Hint for shard splitting. |
| | batch_size : int, optional |
| | If provided, batch internally (rare). |
| | shardshuffle : int |
| | Buffer size for shard-level shuffle. |
| | replicate_image_frames : int |
| | Replicate single-frame images to N frames. |
| | """ |
| | effective_seed = seed + epoch |
| |
|
| | import glob as globmod |
| | if isinstance(shard_pattern, list): |
| | urls = [] |
| | for pat in shard_pattern: |
| | urls.extend(sorted(globmod.glob(pat))) |
| | if not urls: |
| | raise ValueError(f"No shards found for patterns: {shard_pattern}") |
| | elif '*' in shard_pattern or '?' in shard_pattern: |
| | urls = sorted(globmod.glob(shard_pattern)) |
| | if not urls: |
| | raise ValueError(f"No shards found for pattern: {shard_pattern}") |
| | else: |
| | urls = shard_pattern |
| |
|
| | dataset = wds.WebDataset( |
| | urls, |
| | nodesplitter=wds.split_by_worker, |
| | shardshuffle=shardshuffle if shuffle else False, |
| | seed=effective_seed if shuffle else None, |
| | empty_check=False, |
| | handler=wds.warn_and_continue, |
| | ) |
| |
|
| | if shuffle: |
| | dataset = dataset.shuffle(size=5000, seed=effective_seed) |
| |
|
| | dataset = dataset.map(_dpo_sample_decoder(max_frames, tokenizer=tokenizer, |
| | replicate_image_frames=replicate_image_frames)) |
| | dataset = dataset.select(_is_valid) |
| |
|
| | if batch_size is not None: |
| | dataset = dataset.batched(batch_size) |
| |
|
| | return dataset |
| |
|
| |
|
| | def make_dynamic_dataloader( |
| | shard_pattern: str, |
| | max_total_frames: int = 512, |
| | max_batch_size: int = 64, |
| | max_frames: int = 64, |
| | min_frames: int = 0, |
| | shuffle: bool = True, |
| | seed: int = 42, |
| | epoch: int = 0, |
| | num_workers: int = 4, |
| | pin_memory: bool = True, |
| | prefetch_factor: int = 4, |
| | tokenizer=None, |
| | stage: int = 1, |
| | replicate_image_frames: int = 1, |
| | ) -> torch.utils.data.DataLoader: |
| | """ |
| | Dynamic-batch dataloader: batch size varies per batch based on total |
| | frame count. Short-video batches get more samples; long-video batches |
| | get fewer. Total frames per batch is capped at max_total_frames. |
| | |
| | This keeps GPU work roughly constant across batches and eliminates the |
| | pathological case where one T=64 sample forces the entire batch to pad |
| | to 64 frames. |
| | """ |
| | from collate import token_budget_batcher |
| |
|
| | dataset = create_webdataset( |
| | shard_pattern=shard_pattern, |
| | tokenizer=tokenizer, |
| | stage=stage, |
| | max_frames=max_frames, |
| | min_frames=min_frames, |
| | shuffle=shuffle, |
| | seed=seed, |
| | epoch=epoch, |
| | num_workers=num_workers, |
| | replicate_image_frames=replicate_image_frames, |
| | ) |
| |
|
| | |
| | |
| | dataset = dataset.compose(token_budget_batcher( |
| | max_total_frames, max_batch_size, |
| | length_bucket=True, bucket_buffer=max_batch_size * 4, |
| | )) |
| |
|
| | |
| | loader = torch.utils.data.DataLoader( |
| | dataset, |
| | batch_size=None, |
| | num_workers=num_workers, |
| | pin_memory=pin_memory, |
| | prefetch_factor=prefetch_factor if num_workers > 0 else None, |
| | persistent_workers=num_workers > 0, |
| | ) |
| | return loader |
| |
|
| |
|
| | def make_dataloader( |
| | shard_pattern: str, |
| | batch_size: int, |
| | max_frames: int = 64, |
| | min_frames: int = 0, |
| | shuffle: bool = True, |
| | seed: int = 42, |
| | epoch: int = 0, |
| | num_workers: int = 4, |
| | collate_fn=None, |
| | pin_memory: bool = True, |
| | prefetch_factor: int = 4, |
| | tokenizer=None, |
| | stage: int = 1, |
| | replicate_image_frames: int = 1, |
| | ) -> torch.utils.data.DataLoader: |
| | """ |
| | Convenience wrapper: creates the webdataset pipeline and wraps it in a |
| | standard PyTorch DataLoader with the given collate function. |
| | |
| | If collate_fn is None, use collate.collate_foveated. |
| | """ |
| | if collate_fn is None: |
| | from collate import collate_foveated |
| | collate_fn = collate_foveated |
| |
|
| | dataset = create_webdataset( |
| | shard_pattern=shard_pattern, |
| | tokenizer=tokenizer, |
| | stage=stage, |
| | max_frames=max_frames, |
| | min_frames=min_frames, |
| | shuffle=shuffle, |
| | seed=seed, |
| | epoch=epoch, |
| | num_workers=num_workers, |
| | replicate_image_frames=replicate_image_frames, |
| | ) |
| |
|
| | loader = torch.utils.data.DataLoader( |
| | dataset, |
| | batch_size=batch_size, |
| | num_workers=num_workers, |
| | collate_fn=collate_fn, |
| | pin_memory=pin_memory, |
| | prefetch_factor=prefetch_factor if num_workers > 0 else None, |
| | persistent_workers=num_workers > 0, |
| | ) |
| | return loader |
| |
|