File size: 3,556 Bytes
9855216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
Conditional Dataset loader with labels
"""

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os


class ConditionalImageDataset(Dataset):
    def __init__(self, data_path, label_path, transform=None, label_stats=None):
        self.data = np.load(data_path)
        self.labels = np.load(label_path)
        self.transform = transform
        self.label_stats = label_stats

        assert len(self.data) == len(self.labels), f"Data and labels length mismatch! {len(self.data)} vs {len(self.labels)}"

        print(f"Loaded {len(self.data)} images | Image shape: {self.data.shape[1:]} | Label shape: {self.labels.shape[1:]}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = torch.from_numpy(self.data[idx]).float()
        label = torch.from_numpy(self.labels[idx]).float()

        # Normalize image to [-1, 1]
        img = img * 2.0 - 1.0

        # Normalize labels
        if self.label_stats is not None:
            label = (label - self.label_stats['mean']) / self.label_stats['std']

        if img.dim() == 2:
            img = img.unsqueeze(0)

        return img, label


def get_conditional_dataloaders(
    data_dir='./data/params_2',
    batch_size=8,
    num_workers=4,
    pin_memory=True,
    normalize_labels=True
):
    is_6param = 'params_6' in data_dir

    if is_6param:
        train_data = os.path.join(data_dir, 'train_LH_6.npy')
        val_data = os.path.join(data_dir, 'val_LH_6.npy')
        test_data = os.path.join(data_dir, 'test_LH_6.npy')
        train_labels = os.path.join(data_dir, 'train_labels_LH.npy')
        val_labels = os.path.join(data_dir, 'val_labels_LH.npy')
        test_labels = os.path.join(data_dir, 'test_labels_LH.npy')
    else:
        train_data = os.path.join(data_dir, 'train_LH.npy')
        val_data = os.path.join(data_dir, 'val_LH.npy')
        test_data = os.path.join(data_dir, 'test_LH.npy')
        train_labels = os.path.join(data_dir, 'train_labels_LH_2.npy')
        val_labels = os.path.join(data_dir, 'val_labels_LH_2.npy')
        test_labels = os.path.join(data_dir, 'test_labels_LH_2.npy')

    print(f"Loading dataset from {data_dir} ({'6-param' if is_6param else '2-param'})")

    # Label normalization stats
    label_stats = None
    if normalize_labels:
        train_labels_array = np.load(train_labels)
        label_mean = train_labels_array.mean(axis=0)
        label_std = train_labels_array.std(axis=0)
        label_std = np.where(label_std == 0, 1.0, label_std)  # guard against zero-variance labels
        label_stats = {'mean': torch.from_numpy(label_mean).float(), 'std': torch.from_numpy(label_std).float()}
        print(f"Label normalization -> mean={label_mean}, std={label_std}")

    train_dataset = ConditionalImageDataset(train_data, train_labels, label_stats=label_stats)
    val_dataset   = ConditionalImageDataset(val_data,   val_labels,   label_stats=label_stats)
    test_dataset  = ConditionalImageDataset(test_data,  test_labels,  label_stats=label_stats)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=num_workers, pin_memory=pin_memory, drop_last=True)
    val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, drop_last=False)
    test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, drop_last=False)

    return train_loader, val_loader, test_loader