the-well-diffusion / data_pipeline.py
AlexWortega's picture
Upload data_pipeline.py with huggingface_hub
51a44ad verified
"""
Streaming data pipeline for The Well datasets.
Handles HF streaming and local loading with robust error recovery.
"""
import torch
from torch.utils.data import DataLoader
import logging
logger = logging.getLogger(__name__)
def create_dataloader(
dataset_name="turbulent_radiative_layer_2D",
split="train",
batch_size=4,
n_steps_input=1,
n_steps_output=1,
num_workers=0,
streaming=True,
local_path=None,
use_normalization=True,
):
"""Create a DataLoader for a Well dataset.
Args:
dataset_name: Name of the Well dataset.
split: 'train', 'valid', or 'test'.
batch_size: Batch size.
n_steps_input: Number of input timesteps.
n_steps_output: Number of output timesteps.
num_workers: DataLoader workers (0 for streaming recommended).
streaming: If True, stream from HuggingFace Hub.
local_path: Path to local data (used if streaming=False).
use_normalization: Whether to normalize data.
Returns:
(DataLoader, WellDataset)
"""
from the_well.data import WellDataset
base_path = "hf://datasets/polymathic-ai/" if streaming else local_path
if base_path is None:
raise ValueError("Must provide local_path when streaming=False")
logger.info(f"Creating dataset: {dataset_name}/{split} (streaming={streaming})")
dataset = WellDataset(
well_base_path=base_path,
well_dataset_name=dataset_name,
well_split_name=split,
n_steps_input=n_steps_input,
n_steps_output=n_steps_output,
use_normalization=use_normalization,
flatten_tensors=True,
)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=(split == "train"),
num_workers=num_workers,
pin_memory=True,
drop_last=True,
persistent_workers=num_workers > 0,
)
return loader, dataset
def to_channels_first(x):
"""Convert Well format [B, T, H, W, C] to PyTorch [B, T*C, H, W]."""
if x.dim() == 5: # [B, T, H, W, C]
B, T, H, W, C = x.shape
return x.permute(0, 1, 4, 2, 3).reshape(B, T * C, H, W)
elif x.dim() == 4: # [B, H, W, C] (no time dim)
return x.permute(0, 3, 1, 2)
elif x.dim() == 3: # [H, W, C] single sample
return x.permute(2, 0, 1)
return x
def prepare_batch(batch, device="cuda"):
"""Convert a Well batch to model-ready tensors.
Returns:
x_input: [B, Ti*C, H, W] condition frames (channels-first)
x_output: [B, To*C, H, W] target frames (channels-first)
"""
input_fields = batch["input_fields"].to(device, non_blocking=True)
output_fields = batch["output_fields"].to(device, non_blocking=True)
x_input = to_channels_first(input_fields).float()
x_output = to_channels_first(output_fields).float()
return x_input, x_output
def get_data_info(dataset):
"""Probe dataset for shapes and channel counts."""
sample = dataset[0]
info = {}
for key, val in sample.items():
if isinstance(val, torch.Tensor):
info[key] = tuple(val.shape)
return info
def get_channel_info(dataset):
"""Get input/output channel counts for model construction."""
sample = dataset[0]
inp = sample["input_fields"] # [Ti, H, W, C]
out = sample["output_fields"] # [To, H, W, C]
ti, h, w, c_in = inp.shape
to_, _, _, c_out = out.shape
return {
"input_channels": ti * c_in,
"output_channels": to_ * c_out,
"raw_channels": c_in,
"height": h,
"width": w,
"n_steps_input": ti,
"n_steps_output": to_,
}