Spaces:
Sleeping
Sleeping
File size: 4,342 Bytes
8b83582 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import STL10
import config
def add_noise(image: torch.Tensor, noise_type: str) -> torch.Tensor:
"""Add noise to a CHW float tensor in [0, 1]."""
noisy = image.clone()
if noise_type == "gaussian":
noise = torch.randn_like(noisy) * config.GAUSSIAN_STD
noisy = noisy + noise
elif noise_type == "salt_pepper":
mask = torch.rand_like(noisy)
noisy[mask < config.SALT_PEPPER_PROB / 2] = 0.0
noisy[mask > 1 - config.SALT_PEPPER_PROB / 2] = 1.0
elif noise_type == "speckle":
noise = torch.randn_like(noisy) * config.SPECKLE_STD
noisy = noisy + noisy * noise
else:
raise ValueError(f"Unknown noise type: {noise_type}")
return torch.clamp(noisy, 0.0, 1.0)
class NoisySTL10(Dataset):
"""STL10 dataset that returns (noisy_image, clean_image) pairs.
Stores images as uint8 to save RAM (~690 MB vs ~2.76 GB for float32).
Converts to float only in __getitem__.
"""
def __init__(self, split: str, noise_type: str):
self.noise_type = noise_type
raw = STL10(root=config.DATA_DIR, split=split, download=False)
# Keep as uint8 torch tensor — 4x less RAM than float32
self.images = torch.from_numpy(raw.data) # (N, 3, 96, 96) uint8
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
clean = self.images[idx].float() / 255.0 # convert to float here
noisy = add_noise(clean, self.noise_type)
return noisy, clean
def get_dataloaders(noise_type: str = "gaussian"):
"""Return train and test DataLoaders."""
# Use unlabeled split (100K images) for training — no labels needed
train_dataset = NoisySTL10(split="unlabeled", noise_type=noise_type)
test_dataset = NoisySTL10(split="test", noise_type=noise_type)
train_loader = DataLoader(
train_dataset,
batch_size=config.BATCH_SIZE,
shuffle=True,
num_workers=config.NUM_WORKERS,
pin_memory=True,
)
test_loader = DataLoader(
test_dataset,
batch_size=config.BATCH_SIZE,
shuffle=False,
num_workers=config.NUM_WORKERS,
pin_memory=True,
)
return train_loader, test_loader
# ---------------------------------------------------------------------------
# Super-Resolution Dataset (noisy 48×48 → clean 96×96)
# ---------------------------------------------------------------------------
class SuperResSTL10(Dataset):
"""STL10 dataset that returns (noisy_48x48, clean_96x96) pairs for SR training.
Stores images as uint8 to save RAM (~690 MB vs ~3.45 GB for float32 + downsamples).
Downsampling is done on-the-fly in __getitem__.
"""
def __init__(self, split: str, noise_type: str):
self.noise_type = noise_type
raw = STL10(root=config.DATA_DIR, split=split, download=False)
# Keep as uint8 torch tensor — 4x less RAM than float32
self.images = torch.from_numpy(raw.data) # (N, 3, 96, 96) uint8
self.downsample = transforms.Resize(
config.SR_INPUT_SIZE,
interpolation=transforms.InterpolationMode.BICUBIC,
antialias=True,
)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
clean_96 = self.images[idx].float() / 255.0 # (3, 96, 96) float
lr = self.downsample(clean_96) # (3, 48, 48) float
noisy_48 = add_noise(lr, self.noise_type)
return noisy_48, clean_96
def get_sr_dataloaders(noise_type: str = "gaussian"):
"""Return train and test DataLoaders for super-resolution training."""
train_dataset = SuperResSTL10(split="unlabeled", noise_type=noise_type)
test_dataset = SuperResSTL10(split="test", noise_type=noise_type)
train_loader = DataLoader(
train_dataset,
batch_size=config.BATCH_SIZE,
shuffle=True,
num_workers=config.NUM_WORKERS,
pin_memory=True,
)
test_loader = DataLoader(
test_dataset,
batch_size=config.BATCH_SIZE,
shuffle=False,
num_workers=config.NUM_WORKERS,
pin_memory=True,
)
return train_loader, test_loader
|