File size: 4,298 Bytes
c46900a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f3e7a2
c46900a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
Conditional dataset loader for CAMELS LH 6-parameter layout.

Same layout convention as DDPM_HI_Emulation_improved/dataset_conditional.py when
is_6param is true (that repo enables 6-param mode when the string 'params_6'
appears in data_dir):

  data_dir/
    train_LH_6.npy, val_LH_6.npy, test_LH_6.npy
    train_labels_LH.npy, val_labels_LH.npy, test_labels_LH.npy

Pass data_dir as the directory that directly contains these files (e.g. the
absolute path to params_6 under LH_data, analogous to params_2 for 2 labels).

Images are scaled to [-1, 1]; labels are z-scored using train-split statistics.
"""

import os

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset

# Mirrors shell training for 2-label data at .../LH_data/params_2; 6-param lives in params_6.
DEFAULT_DATA_DIR = "<DDPM_ROOT>/data/LH_data/params_6"


class ConditionalImageDataset(Dataset):
    def __init__(self, data_path, label_path, transform=None, label_stats=None):
        self.data = np.load(data_path)
        self.labels = np.load(label_path)
        self.transform = transform
        self.label_stats = label_stats

        assert len(self.data) == len(self.labels), (
            f"Data and labels length mismatch! {len(self.data)} vs {len(self.labels)}"
        )

        print(
            f"Loaded {len(self.data)} images | Image shape: {self.data.shape[1:]} | "
            f"Label shape: {self.labels.shape[1:]}"
        )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = torch.from_numpy(self.data[idx]).float()
        label = torch.from_numpy(self.labels[idx]).float()

        # Normalize image to [-1, 1]
        img = img * 2.0 - 1.0

        # Normalize labels
        if self.label_stats is not None:
            label = (label - self.label_stats["mean"]) / self.label_stats["std"]

        if img.dim() == 2:
            img = img.unsqueeze(0)

        return img, label


def get_conditional_dataloaders(
    data_dir=DEFAULT_DATA_DIR,
    batch_size=8,
    num_workers=4,
    pin_memory=True,
    normalize_labels=True,
    label_dim=6,
):
    """
    Load LH 6-parameter splits. label_dim must match the second axis of *_labels_LH.npy.
    """
    train_data = os.path.join(data_dir, "train_LH_6.npy")
    val_data = os.path.join(data_dir, "val_LH_6.npy")
    test_data = os.path.join(data_dir, "test_LH_6.npy")
    train_labels = os.path.join(data_dir, "train_labels_LH.npy")
    val_labels = os.path.join(data_dir, "val_labels_LH.npy")
    test_labels = os.path.join(data_dir, "test_labels_LH.npy")

    print(f"Loading 6-parameter LH dataset from {data_dir}")

    label_stats = None
    if normalize_labels:
        train_labels_array = np.load(train_labels)
        if train_labels_array.shape[1] != label_dim:
            raise ValueError(
                f"train_labels_LH.npy has {train_labels_array.shape[1]} columns; "
                f"expected label_dim={label_dim}"
            )
        label_mean = train_labels_array.mean(axis=0)
        label_std = train_labels_array.std(axis=0)
        label_std = np.where(label_std == 0, 1.0, label_std)
        label_stats = {
            "mean": torch.from_numpy(label_mean).float(),
            "std": torch.from_numpy(label_std).float(),
        }
        print(f"Label normalization -> mean={label_mean}, std={label_std}")

    train_dataset = ConditionalImageDataset(train_data, train_labels, label_stats=label_stats)
    val_dataset = ConditionalImageDataset(val_data, val_labels, label_stats=label_stats)
    test_dataset = ConditionalImageDataset(test_data, test_labels, label_stats=label_stats)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=True,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=False,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    return train_loader, val_loader, test_loader