File size: 4,342 Bytes
8b83582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import STL10

import config




def add_noise(image: torch.Tensor, noise_type: str) -> torch.Tensor:
    """Add noise to a CHW float tensor in [0, 1]."""
    noisy = image.clone()

    if noise_type == "gaussian":
        noise = torch.randn_like(noisy) * config.GAUSSIAN_STD
        noisy = noisy + noise

    elif noise_type == "salt_pepper":
        mask = torch.rand_like(noisy)
        noisy[mask < config.SALT_PEPPER_PROB / 2] = 0.0
        noisy[mask > 1 - config.SALT_PEPPER_PROB / 2] = 1.0

    elif noise_type == "speckle":
        noise = torch.randn_like(noisy) * config.SPECKLE_STD
        noisy = noisy + noisy * noise

    else:
        raise ValueError(f"Unknown noise type: {noise_type}")

    return torch.clamp(noisy, 0.0, 1.0)


class NoisySTL10(Dataset):
    """STL10 dataset that returns (noisy_image, clean_image) pairs.

    Stores images as uint8 to save RAM (~690 MB vs ~2.76 GB for float32).
    Converts to float only in __getitem__.
    """

    def __init__(self, split: str, noise_type: str):
        self.noise_type = noise_type
        raw = STL10(root=config.DATA_DIR, split=split, download=False)
        # Keep as uint8 torch tensor — 4x less RAM than float32
        self.images = torch.from_numpy(raw.data)  # (N, 3, 96, 96) uint8

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        clean = self.images[idx].float() / 255.0  # convert to float here
        noisy = add_noise(clean, self.noise_type)
        return noisy, clean


def get_dataloaders(noise_type: str = "gaussian"):
    """Return train and test DataLoaders."""
    # Use unlabeled split (100K images) for training — no labels needed
    train_dataset = NoisySTL10(split="unlabeled", noise_type=noise_type)
    test_dataset = NoisySTL10(split="test", noise_type=noise_type)

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
        pin_memory=True,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=False,
        num_workers=config.NUM_WORKERS,
        pin_memory=True,
    )
    return train_loader, test_loader


# ---------------------------------------------------------------------------
# Super-Resolution Dataset (noisy 48×48 → clean 96×96)
# ---------------------------------------------------------------------------

class SuperResSTL10(Dataset):
    """STL10 dataset that returns (noisy_48x48, clean_96x96) pairs for SR training.

    Stores images as uint8 to save RAM (~690 MB vs ~3.45 GB for float32 + downsamples).
    Downsampling is done on-the-fly in __getitem__.
    """

    def __init__(self, split: str, noise_type: str):
        self.noise_type = noise_type
        raw = STL10(root=config.DATA_DIR, split=split, download=False)
        # Keep as uint8 torch tensor — 4x less RAM than float32
        self.images = torch.from_numpy(raw.data)  # (N, 3, 96, 96) uint8
        self.downsample = transforms.Resize(
            config.SR_INPUT_SIZE,
            interpolation=transforms.InterpolationMode.BICUBIC,
            antialias=True,
        )

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        clean_96 = self.images[idx].float() / 255.0   # (3, 96, 96) float
        lr = self.downsample(clean_96)                 # (3, 48, 48) float
        noisy_48 = add_noise(lr, self.noise_type)
        return noisy_48, clean_96


def get_sr_dataloaders(noise_type: str = "gaussian"):
    """Return train and test DataLoaders for super-resolution training."""
    train_dataset = SuperResSTL10(split="unlabeled", noise_type=noise_type)
    test_dataset = SuperResSTL10(split="test", noise_type=noise_type)

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
        pin_memory=True,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=False,
        num_workers=config.NUM_WORKERS,
        pin_memory=True,
    )
    return train_loader, test_loader