Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint) with full training/eval/posterior toolchain
1f3e7a2 verified | """ | |
| Conditional dataset loader for CAMELS LH 6-parameter layout. | |
| Same layout convention as DDPM_HI_Emulation_improved/dataset_conditional.py when | |
| is_6param is true (that repo enables 6-param mode when the string 'params_6' | |
| appears in data_dir): | |
| data_dir/ | |
| train_LH_6.npy, val_LH_6.npy, test_LH_6.npy | |
| train_labels_LH.npy, val_labels_LH.npy, test_labels_LH.npy | |
| Pass data_dir as the directory that directly contains these files (e.g. the | |
| absolute path to params_6 under LH_data, analogous to params_2 for 2 labels). | |
| Images are scaled to [-1, 1]; labels are z-scored using train-split statistics. | |
| """ | |
| import os | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader, Dataset | |
| # Mirrors shell training for 2-label data at .../LH_data/params_2; 6-param lives in params_6. | |
| DEFAULT_DATA_DIR = "<DDPM_ROOT>/data/LH_data/params_6" | |
| class ConditionalImageDataset(Dataset): | |
| def __init__(self, data_path, label_path, transform=None, label_stats=None): | |
| self.data = np.load(data_path) | |
| self.labels = np.load(label_path) | |
| self.transform = transform | |
| self.label_stats = label_stats | |
| assert len(self.data) == len(self.labels), ( | |
| f"Data and labels length mismatch! {len(self.data)} vs {len(self.labels)}" | |
| ) | |
| print( | |
| f"Loaded {len(self.data)} images | Image shape: {self.data.shape[1:]} | " | |
| f"Label shape: {self.labels.shape[1:]}" | |
| ) | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| img = torch.from_numpy(self.data[idx]).float() | |
| label = torch.from_numpy(self.labels[idx]).float() | |
| # Normalize image to [-1, 1] | |
| img = img * 2.0 - 1.0 | |
| # Normalize labels | |
| if self.label_stats is not None: | |
| label = (label - self.label_stats["mean"]) / self.label_stats["std"] | |
| if img.dim() == 2: | |
| img = img.unsqueeze(0) | |
| return img, label | |
| def get_conditional_dataloaders( | |
| data_dir=DEFAULT_DATA_DIR, | |
| batch_size=8, | |
| num_workers=4, | |
| pin_memory=True, | |
| normalize_labels=True, | |
| label_dim=6, | |
| ): | |
| """ | |
| Load LH 6-parameter splits. label_dim must match the second axis of *_labels_LH.npy. | |
| """ | |
| train_data = os.path.join(data_dir, "train_LH_6.npy") | |
| val_data = os.path.join(data_dir, "val_LH_6.npy") | |
| test_data = os.path.join(data_dir, "test_LH_6.npy") | |
| train_labels = os.path.join(data_dir, "train_labels_LH.npy") | |
| val_labels = os.path.join(data_dir, "val_labels_LH.npy") | |
| test_labels = os.path.join(data_dir, "test_labels_LH.npy") | |
| print(f"Loading 6-parameter LH dataset from {data_dir}") | |
| label_stats = None | |
| if normalize_labels: | |
| train_labels_array = np.load(train_labels) | |
| if train_labels_array.shape[1] != label_dim: | |
| raise ValueError( | |
| f"train_labels_LH.npy has {train_labels_array.shape[1]} columns; " | |
| f"expected label_dim={label_dim}" | |
| ) | |
| label_mean = train_labels_array.mean(axis=0) | |
| label_std = train_labels_array.std(axis=0) | |
| label_std = np.where(label_std == 0, 1.0, label_std) | |
| label_stats = { | |
| "mean": torch.from_numpy(label_mean).float(), | |
| "std": torch.from_numpy(label_std).float(), | |
| } | |
| print(f"Label normalization -> mean={label_mean}, std={label_std}") | |
| train_dataset = ConditionalImageDataset(train_data, train_labels, label_stats=label_stats) | |
| val_dataset = ConditionalImageDataset(val_data, val_labels, label_stats=label_stats) | |
| test_dataset = ConditionalImageDataset(test_data, test_labels, label_stats=label_stats) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| drop_last=True, | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| drop_last=False, | |
| ) | |
| test_loader = DataLoader( | |
| test_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| drop_last=False, | |
| ) | |
| return train_loader, val_loader, test_loader | |