File size: 3,658 Bytes
51a44ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
"""
Streaming data pipeline for The Well datasets.
Handles HF streaming and local loading with robust error recovery.
"""
import torch
from torch.utils.data import DataLoader
import logging
logger = logging.getLogger(__name__)
def create_dataloader(
dataset_name="turbulent_radiative_layer_2D",
split="train",
batch_size=4,
n_steps_input=1,
n_steps_output=1,
num_workers=0,
streaming=True,
local_path=None,
use_normalization=True,
):
"""Create a DataLoader for a Well dataset.
Args:
dataset_name: Name of the Well dataset.
split: 'train', 'valid', or 'test'.
batch_size: Batch size.
n_steps_input: Number of input timesteps.
n_steps_output: Number of output timesteps.
num_workers: DataLoader workers (0 for streaming recommended).
streaming: If True, stream from HuggingFace Hub.
local_path: Path to local data (used if streaming=False).
use_normalization: Whether to normalize data.
Returns:
(DataLoader, WellDataset)
"""
from the_well.data import WellDataset
base_path = "hf://datasets/polymathic-ai/" if streaming else local_path
if base_path is None:
raise ValueError("Must provide local_path when streaming=False")
logger.info(f"Creating dataset: {dataset_name}/{split} (streaming={streaming})")
dataset = WellDataset(
well_base_path=base_path,
well_dataset_name=dataset_name,
well_split_name=split,
n_steps_input=n_steps_input,
n_steps_output=n_steps_output,
use_normalization=use_normalization,
flatten_tensors=True,
)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=(split == "train"),
num_workers=num_workers,
pin_memory=True,
drop_last=True,
persistent_workers=num_workers > 0,
)
return loader, dataset
def to_channels_first(x):
"""Convert Well format [B, T, H, W, C] to PyTorch [B, T*C, H, W]."""
if x.dim() == 5: # [B, T, H, W, C]
B, T, H, W, C = x.shape
return x.permute(0, 1, 4, 2, 3).reshape(B, T * C, H, W)
elif x.dim() == 4: # [B, H, W, C] (no time dim)
return x.permute(0, 3, 1, 2)
elif x.dim() == 3: # [H, W, C] single sample
return x.permute(2, 0, 1)
return x
def prepare_batch(batch, device="cuda"):
"""Convert a Well batch to model-ready tensors.
Returns:
x_input: [B, Ti*C, H, W] condition frames (channels-first)
x_output: [B, To*C, H, W] target frames (channels-first)
"""
input_fields = batch["input_fields"].to(device, non_blocking=True)
output_fields = batch["output_fields"].to(device, non_blocking=True)
x_input = to_channels_first(input_fields).float()
x_output = to_channels_first(output_fields).float()
return x_input, x_output
def get_data_info(dataset):
"""Probe dataset for shapes and channel counts."""
sample = dataset[0]
info = {}
for key, val in sample.items():
if isinstance(val, torch.Tensor):
info[key] = tuple(val.shape)
return info
def get_channel_info(dataset):
"""Get input/output channel counts for model construction."""
sample = dataset[0]
inp = sample["input_fields"] # [Ti, H, W, C]
out = sample["output_fields"] # [To, H, W, C]
ti, h, w, c_in = inp.shape
to_, _, _, c_out = out.shape
return {
"input_channels": ti * c_in,
"output_channels": to_ * c_out,
"raw_channels": c_in,
"height": h,
"width": w,
"n_steps_input": ti,
"n_steps_output": to_,
}
|