AlexWortega commited on
Commit
51a44ad
·
verified ·
1 Parent(s): 8292899

Upload data_pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. data_pipeline.py +125 -0
data_pipeline.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Streaming data pipeline for The Well datasets.
3
+ Handles HF streaming and local loading with robust error recovery.
4
+ """
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def create_dataloader(
13
+ dataset_name="turbulent_radiative_layer_2D",
14
+ split="train",
15
+ batch_size=4,
16
+ n_steps_input=1,
17
+ n_steps_output=1,
18
+ num_workers=0,
19
+ streaming=True,
20
+ local_path=None,
21
+ use_normalization=True,
22
+ ):
23
+ """Create a DataLoader for a Well dataset.
24
+
25
+ Args:
26
+ dataset_name: Name of the Well dataset.
27
+ split: 'train', 'valid', or 'test'.
28
+ batch_size: Batch size.
29
+ n_steps_input: Number of input timesteps.
30
+ n_steps_output: Number of output timesteps.
31
+ num_workers: DataLoader workers (0 for streaming recommended).
32
+ streaming: If True, stream from HuggingFace Hub.
33
+ local_path: Path to local data (used if streaming=False).
34
+ use_normalization: Whether to normalize data.
35
+
36
+ Returns:
37
+ (DataLoader, WellDataset)
38
+ """
39
+ from the_well.data import WellDataset
40
+
41
+ base_path = "hf://datasets/polymathic-ai/" if streaming else local_path
42
+ if base_path is None:
43
+ raise ValueError("Must provide local_path when streaming=False")
44
+
45
+ logger.info(f"Creating dataset: {dataset_name}/{split} (streaming={streaming})")
46
+
47
+ dataset = WellDataset(
48
+ well_base_path=base_path,
49
+ well_dataset_name=dataset_name,
50
+ well_split_name=split,
51
+ n_steps_input=n_steps_input,
52
+ n_steps_output=n_steps_output,
53
+ use_normalization=use_normalization,
54
+ flatten_tensors=True,
55
+ )
56
+
57
+ loader = DataLoader(
58
+ dataset,
59
+ batch_size=batch_size,
60
+ shuffle=(split == "train"),
61
+ num_workers=num_workers,
62
+ pin_memory=True,
63
+ drop_last=True,
64
+ persistent_workers=num_workers > 0,
65
+ )
66
+
67
+ return loader, dataset
68
+
69
+
70
+ def to_channels_first(x):
71
+ """Convert Well format [B, T, H, W, C] to PyTorch [B, T*C, H, W]."""
72
+ if x.dim() == 5: # [B, T, H, W, C]
73
+ B, T, H, W, C = x.shape
74
+ return x.permute(0, 1, 4, 2, 3).reshape(B, T * C, H, W)
75
+ elif x.dim() == 4: # [B, H, W, C] (no time dim)
76
+ return x.permute(0, 3, 1, 2)
77
+ elif x.dim() == 3: # [H, W, C] single sample
78
+ return x.permute(2, 0, 1)
79
+ return x
80
+
81
+
82
+ def prepare_batch(batch, device="cuda"):
83
+ """Convert a Well batch to model-ready tensors.
84
+
85
+ Returns:
86
+ x_input: [B, Ti*C, H, W] condition frames (channels-first)
87
+ x_output: [B, To*C, H, W] target frames (channels-first)
88
+ """
89
+ input_fields = batch["input_fields"].to(device, non_blocking=True)
90
+ output_fields = batch["output_fields"].to(device, non_blocking=True)
91
+
92
+ x_input = to_channels_first(input_fields).float()
93
+ x_output = to_channels_first(output_fields).float()
94
+
95
+ return x_input, x_output
96
+
97
+
98
+ def get_data_info(dataset):
99
+ """Probe dataset for shapes and channel counts."""
100
+ sample = dataset[0]
101
+ info = {}
102
+ for key, val in sample.items():
103
+ if isinstance(val, torch.Tensor):
104
+ info[key] = tuple(val.shape)
105
+ return info
106
+
107
+
108
+ def get_channel_info(dataset):
109
+ """Get input/output channel counts for model construction."""
110
+ sample = dataset[0]
111
+ inp = sample["input_fields"] # [Ti, H, W, C]
112
+ out = sample["output_fields"] # [To, H, W, C]
113
+
114
+ ti, h, w, c_in = inp.shape
115
+ to_, _, _, c_out = out.shape
116
+
117
+ return {
118
+ "input_channels": ti * c_in,
119
+ "output_channels": to_ * c_out,
120
+ "raw_channels": c_in,
121
+ "height": h,
122
+ "width": w,
123
+ "n_steps_input": ti,
124
+ "n_steps_output": to_,
125
+ }