Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| from torchvision.datasets import STL10 | |
| import config | |
| def add_noise(image: torch.Tensor, noise_type: str) -> torch.Tensor: | |
| """Add noise to a CHW float tensor in [0, 1].""" | |
| noisy = image.clone() | |
| if noise_type == "gaussian": | |
| noise = torch.randn_like(noisy) * config.GAUSSIAN_STD | |
| noisy = noisy + noise | |
| elif noise_type == "salt_pepper": | |
| mask = torch.rand_like(noisy) | |
| noisy[mask < config.SALT_PEPPER_PROB / 2] = 0.0 | |
| noisy[mask > 1 - config.SALT_PEPPER_PROB / 2] = 1.0 | |
| elif noise_type == "speckle": | |
| noise = torch.randn_like(noisy) * config.SPECKLE_STD | |
| noisy = noisy + noisy * noise | |
| else: | |
| raise ValueError(f"Unknown noise type: {noise_type}") | |
| return torch.clamp(noisy, 0.0, 1.0) | |
| class NoisySTL10(Dataset): | |
| """STL10 dataset that returns (noisy_image, clean_image) pairs. | |
| Stores images as uint8 to save RAM (~690 MB vs ~2.76 GB for float32). | |
| Converts to float only in __getitem__. | |
| """ | |
| def __init__(self, split: str, noise_type: str): | |
| self.noise_type = noise_type | |
| raw = STL10(root=config.DATA_DIR, split=split, download=False) | |
| # Keep as uint8 torch tensor — 4x less RAM than float32 | |
| self.images = torch.from_numpy(raw.data) # (N, 3, 96, 96) uint8 | |
| def __len__(self): | |
| return len(self.images) | |
| def __getitem__(self, idx): | |
| clean = self.images[idx].float() / 255.0 # convert to float here | |
| noisy = add_noise(clean, self.noise_type) | |
| return noisy, clean | |
| def get_dataloaders(noise_type: str = "gaussian"): | |
| """Return train and test DataLoaders.""" | |
| # Use unlabeled split (100K images) for training — no labels needed | |
| train_dataset = NoisySTL10(split="unlabeled", noise_type=noise_type) | |
| test_dataset = NoisySTL10(split="test", noise_type=noise_type) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=config.BATCH_SIZE, | |
| shuffle=True, | |
| num_workers=config.NUM_WORKERS, | |
| pin_memory=True, | |
| ) | |
| test_loader = DataLoader( | |
| test_dataset, | |
| batch_size=config.BATCH_SIZE, | |
| shuffle=False, | |
| num_workers=config.NUM_WORKERS, | |
| pin_memory=True, | |
| ) | |
| return train_loader, test_loader | |
| # --------------------------------------------------------------------------- | |
| # Super-Resolution Dataset (noisy 48×48 → clean 96×96) | |
| # --------------------------------------------------------------------------- | |
| class SuperResSTL10(Dataset): | |
| """STL10 dataset that returns (noisy_48x48, clean_96x96) pairs for SR training. | |
| Stores images as uint8 to save RAM (~690 MB vs ~3.45 GB for float32 + downsamples). | |
| Downsampling is done on-the-fly in __getitem__. | |
| """ | |
| def __init__(self, split: str, noise_type: str): | |
| self.noise_type = noise_type | |
| raw = STL10(root=config.DATA_DIR, split=split, download=False) | |
| # Keep as uint8 torch tensor — 4x less RAM than float32 | |
| self.images = torch.from_numpy(raw.data) # (N, 3, 96, 96) uint8 | |
| self.downsample = transforms.Resize( | |
| config.SR_INPUT_SIZE, | |
| interpolation=transforms.InterpolationMode.BICUBIC, | |
| antialias=True, | |
| ) | |
| def __len__(self): | |
| return len(self.images) | |
| def __getitem__(self, idx): | |
| clean_96 = self.images[idx].float() / 255.0 # (3, 96, 96) float | |
| lr = self.downsample(clean_96) # (3, 48, 48) float | |
| noisy_48 = add_noise(lr, self.noise_type) | |
| return noisy_48, clean_96 | |
| def get_sr_dataloaders(noise_type: str = "gaussian"): | |
| """Return train and test DataLoaders for super-resolution training.""" | |
| train_dataset = SuperResSTL10(split="unlabeled", noise_type=noise_type) | |
| test_dataset = SuperResSTL10(split="test", noise_type=noise_type) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=config.BATCH_SIZE, | |
| shuffle=True, | |
| num_workers=config.NUM_WORKERS, | |
| pin_memory=True, | |
| ) | |
| test_loader = DataLoader( | |
| test_dataset, | |
| batch_size=config.BATCH_SIZE, | |
| shuffle=False, | |
| num_workers=config.NUM_WORKERS, | |
| pin_memory=True, | |
| ) | |
| return train_loader, test_loader | |