DDPM-6param / src /dataset_conditional.py
collins909's picture
Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint) with full training/eval/posterior toolchain
1f3e7a2 verified
"""
Conditional dataset loader for CAMELS LH 6-parameter layout.
Same layout convention as DDPM_HI_Emulation_improved/dataset_conditional.py when
is_6param is true (that repo enables 6-param mode when the string 'params_6'
appears in data_dir):
data_dir/
train_LH_6.npy, val_LH_6.npy, test_LH_6.npy
train_labels_LH.npy, val_labels_LH.npy, test_labels_LH.npy
Pass data_dir as the directory that directly contains these files (e.g. the
absolute path to params_6 under LH_data, analogous to params_2 for 2 labels).
Images are scaled to [-1, 1]; labels are z-scored using train-split statistics.
"""
import os
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
# Mirrors shell training for 2-label data at .../LH_data/params_2; 6-param lives in params_6.
DEFAULT_DATA_DIR = "<DDPM_ROOT>/data/LH_data/params_6"
class ConditionalImageDataset(Dataset):
def __init__(self, data_path, label_path, transform=None, label_stats=None):
self.data = np.load(data_path)
self.labels = np.load(label_path)
self.transform = transform
self.label_stats = label_stats
assert len(self.data) == len(self.labels), (
f"Data and labels length mismatch! {len(self.data)} vs {len(self.labels)}"
)
print(
f"Loaded {len(self.data)} images | Image shape: {self.data.shape[1:]} | "
f"Label shape: {self.labels.shape[1:]}"
)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img = torch.from_numpy(self.data[idx]).float()
label = torch.from_numpy(self.labels[idx]).float()
# Normalize image to [-1, 1]
img = img * 2.0 - 1.0
# Normalize labels
if self.label_stats is not None:
label = (label - self.label_stats["mean"]) / self.label_stats["std"]
if img.dim() == 2:
img = img.unsqueeze(0)
return img, label
def get_conditional_dataloaders(
data_dir=DEFAULT_DATA_DIR,
batch_size=8,
num_workers=4,
pin_memory=True,
normalize_labels=True,
label_dim=6,
):
"""
Load LH 6-parameter splits. label_dim must match the second axis of *_labels_LH.npy.
"""
train_data = os.path.join(data_dir, "train_LH_6.npy")
val_data = os.path.join(data_dir, "val_LH_6.npy")
test_data = os.path.join(data_dir, "test_LH_6.npy")
train_labels = os.path.join(data_dir, "train_labels_LH.npy")
val_labels = os.path.join(data_dir, "val_labels_LH.npy")
test_labels = os.path.join(data_dir, "test_labels_LH.npy")
print(f"Loading 6-parameter LH dataset from {data_dir}")
label_stats = None
if normalize_labels:
train_labels_array = np.load(train_labels)
if train_labels_array.shape[1] != label_dim:
raise ValueError(
f"train_labels_LH.npy has {train_labels_array.shape[1]} columns; "
f"expected label_dim={label_dim}"
)
label_mean = train_labels_array.mean(axis=0)
label_std = train_labels_array.std(axis=0)
label_std = np.where(label_std == 0, 1.0, label_std)
label_stats = {
"mean": torch.from_numpy(label_mean).float(),
"std": torch.from_numpy(label_std).float(),
}
print(f"Label normalization -> mean={label_mean}, std={label_std}")
train_dataset = ConditionalImageDataset(train_data, train_labels, label_stats=label_stats)
val_dataset = ConditionalImageDataset(val_data, val_labels, label_stats=label_stats)
test_dataset = ConditionalImageDataset(test_data, test_labels, label_stats=label_stats)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=False,
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=False,
)
return train_loader, val_loader, test_loader