DDPM-2param / src /dataset_conditional.py
collins909's picture
Upload 2-parameter conditional DDPM (HI emulation, CAMELS LH params_2, epoch 200)
0d05ab1 verified
"""
Conditional Dataset loader with labels
"""
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
class ConditionalImageDataset(Dataset):
def __init__(self, data_path, label_path, transform=None, label_stats=None):
self.data = np.load(data_path)
self.labels = np.load(label_path)
self.transform = transform
self.label_stats = label_stats
assert len(self.data) == len(self.labels), f"Data and labels length mismatch! {len(self.data)} vs {len(self.labels)}"
print(f"Loaded {len(self.data)} images | Image shape: {self.data.shape[1:]} | Label shape: {self.labels.shape[1:]}")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img = torch.from_numpy(self.data[idx]).float()
label = torch.from_numpy(self.labels[idx]).float()
# Normalize image to [-1, 1]
img = img * 2.0 - 1.0
# Normalize labels
if self.label_stats is not None:
label = (label - self.label_stats['mean']) / self.label_stats['std']
if img.dim() == 2:
img = img.unsqueeze(0)
return img, label
def get_conditional_dataloaders(
data_dir='./data/params_2',
batch_size=8,
num_workers=4,
pin_memory=True,
normalize_labels=True
):
is_6param = 'params_6' in data_dir
if is_6param:
train_data = os.path.join(data_dir, 'train_LH_6.npy')
val_data = os.path.join(data_dir, 'val_LH_6.npy')
test_data = os.path.join(data_dir, 'test_LH_6.npy')
train_labels = os.path.join(data_dir, 'train_labels_LH.npy')
val_labels = os.path.join(data_dir, 'val_labels_LH.npy')
test_labels = os.path.join(data_dir, 'test_labels_LH.npy')
else:
train_data = os.path.join(data_dir, 'train_LH.npy')
val_data = os.path.join(data_dir, 'val_LH.npy')
test_data = os.path.join(data_dir, 'test_LH.npy')
train_labels = os.path.join(data_dir, 'train_labels_LH_2.npy')
val_labels = os.path.join(data_dir, 'val_labels_LH_2.npy')
test_labels = os.path.join(data_dir, 'test_labels_LH_2.npy')
print(f"Loading dataset from {data_dir} ({'6-param' if is_6param else '2-param'})")
# Label normalization stats
label_stats = None
if normalize_labels:
train_labels_array = np.load(train_labels)
label_mean = train_labels_array.mean(axis=0)
label_std = train_labels_array.std(axis=0)
label_std = np.where(label_std == 0, 1.0, label_std) # guard against zero-variance labels
label_stats = {'mean': torch.from_numpy(label_mean).float(), 'std': torch.from_numpy(label_std).float()}
print(f"Label normalization -> mean={label_mean}, std={label_std}")
train_dataset = ConditionalImageDataset(train_data, train_labels, label_stats=label_stats)
val_dataset = ConditionalImageDataset(val_data, val_labels, label_stats=label_stats)
test_dataset = ConditionalImageDataset(test_data, test_labels, label_stats=label_stats)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, drop_last=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, drop_last=False)
return train_loader, val_loader, test_loader