File size: 4,298 Bytes
c46900a 1f3e7a2 c46900a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """
Conditional dataset loader for CAMELS LH 6-parameter layout.
Same layout convention as DDPM_HI_Emulation_improved/dataset_conditional.py when
is_6param is true (that repo enables 6-param mode when the string 'params_6'
appears in data_dir):
data_dir/
train_LH_6.npy, val_LH_6.npy, test_LH_6.npy
train_labels_LH.npy, val_labels_LH.npy, test_labels_LH.npy
Pass data_dir as the directory that directly contains these files (e.g. the
absolute path to params_6 under LH_data, analogous to params_2 for 2 labels).
Images are scaled to [-1, 1]; labels are z-scored using train-split statistics.
"""
import os
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
# Mirrors shell training for 2-label data at .../LH_data/params_2; 6-param lives in params_6.
DEFAULT_DATA_DIR = "<DDPM_ROOT>/data/LH_data/params_6"
class ConditionalImageDataset(Dataset):
def __init__(self, data_path, label_path, transform=None, label_stats=None):
self.data = np.load(data_path)
self.labels = np.load(label_path)
self.transform = transform
self.label_stats = label_stats
assert len(self.data) == len(self.labels), (
f"Data and labels length mismatch! {len(self.data)} vs {len(self.labels)}"
)
print(
f"Loaded {len(self.data)} images | Image shape: {self.data.shape[1:]} | "
f"Label shape: {self.labels.shape[1:]}"
)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img = torch.from_numpy(self.data[idx]).float()
label = torch.from_numpy(self.labels[idx]).float()
# Normalize image to [-1, 1]
img = img * 2.0 - 1.0
# Normalize labels
if self.label_stats is not None:
label = (label - self.label_stats["mean"]) / self.label_stats["std"]
if img.dim() == 2:
img = img.unsqueeze(0)
return img, label
def get_conditional_dataloaders(
data_dir=DEFAULT_DATA_DIR,
batch_size=8,
num_workers=4,
pin_memory=True,
normalize_labels=True,
label_dim=6,
):
"""
Load LH 6-parameter splits. label_dim must match the second axis of *_labels_LH.npy.
"""
train_data = os.path.join(data_dir, "train_LH_6.npy")
val_data = os.path.join(data_dir, "val_LH_6.npy")
test_data = os.path.join(data_dir, "test_LH_6.npy")
train_labels = os.path.join(data_dir, "train_labels_LH.npy")
val_labels = os.path.join(data_dir, "val_labels_LH.npy")
test_labels = os.path.join(data_dir, "test_labels_LH.npy")
print(f"Loading 6-parameter LH dataset from {data_dir}")
label_stats = None
if normalize_labels:
train_labels_array = np.load(train_labels)
if train_labels_array.shape[1] != label_dim:
raise ValueError(
f"train_labels_LH.npy has {train_labels_array.shape[1]} columns; "
f"expected label_dim={label_dim}"
)
label_mean = train_labels_array.mean(axis=0)
label_std = train_labels_array.std(axis=0)
label_std = np.where(label_std == 0, 1.0, label_std)
label_stats = {
"mean": torch.from_numpy(label_mean).float(),
"std": torch.from_numpy(label_std).float(),
}
print(f"Label normalization -> mean={label_mean}, std={label_std}")
train_dataset = ConditionalImageDataset(train_data, train_labels, label_stats=label_stats)
val_dataset = ConditionalImageDataset(val_data, val_labels, label_stats=label_stats)
test_dataset = ConditionalImageDataset(test_data, test_labels, label_stats=label_stats)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=False,
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=False,
)
return train_loader, val_loader, test_loader
|