Spaces:
Running
Running
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset | |
| from typing import Tuple | |
| from rstor.data.degradation import DegradationBlurMat, DegradationBlurGauss, DegradationNoise | |
| from rstor.properties import DEVICE, AUGMENTATION_FLIP, DEGRADATION_BLUR_NONE, DEGRADATION_BLUR_MAT, DEGRADATION_BLUR_GAUSS | |
| from rstor.synthetic_data.dead_leaves_cpu import cpu_dead_leaves_chart | |
| from rstor.synthetic_data.dead_leaves_gpu import gpu_dead_leaves_chart | |
| import cv2 | |
| from skimage.filters import gaussian | |
| import random | |
| import numpy as np | |
| from rstor.utils import DEFAULT_TORCH_FLOAT_TYPE | |
| class DeadLeavesDataset(Dataset): | |
| def __init__( | |
| self, | |
| size: Tuple[int, int] = (128, 128), | |
| length: int = 1000, | |
| frozen_seed: int = None, # useful for validation set! | |
| blur_kernel_half_size: int = [0, 2], | |
| ds_factor: int = 5, | |
| noise_stddev: float = [0., 50.], | |
| degradation_blur=DEGRADATION_BLUR_NONE, | |
| **config_dead_leaves | |
| # number_of_circles: int = -1, | |
| # background_color: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), | |
| # colored: Optional[bool] = False, | |
| # radius_mean: Optional[int] = -1, | |
| # radius_stddev: Optional[int] = -1, | |
| ): | |
| self.frozen_seed = frozen_seed | |
| self.ds_factor = ds_factor | |
| self.size = (size[0]*ds_factor, size[1]*ds_factor) | |
| self.length = length | |
| self.config_dead_leaves = config_dead_leaves | |
| self.blur_kernel_half_size = blur_kernel_half_size | |
| self.noise_stddev = noise_stddev | |
| self.degradation_blur_type = degradation_blur | |
| if degradation_blur == DEGRADATION_BLUR_GAUSS: | |
| self.degradation_blur = DegradationBlurGauss(self.length, | |
| blur_kernel_half_size, | |
| frozen_seed) | |
| self.blur_deg_str = "blur_kernel_half_size" | |
| elif degradation_blur == DEGRADATION_BLUR_MAT: | |
| self.degradation_blur = DegradationBlurMat(self.length, | |
| frozen_seed) | |
| self.blur_deg_str = "blur_kernel_id" | |
| elif degradation_blur == DEGRADATION_BLUR_NONE: | |
| pass | |
| else: | |
| raise ValueError(f"Unknown degradation blur {degradation_blur}") | |
| self.degradation_noise = DegradationNoise(self.length, | |
| noise_stddev, | |
| frozen_seed) | |
| self.current_degradation = {} | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| # TODO there is a bug on this cpu version, the dead leaved dont appear ot be right | |
| seed = self.frozen_seed + idx if self.frozen_seed is not None else None | |
| chart = cpu_dead_leaves_chart(self.size, seed=seed, **self.config_dead_leaves) | |
| if self.ds_factor > 1: | |
| # print(f"Downsampling {chart.shape} with factor {self.ds_factor}...") | |
| sigma = 3/5 | |
| chart = gaussian( | |
| chart, sigma=(sigma, sigma, 0), mode='nearest', | |
| cval=0, preserve_range=True, truncate=4.0) | |
| chart = chart[::self.ds_factor, ::self.ds_factor] | |
| th_chart = torch.from_numpy(chart).permute(2, 0, 1).unsqueeze(0) | |
| degraded_chart = th_chart | |
| self.current_degradation[idx] = {} | |
| if self.degradation_blur_type != DEGRADATION_BLUR_NONE: | |
| degraded_chart = self.degradation_blur(degraded_chart, idx) | |
| self.current_degradation[idx][self.blur_deg_str] = self.degradation_blur.current_degradation[idx][self.blur_deg_str] | |
| degraded_chart = self.degradation_noise(degraded_chart, idx) | |
| self.current_degradation[idx]["noise_stddev"] = self.degradation_noise.current_degradation[idx]["noise_stddev"] | |
| degraded_chart = degraded_chart.squeeze(0) | |
| th_chart = th_chart.squeeze(0) | |
| return degraded_chart, th_chart | |
| class DeadLeavesDatasetGPU(Dataset): | |
| def __init__( | |
| self, | |
| size: Tuple[int, int] = (128, 128), | |
| length: int = 1000, | |
| frozen_seed: int = None, # useful for validation set! | |
| blur_kernel_half_size: int = [0, 2], | |
| ds_factor: int = 5, | |
| noise_stddev: float = [0., 50.], | |
| use_gaussian_kernel=True, | |
| **config_dead_leaves | |
| # number_of_circles: int = -1, | |
| # background_color: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), | |
| # colored: Optional[bool] = False, | |
| # radius_mean: Optional[int] = -1, | |
| # radius_stddev: Optional[int] = -1, | |
| ): | |
| self.frozen_seed = frozen_seed | |
| self.ds_factor = ds_factor | |
| self.size = (size[0]*ds_factor, size[1]*ds_factor) | |
| self.length = length | |
| self.config_dead_leaves = config_dead_leaves | |
| # downsample kernel | |
| sigma = 3/5 | |
| k_size = 5 # This fits with sigma = 3/5, the cutoff value is 0.0038 (neglectable) | |
| x = (torch.arange(k_size) - 2).to('cuda') | |
| kernel = torch.stack(torch.meshgrid((x, x), indexing='ij')) | |
| kernel.requires_grad = False | |
| dist_sq = kernel[0]**2 + kernel[1]**2 | |
| kernel = (-dist_sq.square()/(2*sigma**2)).exp() | |
| kernel = kernel / kernel.sum() | |
| self.downsample_kernel = kernel.repeat(3, 1, 1, 1) # shape [3, 1, k_size, k_size] | |
| self.downsample_kernel.requires_grad = False | |
| self.use_gaussian_kernel = use_gaussian_kernel | |
| if use_gaussian_kernel: | |
| self.degradation_blur = DegradationBlurGauss(length, | |
| blur_kernel_half_size, | |
| frozen_seed) | |
| else: | |
| self.degradation_blur = DegradationBlurMat(length, | |
| frozen_seed) | |
| self.degradation_noise = DegradationNoise(length, | |
| noise_stddev, | |
| frozen_seed) | |
| self.current_degradation = {} | |
| def __len__(self) -> int: | |
| return self.length | |
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Get a single deadleave chart and its degraded version. | |
| Args: | |
| idx (int): index of the item to retrieve | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor]: degraded chart, target chart | |
| """ | |
| seed = self.frozen_seed + idx if self.frozen_seed is not None else None | |
| # Return numba device array | |
| numba_chart = gpu_dead_leaves_chart(self.size, seed=seed, **self.config_dead_leaves) | |
| th_chart = torch.as_tensor(numba_chart, dtype=DEFAULT_TORCH_FLOAT_TYPE, device="cuda")[ | |
| None].permute(0, 3, 1, 2) # [1, c, h, w] | |
| if self.ds_factor > 1: | |
| # Downsample using strided gaussian conv (sigma=3/5) | |
| th_chart = F.pad(th_chart, | |
| pad=(2, 2, 0, 0), | |
| mode="replicate") | |
| th_chart = F.conv2d(th_chart, | |
| self.downsample_kernel, | |
| padding='valid', | |
| groups=3, | |
| stride=self.ds_factor) | |
| degraded_chart = self.degradation_blur(th_chart, idx) | |
| degraded_chart = self.degradation_noise(degraded_chart, idx) | |
| blur_deg_str = "blur_kernel_half_size" if self.use_gaussian_kernel else "blur_kernel_id" | |
| self.current_degradation[idx] = { | |
| blur_deg_str: self.degradation_blur.current_degradation[idx][blur_deg_str], | |
| "noise_stddev": self.degradation_noise.current_degradation[idx]["noise_stddev"] | |
| } | |
| degraded_chart = degraded_chart.squeeze(0) | |
| th_chart = th_chart.squeeze(0) | |
| return degraded_chart, th_chart | |