image-denoiser / data /dataset.py
Kajuto's picture
Initial commit - image denoiser + SR + MLOps stack
8b83582
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import STL10
import config
def add_noise(image: torch.Tensor, noise_type: str) -> torch.Tensor:
"""Add noise to a CHW float tensor in [0, 1]."""
noisy = image.clone()
if noise_type == "gaussian":
noise = torch.randn_like(noisy) * config.GAUSSIAN_STD
noisy = noisy + noise
elif noise_type == "salt_pepper":
mask = torch.rand_like(noisy)
noisy[mask < config.SALT_PEPPER_PROB / 2] = 0.0
noisy[mask > 1 - config.SALT_PEPPER_PROB / 2] = 1.0
elif noise_type == "speckle":
noise = torch.randn_like(noisy) * config.SPECKLE_STD
noisy = noisy + noisy * noise
else:
raise ValueError(f"Unknown noise type: {noise_type}")
return torch.clamp(noisy, 0.0, 1.0)
class NoisySTL10(Dataset):
"""STL10 dataset that returns (noisy_image, clean_image) pairs.
Stores images as uint8 to save RAM (~690 MB vs ~2.76 GB for float32).
Converts to float only in __getitem__.
"""
def __init__(self, split: str, noise_type: str):
self.noise_type = noise_type
raw = STL10(root=config.DATA_DIR, split=split, download=False)
# Keep as uint8 torch tensor — 4x less RAM than float32
self.images = torch.from_numpy(raw.data) # (N, 3, 96, 96) uint8
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
clean = self.images[idx].float() / 255.0 # convert to float here
noisy = add_noise(clean, self.noise_type)
return noisy, clean
def get_dataloaders(noise_type: str = "gaussian"):
"""Return train and test DataLoaders."""
# Use unlabeled split (100K images) for training — no labels needed
train_dataset = NoisySTL10(split="unlabeled", noise_type=noise_type)
test_dataset = NoisySTL10(split="test", noise_type=noise_type)
train_loader = DataLoader(
train_dataset,
batch_size=config.BATCH_SIZE,
shuffle=True,
num_workers=config.NUM_WORKERS,
pin_memory=True,
)
test_loader = DataLoader(
test_dataset,
batch_size=config.BATCH_SIZE,
shuffle=False,
num_workers=config.NUM_WORKERS,
pin_memory=True,
)
return train_loader, test_loader
# ---------------------------------------------------------------------------
# Super-Resolution Dataset (noisy 48×48 → clean 96×96)
# ---------------------------------------------------------------------------
class SuperResSTL10(Dataset):
"""STL10 dataset that returns (noisy_48x48, clean_96x96) pairs for SR training.
Stores images as uint8 to save RAM (~690 MB vs ~3.45 GB for float32 + downsamples).
Downsampling is done on-the-fly in __getitem__.
"""
def __init__(self, split: str, noise_type: str):
self.noise_type = noise_type
raw = STL10(root=config.DATA_DIR, split=split, download=False)
# Keep as uint8 torch tensor — 4x less RAM than float32
self.images = torch.from_numpy(raw.data) # (N, 3, 96, 96) uint8
self.downsample = transforms.Resize(
config.SR_INPUT_SIZE,
interpolation=transforms.InterpolationMode.BICUBIC,
antialias=True,
)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
clean_96 = self.images[idx].float() / 255.0 # (3, 96, 96) float
lr = self.downsample(clean_96) # (3, 48, 48) float
noisy_48 = add_noise(lr, self.noise_type)
return noisy_48, clean_96
def get_sr_dataloaders(noise_type: str = "gaussian"):
"""Return train and test DataLoaders for super-resolution training."""
train_dataset = SuperResSTL10(split="unlabeled", noise_type=noise_type)
test_dataset = SuperResSTL10(split="test", noise_type=noise_type)
train_loader = DataLoader(
train_dataset,
batch_size=config.BATCH_SIZE,
shuffle=True,
num_workers=config.NUM_WORKERS,
pin_memory=True,
)
test_loader = DataLoader(
test_dataset,
batch_size=config.BATCH_SIZE,
shuffle=False,
num_workers=config.NUM_WORKERS,
pin_memory=True,
)
return train_loader, test_loader