Upload data_pipeline.py with huggingface_hub
Browse files- data_pipeline.py +125 -0
data_pipeline.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Streaming data pipeline for The Well datasets.
|
| 3 |
+
Handles HF streaming and local loading with robust error recovery.
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def create_dataloader(
|
| 13 |
+
dataset_name="turbulent_radiative_layer_2D",
|
| 14 |
+
split="train",
|
| 15 |
+
batch_size=4,
|
| 16 |
+
n_steps_input=1,
|
| 17 |
+
n_steps_output=1,
|
| 18 |
+
num_workers=0,
|
| 19 |
+
streaming=True,
|
| 20 |
+
local_path=None,
|
| 21 |
+
use_normalization=True,
|
| 22 |
+
):
|
| 23 |
+
"""Create a DataLoader for a Well dataset.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
dataset_name: Name of the Well dataset.
|
| 27 |
+
split: 'train', 'valid', or 'test'.
|
| 28 |
+
batch_size: Batch size.
|
| 29 |
+
n_steps_input: Number of input timesteps.
|
| 30 |
+
n_steps_output: Number of output timesteps.
|
| 31 |
+
num_workers: DataLoader workers (0 for streaming recommended).
|
| 32 |
+
streaming: If True, stream from HuggingFace Hub.
|
| 33 |
+
local_path: Path to local data (used if streaming=False).
|
| 34 |
+
use_normalization: Whether to normalize data.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
(DataLoader, WellDataset)
|
| 38 |
+
"""
|
| 39 |
+
from the_well.data import WellDataset
|
| 40 |
+
|
| 41 |
+
base_path = "hf://datasets/polymathic-ai/" if streaming else local_path
|
| 42 |
+
if base_path is None:
|
| 43 |
+
raise ValueError("Must provide local_path when streaming=False")
|
| 44 |
+
|
| 45 |
+
logger.info(f"Creating dataset: {dataset_name}/{split} (streaming={streaming})")
|
| 46 |
+
|
| 47 |
+
dataset = WellDataset(
|
| 48 |
+
well_base_path=base_path,
|
| 49 |
+
well_dataset_name=dataset_name,
|
| 50 |
+
well_split_name=split,
|
| 51 |
+
n_steps_input=n_steps_input,
|
| 52 |
+
n_steps_output=n_steps_output,
|
| 53 |
+
use_normalization=use_normalization,
|
| 54 |
+
flatten_tensors=True,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
loader = DataLoader(
|
| 58 |
+
dataset,
|
| 59 |
+
batch_size=batch_size,
|
| 60 |
+
shuffle=(split == "train"),
|
| 61 |
+
num_workers=num_workers,
|
| 62 |
+
pin_memory=True,
|
| 63 |
+
drop_last=True,
|
| 64 |
+
persistent_workers=num_workers > 0,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
return loader, dataset
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def to_channels_first(x):
|
| 71 |
+
"""Convert Well format [B, T, H, W, C] to PyTorch [B, T*C, H, W]."""
|
| 72 |
+
if x.dim() == 5: # [B, T, H, W, C]
|
| 73 |
+
B, T, H, W, C = x.shape
|
| 74 |
+
return x.permute(0, 1, 4, 2, 3).reshape(B, T * C, H, W)
|
| 75 |
+
elif x.dim() == 4: # [B, H, W, C] (no time dim)
|
| 76 |
+
return x.permute(0, 3, 1, 2)
|
| 77 |
+
elif x.dim() == 3: # [H, W, C] single sample
|
| 78 |
+
return x.permute(2, 0, 1)
|
| 79 |
+
return x
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def prepare_batch(batch, device="cuda"):
|
| 83 |
+
"""Convert a Well batch to model-ready tensors.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
x_input: [B, Ti*C, H, W] condition frames (channels-first)
|
| 87 |
+
x_output: [B, To*C, H, W] target frames (channels-first)
|
| 88 |
+
"""
|
| 89 |
+
input_fields = batch["input_fields"].to(device, non_blocking=True)
|
| 90 |
+
output_fields = batch["output_fields"].to(device, non_blocking=True)
|
| 91 |
+
|
| 92 |
+
x_input = to_channels_first(input_fields).float()
|
| 93 |
+
x_output = to_channels_first(output_fields).float()
|
| 94 |
+
|
| 95 |
+
return x_input, x_output
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_data_info(dataset):
|
| 99 |
+
"""Probe dataset for shapes and channel counts."""
|
| 100 |
+
sample = dataset[0]
|
| 101 |
+
info = {}
|
| 102 |
+
for key, val in sample.items():
|
| 103 |
+
if isinstance(val, torch.Tensor):
|
| 104 |
+
info[key] = tuple(val.shape)
|
| 105 |
+
return info
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_channel_info(dataset):
|
| 109 |
+
"""Get input/output channel counts for model construction."""
|
| 110 |
+
sample = dataset[0]
|
| 111 |
+
inp = sample["input_fields"] # [Ti, H, W, C]
|
| 112 |
+
out = sample["output_fields"] # [To, H, W, C]
|
| 113 |
+
|
| 114 |
+
ti, h, w, c_in = inp.shape
|
| 115 |
+
to_, _, _, c_out = out.shape
|
| 116 |
+
|
| 117 |
+
return {
|
| 118 |
+
"input_channels": ti * c_in,
|
| 119 |
+
"output_channels": to_ * c_out,
|
| 120 |
+
"raw_channels": c_in,
|
| 121 |
+
"height": h,
|
| 122 |
+
"width": w,
|
| 123 |
+
"n_steps_input": ti,
|
| 124 |
+
"n_steps_output": to_,
|
| 125 |
+
}
|