""" Conditional dataset loader for CAMELS LH 6-parameter layout. Same layout convention as DDPM_HI_Emulation_improved/dataset_conditional.py when is_6param is true (that repo enables 6-param mode when the string 'params_6' appears in data_dir): data_dir/ train_LH_6.npy, val_LH_6.npy, test_LH_6.npy train_labels_LH.npy, val_labels_LH.npy, test_labels_LH.npy Pass data_dir as the directory that directly contains these files (e.g. the absolute path to params_6 under LH_data, analogous to params_2 for 2 labels). Images are scaled to [-1, 1]; labels are z-scored using train-split statistics. """ import os import numpy as np import torch from torch.utils.data import DataLoader, Dataset # Mirrors shell training for 2-label data at .../LH_data/params_2; 6-param lives in params_6. DEFAULT_DATA_DIR = "/data/LH_data/params_6" class ConditionalImageDataset(Dataset): def __init__(self, data_path, label_path, transform=None, label_stats=None): self.data = np.load(data_path) self.labels = np.load(label_path) self.transform = transform self.label_stats = label_stats assert len(self.data) == len(self.labels), ( f"Data and labels length mismatch! {len(self.data)} vs {len(self.labels)}" ) print( f"Loaded {len(self.data)} images | Image shape: {self.data.shape[1:]} | " f"Label shape: {self.labels.shape[1:]}" ) def __len__(self): return len(self.data) def __getitem__(self, idx): img = torch.from_numpy(self.data[idx]).float() label = torch.from_numpy(self.labels[idx]).float() # Normalize image to [-1, 1] img = img * 2.0 - 1.0 # Normalize labels if self.label_stats is not None: label = (label - self.label_stats["mean"]) / self.label_stats["std"] if img.dim() == 2: img = img.unsqueeze(0) return img, label def get_conditional_dataloaders( data_dir=DEFAULT_DATA_DIR, batch_size=8, num_workers=4, pin_memory=True, normalize_labels=True, label_dim=6, ): """ Load LH 6-parameter splits. label_dim must match the second axis of *_labels_LH.npy. """ train_data = os.path.join(data_dir, "train_LH_6.npy") val_data = os.path.join(data_dir, "val_LH_6.npy") test_data = os.path.join(data_dir, "test_LH_6.npy") train_labels = os.path.join(data_dir, "train_labels_LH.npy") val_labels = os.path.join(data_dir, "val_labels_LH.npy") test_labels = os.path.join(data_dir, "test_labels_LH.npy") print(f"Loading 6-parameter LH dataset from {data_dir}") label_stats = None if normalize_labels: train_labels_array = np.load(train_labels) if train_labels_array.shape[1] != label_dim: raise ValueError( f"train_labels_LH.npy has {train_labels_array.shape[1]} columns; " f"expected label_dim={label_dim}" ) label_mean = train_labels_array.mean(axis=0) label_std = train_labels_array.std(axis=0) label_std = np.where(label_std == 0, 1.0, label_std) label_stats = { "mean": torch.from_numpy(label_mean).float(), "std": torch.from_numpy(label_std).float(), } print(f"Label normalization -> mean={label_mean}, std={label_std}") train_dataset = ConditionalImageDataset(train_data, train_labels, label_stats=label_stats) val_dataset = ConditionalImageDataset(val_data, val_labels, label_stats=label_stats) test_dataset = ConditionalImageDataset(test_data, test_labels, label_stats=label_stats) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory, drop_last=True, ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, drop_last=False, ) test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, drop_last=False, ) return train_loader, val_loader, test_loader