|
|
""" |
|
|
Streaming data pipeline for The Well datasets. |
|
|
Handles HF streaming and local loading with robust error recovery. |
|
|
""" |
|
|
import torch |
|
|
from torch.utils.data import DataLoader |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def create_dataloader( |
|
|
dataset_name="turbulent_radiative_layer_2D", |
|
|
split="train", |
|
|
batch_size=4, |
|
|
n_steps_input=1, |
|
|
n_steps_output=1, |
|
|
num_workers=0, |
|
|
streaming=True, |
|
|
local_path=None, |
|
|
use_normalization=True, |
|
|
): |
|
|
"""Create a DataLoader for a Well dataset. |
|
|
|
|
|
Args: |
|
|
dataset_name: Name of the Well dataset. |
|
|
split: 'train', 'valid', or 'test'. |
|
|
batch_size: Batch size. |
|
|
n_steps_input: Number of input timesteps. |
|
|
n_steps_output: Number of output timesteps. |
|
|
num_workers: DataLoader workers (0 for streaming recommended). |
|
|
streaming: If True, stream from HuggingFace Hub. |
|
|
local_path: Path to local data (used if streaming=False). |
|
|
use_normalization: Whether to normalize data. |
|
|
|
|
|
Returns: |
|
|
(DataLoader, WellDataset) |
|
|
""" |
|
|
from the_well.data import WellDataset |
|
|
|
|
|
base_path = "hf://datasets/polymathic-ai/" if streaming else local_path |
|
|
if base_path is None: |
|
|
raise ValueError("Must provide local_path when streaming=False") |
|
|
|
|
|
logger.info(f"Creating dataset: {dataset_name}/{split} (streaming={streaming})") |
|
|
|
|
|
dataset = WellDataset( |
|
|
well_base_path=base_path, |
|
|
well_dataset_name=dataset_name, |
|
|
well_split_name=split, |
|
|
n_steps_input=n_steps_input, |
|
|
n_steps_output=n_steps_output, |
|
|
use_normalization=use_normalization, |
|
|
flatten_tensors=True, |
|
|
) |
|
|
|
|
|
loader = DataLoader( |
|
|
dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=(split == "train"), |
|
|
num_workers=num_workers, |
|
|
pin_memory=True, |
|
|
drop_last=True, |
|
|
persistent_workers=num_workers > 0, |
|
|
) |
|
|
|
|
|
return loader, dataset |
|
|
|
|
|
|
|
|
def to_channels_first(x): |
|
|
"""Convert Well format [B, T, H, W, C] to PyTorch [B, T*C, H, W].""" |
|
|
if x.dim() == 5: |
|
|
B, T, H, W, C = x.shape |
|
|
return x.permute(0, 1, 4, 2, 3).reshape(B, T * C, H, W) |
|
|
elif x.dim() == 4: |
|
|
return x.permute(0, 3, 1, 2) |
|
|
elif x.dim() == 3: |
|
|
return x.permute(2, 0, 1) |
|
|
return x |
|
|
|
|
|
|
|
|
def prepare_batch(batch, device="cuda"): |
|
|
"""Convert a Well batch to model-ready tensors. |
|
|
|
|
|
Returns: |
|
|
x_input: [B, Ti*C, H, W] condition frames (channels-first) |
|
|
x_output: [B, To*C, H, W] target frames (channels-first) |
|
|
""" |
|
|
input_fields = batch["input_fields"].to(device, non_blocking=True) |
|
|
output_fields = batch["output_fields"].to(device, non_blocking=True) |
|
|
|
|
|
x_input = to_channels_first(input_fields).float() |
|
|
x_output = to_channels_first(output_fields).float() |
|
|
|
|
|
return x_input, x_output |
|
|
|
|
|
|
|
|
def get_data_info(dataset): |
|
|
"""Probe dataset for shapes and channel counts.""" |
|
|
sample = dataset[0] |
|
|
info = {} |
|
|
for key, val in sample.items(): |
|
|
if isinstance(val, torch.Tensor): |
|
|
info[key] = tuple(val.shape) |
|
|
return info |
|
|
|
|
|
|
|
|
def get_channel_info(dataset): |
|
|
"""Get input/output channel counts for model construction.""" |
|
|
sample = dataset[0] |
|
|
inp = sample["input_fields"] |
|
|
out = sample["output_fields"] |
|
|
|
|
|
ti, h, w, c_in = inp.shape |
|
|
to_, _, _, c_out = out.shape |
|
|
|
|
|
return { |
|
|
"input_channels": ti * c_in, |
|
|
"output_channels": to_ * c_out, |
|
|
"raw_channels": c_in, |
|
|
"height": h, |
|
|
"width": w, |
|
|
"n_steps_input": ti, |
|
|
"n_steps_output": to_, |
|
|
} |
|
|
|