""" Streaming data pipeline for The Well datasets. Handles HF streaming and local loading with robust error recovery. """ import torch from torch.utils.data import DataLoader import logging logger = logging.getLogger(__name__) def create_dataloader( dataset_name="turbulent_radiative_layer_2D", split="train", batch_size=4, n_steps_input=1, n_steps_output=1, num_workers=0, streaming=True, local_path=None, use_normalization=True, ): """Create a DataLoader for a Well dataset. Args: dataset_name: Name of the Well dataset. split: 'train', 'valid', or 'test'. batch_size: Batch size. n_steps_input: Number of input timesteps. n_steps_output: Number of output timesteps. num_workers: DataLoader workers (0 for streaming recommended). streaming: If True, stream from HuggingFace Hub. local_path: Path to local data (used if streaming=False). use_normalization: Whether to normalize data. Returns: (DataLoader, WellDataset) """ from the_well.data import WellDataset base_path = "hf://datasets/polymathic-ai/" if streaming else local_path if base_path is None: raise ValueError("Must provide local_path when streaming=False") logger.info(f"Creating dataset: {dataset_name}/{split} (streaming={streaming})") dataset = WellDataset( well_base_path=base_path, well_dataset_name=dataset_name, well_split_name=split, n_steps_input=n_steps_input, n_steps_output=n_steps_output, use_normalization=use_normalization, flatten_tensors=True, ) loader = DataLoader( dataset, batch_size=batch_size, shuffle=(split == "train"), num_workers=num_workers, pin_memory=True, drop_last=True, persistent_workers=num_workers > 0, ) return loader, dataset def to_channels_first(x): """Convert Well format [B, T, H, W, C] to PyTorch [B, T*C, H, W].""" if x.dim() == 5: # [B, T, H, W, C] B, T, H, W, C = x.shape return x.permute(0, 1, 4, 2, 3).reshape(B, T * C, H, W) elif x.dim() == 4: # [B, H, W, C] (no time dim) return x.permute(0, 3, 1, 2) elif x.dim() == 3: # [H, W, C] single sample return x.permute(2, 0, 1) return x def prepare_batch(batch, device="cuda"): """Convert a Well batch to model-ready tensors. Returns: x_input: [B, Ti*C, H, W] condition frames (channels-first) x_output: [B, To*C, H, W] target frames (channels-first) """ input_fields = batch["input_fields"].to(device, non_blocking=True) output_fields = batch["output_fields"].to(device, non_blocking=True) x_input = to_channels_first(input_fields).float() x_output = to_channels_first(output_fields).float() return x_input, x_output def get_data_info(dataset): """Probe dataset for shapes and channel counts.""" sample = dataset[0] info = {} for key, val in sample.items(): if isinstance(val, torch.Tensor): info[key] = tuple(val.shape) return info def get_channel_info(dataset): """Get input/output channel counts for model construction.""" sample = dataset[0] inp = sample["input_fields"] # [Ti, H, W, C] out = sample["output_fields"] # [To, H, W, C] ti, h, w, c_in = inp.shape to_, _, _, c_out = out.shape return { "input_channels": ti * c_in, "output_channels": to_ * c_out, "raw_channels": c_in, "height": h, "width": w, "n_steps_input": ti, "n_steps_output": to_, }