File size: 3,658 Bytes
51a44ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
Streaming data pipeline for The Well datasets.
Handles HF streaming and local loading with robust error recovery.
"""
import torch
from torch.utils.data import DataLoader
import logging

logger = logging.getLogger(__name__)


def create_dataloader(
    dataset_name="turbulent_radiative_layer_2D",
    split="train",
    batch_size=4,
    n_steps_input=1,
    n_steps_output=1,
    num_workers=0,
    streaming=True,
    local_path=None,
    use_normalization=True,
):
    """Create a DataLoader for a Well dataset.

    Args:
        dataset_name: Name of the Well dataset.
        split: 'train', 'valid', or 'test'.
        batch_size: Batch size.
        n_steps_input: Number of input timesteps.
        n_steps_output: Number of output timesteps.
        num_workers: DataLoader workers (0 for streaming recommended).
        streaming: If True, stream from HuggingFace Hub.
        local_path: Path to local data (used if streaming=False).
        use_normalization: Whether to normalize data.

    Returns:
        (DataLoader, WellDataset)
    """
    from the_well.data import WellDataset

    base_path = "hf://datasets/polymathic-ai/" if streaming else local_path
    if base_path is None:
        raise ValueError("Must provide local_path when streaming=False")

    logger.info(f"Creating dataset: {dataset_name}/{split} (streaming={streaming})")

    dataset = WellDataset(
        well_base_path=base_path,
        well_dataset_name=dataset_name,
        well_split_name=split,
        n_steps_input=n_steps_input,
        n_steps_output=n_steps_output,
        use_normalization=use_normalization,
        flatten_tensors=True,
    )

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=(split == "train"),
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True,
        persistent_workers=num_workers > 0,
    )

    return loader, dataset


def to_channels_first(x):
    """Convert Well format [B, T, H, W, C] to PyTorch [B, T*C, H, W]."""
    if x.dim() == 5:  # [B, T, H, W, C]
        B, T, H, W, C = x.shape
        return x.permute(0, 1, 4, 2, 3).reshape(B, T * C, H, W)
    elif x.dim() == 4:  # [B, H, W, C] (no time dim)
        return x.permute(0, 3, 1, 2)
    elif x.dim() == 3:  # [H, W, C] single sample
        return x.permute(2, 0, 1)
    return x


def prepare_batch(batch, device="cuda"):
    """Convert a Well batch to model-ready tensors.

    Returns:
        x_input: [B, Ti*C, H, W] condition frames (channels-first)
        x_output: [B, To*C, H, W] target frames (channels-first)
    """
    input_fields = batch["input_fields"].to(device, non_blocking=True)
    output_fields = batch["output_fields"].to(device, non_blocking=True)

    x_input = to_channels_first(input_fields).float()
    x_output = to_channels_first(output_fields).float()

    return x_input, x_output


def get_data_info(dataset):
    """Probe dataset for shapes and channel counts."""
    sample = dataset[0]
    info = {}
    for key, val in sample.items():
        if isinstance(val, torch.Tensor):
            info[key] = tuple(val.shape)
    return info


def get_channel_info(dataset):
    """Get input/output channel counts for model construction."""
    sample = dataset[0]
    inp = sample["input_fields"]  # [Ti, H, W, C]
    out = sample["output_fields"]  # [To, H, W, C]

    ti, h, w, c_in = inp.shape
    to_, _, _, c_out = out.shape

    return {
        "input_channels": ti * c_in,
        "output_channels": to_ * c_out,
        "raw_channels": c_in,
        "height": h,
        "width": w,
        "n_steps_input": ti,
        "n_steps_output": to_,
    }