import torch import numpy as np from torch.utils.data import Dataset, DataLoader from torchvision import transforms from torchvision.datasets import STL10 import config def add_noise(image: torch.Tensor, noise_type: str) -> torch.Tensor: """Add noise to a CHW float tensor in [0, 1].""" noisy = image.clone() if noise_type == "gaussian": noise = torch.randn_like(noisy) * config.GAUSSIAN_STD noisy = noisy + noise elif noise_type == "salt_pepper": mask = torch.rand_like(noisy) noisy[mask < config.SALT_PEPPER_PROB / 2] = 0.0 noisy[mask > 1 - config.SALT_PEPPER_PROB / 2] = 1.0 elif noise_type == "speckle": noise = torch.randn_like(noisy) * config.SPECKLE_STD noisy = noisy + noisy * noise else: raise ValueError(f"Unknown noise type: {noise_type}") return torch.clamp(noisy, 0.0, 1.0) class NoisySTL10(Dataset): """STL10 dataset that returns (noisy_image, clean_image) pairs. Stores images as uint8 to save RAM (~690 MB vs ~2.76 GB for float32). Converts to float only in __getitem__. """ def __init__(self, split: str, noise_type: str): self.noise_type = noise_type raw = STL10(root=config.DATA_DIR, split=split, download=False) # Keep as uint8 torch tensor — 4x less RAM than float32 self.images = torch.from_numpy(raw.data) # (N, 3, 96, 96) uint8 def __len__(self): return len(self.images) def __getitem__(self, idx): clean = self.images[idx].float() / 255.0 # convert to float here noisy = add_noise(clean, self.noise_type) return noisy, clean def get_dataloaders(noise_type: str = "gaussian"): """Return train and test DataLoaders.""" # Use unlabeled split (100K images) for training — no labels needed train_dataset = NoisySTL10(split="unlabeled", noise_type=noise_type) test_dataset = NoisySTL10(split="test", noise_type=noise_type) train_loader = DataLoader( train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True, ) test_loader = DataLoader( test_dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=config.NUM_WORKERS, pin_memory=True, ) return train_loader, test_loader # --------------------------------------------------------------------------- # Super-Resolution Dataset (noisy 48×48 → clean 96×96) # --------------------------------------------------------------------------- class SuperResSTL10(Dataset): """STL10 dataset that returns (noisy_48x48, clean_96x96) pairs for SR training. Stores images as uint8 to save RAM (~690 MB vs ~3.45 GB for float32 + downsamples). Downsampling is done on-the-fly in __getitem__. """ def __init__(self, split: str, noise_type: str): self.noise_type = noise_type raw = STL10(root=config.DATA_DIR, split=split, download=False) # Keep as uint8 torch tensor — 4x less RAM than float32 self.images = torch.from_numpy(raw.data) # (N, 3, 96, 96) uint8 self.downsample = transforms.Resize( config.SR_INPUT_SIZE, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True, ) def __len__(self): return len(self.images) def __getitem__(self, idx): clean_96 = self.images[idx].float() / 255.0 # (3, 96, 96) float lr = self.downsample(clean_96) # (3, 48, 48) float noisy_48 = add_noise(lr, self.noise_type) return noisy_48, clean_96 def get_sr_dataloaders(noise_type: str = "gaussian"): """Return train and test DataLoaders for super-resolution training.""" train_dataset = SuperResSTL10(split="unlabeled", noise_type=noise_type) test_dataset = SuperResSTL10(split="test", noise_type=noise_type) train_loader = DataLoader( train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True, ) test_loader = DataLoader( test_dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=config.NUM_WORKERS, pin_memory=True, ) return train_loader, test_loader