Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import random | |
| import numpy as np | |
| import albumentations as A | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset | |
| from typing import Dict, Optional, Tuple, List, Union | |
| from glob import glob | |
| MOS2_SYNTHETIC = "data/synth-datasets" | |
| MOS2_SAPPHIRE_DIR = "data/raw-data/11-19-24/2. MoS2 on Sapphire" | |
| MOS2_SILICON_DIR = "data/raw-data/11-19-24/2. MoS2 on Sapphire" | |
| MOS2_SEF_FULL_RES_SRC_DIR = "data/raw-data/1-23-25" | |
| MOS2_SEF_MANY_RES_SRC_DIR = "/playpen/mufan/levi/tianlong-chen-lab/sparse-cafm/data/raw-data/2-6-25" | |
| BTO_MANY_RES = "data/raw-data/3-12-25" | |
| TRAIN_SPLIT = "train" | |
| VAL_SPLIT = "val" | |
| TEST_SPLIT = "test" | |
| ORIGINAL_IMAGE_SIZE = (512, 512) | |
| CROPPED_IMG_SIDE_LENGTH = 64 | |
| IMG_SIZE_UM = 2.0 | |
| NORMALIZED_DATA_RANGE = (0.0, 1.0) | |
| class MOS2SRDataset(Dataset): | |
| """ | |
| Dataset class for sparse-sampling of MoS2 samples collected on various substrates. | |
| :Definitions: | |
| - X: surface height map | (H, W) | |
| - y: current map | (H, W) | |
| """ | |
| def __init__( | |
| self, | |
| src_dir: str = MOS2_SEF_FULL_RES_SRC_DIR, | |
| split: str = "train", | |
| upsample_factor: int = 2, | |
| steps_per_epoch: int = 100, | |
| original_image_size: Tuple[int, int] = ORIGINAL_IMAGE_SIZE, | |
| ): | |
| """ | |
| Parameters | |
| --- | |
| split : str | |
| Dataset split; one of {'train', 'val', 'test'}. | |
| - 'train': Uses synthetic downsampling for training samples. | |
| - 'val': Uses synthetic downsampling for validation samples. | |
| - 'test': Uses only real downsampled data; supported by `MOS2_SEF_MANY_RES_SRC_DIR` and `BTO_MANY_RES` datasets. | |
| steps_per_epoch : int | |
| Number of batches per epoch. Data is randomly augmented, so the number of samples per epoch is arbitrary. | |
| upsample_factor : int | |
| Upsampling factor; must be one of {1, 2, 4, 8}. | |
| original_image_size : tuple of int | |
| Size of the original images in the dataset, e.g., (512, 512). | |
| """ | |
| super(MOS2SRDataset, self).__init__() | |
| self.steps_per_epoch: int = steps_per_epoch | |
| assert split.lower() in ["train", "val", "test"], f"Error: invalid split. Expected 'train' or 'val'" | |
| self.split: str = split.lower() | |
| assert upsample_factor in [1, 2, 4, 8], f"Error: expected upsample_factor in: [1, 2, 4, 8]" | |
| self.upsample_factor = upsample_factor | |
| # size of subsamples to crop from original (512, 512) data | |
| self.side_length = 128 | |
| if self.upsample_factor == 2: | |
| # [64, 64] -> [128, 128] | |
| self.side_length == 64 * 2 | |
| if self.upsample_factor == 4: | |
| # [64, 64] -> [256, 256] | |
| self.side_length = 64 * 4 | |
| if self.upsample_factor == 8: | |
| # [48, 48] -> [384, 384] | |
| self.side_length = 48 * 8 | |
| assert os.path.isdir(src_dir), f"Error: invalid src_dir: {src_dir}" | |
| self.src_dir = src_dir | |
| self.original_image_size: Tuple[int, int] = original_image_size | |
| self.augmentation_pipeline = self._create_augmentation_pipeline() | |
| # (B, C, H, W) | |
| self.current_maps = None | |
| self.topo_maps = None | |
| # paths to un-normalized, high-precision current maps | |
| self._raw_current_fps: Optional[List[str]] = None | |
| self._raw_topo_fps: Optional[List[str]] = None | |
| # for normalizing X, y, respectively later | |
| self.current_maps_mean = 0.0 | |
| self.current_maps_std = 0.0 | |
| # use these vals to normalize all data -> [0, 1] | |
| self.current_maps_max = 0.0 | |
| self.current_maps_min = 0.0 | |
| self.topo_maps_mean = 0.0 | |
| self.topo_maps_std = 0.0 | |
| # original sample size is 2um | |
| self.img_size_um = IMG_SIZE_UM | |
| # all data (current + topo maps) normalized to -> [0, 1] | |
| self.normalized_data_range: Tuple[float, float] = NORMALIZED_DATA_RANGE | |
| # load all data from src files | |
| if self.src_dir == MOS2_SEF_FULL_RES_SRC_DIR: | |
| self._load_imgs_mos2_sef() | |
| elif src_dir == MOS2_SILICON_DIR or src_dir == MOS2_SAPPHIRE_DIR: | |
| self._load_imgs_sil_saf() | |
| elif src_dir == BTO_MANY_RES: | |
| self._load_bto_many_res() | |
| elif src_dir == MOS2_SYNTHETIC: | |
| self._load_imgs_mos2_synth() | |
| else: | |
| raise Exception(f"Error: unsupported dataset: {src_dir}") | |
| # TODO: experiment with this | |
| # remove L -> R gradients; remove back contact bias | |
| # self._remove_gradients() | |
| # find the mean/std of current and topo maps | |
| self._calculate_mean_std() | |
| def _load_bto_many_res(self) -> None: | |
| """ | |
| BTO dataset only contains surface morphology maps. | |
| - 4x scans @{512, 256, 128, 64} | |
| """ | |
| raise Exception("BTO dataset is not supported with this dataloader.") | |
| def _load_imgs_mos2_synth(self) -> None: | |
| # current_map_regex = f"{self.src_dir}/current/{self.split}/current-maps/*.npy" | |
| # topo_map_regex = f"{self.src_dir}/topology/{self.split}/topo-maps/*.npy" | |
| # HACK: only load in train samples | |
| current_map_regex = f"{self.src_dir}/current/train/current-maps/*.npy" | |
| topo_map_regex = f"{self.src_dir}/topology/train/topo-maps/*.npy" | |
| NUM_SAMPLES = min(len(glob(current_map_regex)), len(glob(topo_map_regex))) | |
| self._raw_current_fps = sorted(glob(current_map_regex))[:NUM_SAMPLES] | |
| self._raw_topo_fps = sorted(glob(topo_map_regex))[:NUM_SAMPLES] | |
| assert (len(self._raw_current_fps) > 0), f"Error: could not load images using regex: {current_map_regex}" | |
| assert (len(self._raw_topo_fps) > 0), f"Error: could not load images using regex: {current_map_regex}" | |
| # [H, W, C] | |
| self.current_maps: List[np.ndarray] = [np.load(fp) for fp in self._raw_current_fps] | |
| self.topo_maps: List[np.ndarray] = [np.load(fp) for fp in self._raw_topo_fps] | |
| # validate current, topo map paris are aligned | |
| _current_fps_basenames = [os.path.basename(fp) for fp in self._raw_current_fps] | |
| _topo_fps_basenames = [os.path.basename(fp) for fp in self._raw_topo_fps] | |
| assert (_current_fps_basenames == _topo_fps_basenames), f"Error: misalignment of current maps and topo maps during dataloading" | |
| # convert maps to type -> float64 | |
| self.current_maps = [cm.astype(np.float64) for cm in self.current_maps] | |
| self.topo_maps = [tm.astype(np.float64) for tm in self.topo_maps] | |
| # [H, W, C] -> [H, W] by averaging across the channel dimension | |
| self.current_maps = [np.mean(cm, axis=-1) for cm in self.current_maps] | |
| self.topo_maps = [np.mean(tm, axis=-1) for tm in self.topo_maps] | |
| def _load_imgs_mos2_sef(self) -> None: | |
| """ | |
| Load current-map + topo-map data from source dir. | |
| """ | |
| current_map_regex = f"{self.src_dir}/*Current*.npy" | |
| topo_map_regex = f"{self.src_dir}/*Height*.npy" | |
| self._raw_current_fps = sorted(glob(current_map_regex)) | |
| self._raw_topo_fps = sorted(glob(topo_map_regex)) | |
| assert (len(self._raw_current_fps) > 0), f"Error: could not load images using regex: {current_map_regex}" | |
| assert (len(self._raw_topo_fps) > 0), f"Error: could not load images using regex: {current_map_regex}" | |
| # [H, W, C] | |
| self.current_maps: List[np.ndarray] = [np.load(fp) for fp in self._raw_current_fps] | |
| self.topo_maps: List[np.ndarray] = [np.load(fp) for fp in self._raw_topo_fps] | |
| # validate current, topo map paris are aligned | |
| _current_fps_basenames = [os.path.basename(fp)[:4] for fp in self._raw_current_fps] | |
| _topo_fps_basenames = [os.path.basename(fp)[:4] for fp in self._raw_topo_fps] | |
| assert (_current_fps_basenames == _topo_fps_basenames), f"Error: misalignment of current maps and topo maps during dataloading" | |
| # convert maps to type -> float64 | |
| self.current_maps = [cm.astype(np.float64) for cm in self.current_maps] | |
| self.topo_maps = [tm.astype(np.float64) for tm in self.topo_maps] | |
| def _load_imgs_sil_saf(self) -> None: | |
| """ | |
| TODO: make less clunky and hard-coded. | |
| Load current-map + topo-map data from source dir. | |
| """ | |
| current_map_regex = f"{self.src_dir}/*/*Current*.npy" | |
| topo_map_regex = f"{self.src_dir}/*/*Topo*.npy" | |
| self._raw_current_fps = glob(current_map_regex) | |
| self._raw_topo_fps = glob(topo_map_regex) | |
| assert ( | |
| len(self._raw_current_fps) > 0 | |
| ), f"Error: could not load images using regex: {current_map_regex}" | |
| assert ( | |
| len(self._raw_topo_fps) > 0 | |
| ), f"Error: could not load images using regex: {current_map_regex}" | |
| # (H, W) | |
| self.current_maps: List[np.ndarray] = [ | |
| np.load(fp) for fp in self._raw_current_fps | |
| ] | |
| self.topo_maps: List[np.ndarray] = [np.load(fp) for fp in self._raw_topo_fps] | |
| # validate current, topo map paris are aligned | |
| _current_fps_basenames = [ | |
| os.path.basename(fp).split("Current")[0] for fp in self._raw_current_fps | |
| ] | |
| _topo_fps_basenames = [ | |
| os.path.basename(fp).split("Topo")[0] for fp in self._raw_topo_fps | |
| ] | |
| assert ( | |
| _current_fps_basenames == _topo_fps_basenames | |
| ), f"Error: misalignment of current maps and topo maps during dataloading" | |
| # HACK: only use samples: [0, 1, 2, 3] | |
| self.current_maps = self.current_maps[0:4] | |
| self.topo_maps = self.topo_maps[0:4] | |
| def __remove_gradient(self, current_map: np.ndarray) -> np.ndarray: | |
| """ | |
| Find a line of best fit through the column-wise average current of a sample y. | |
| This method helps to remove the bias create by the back-contact; a global bias | |
| a model could not be expected to remove without additional information. | |
| """ | |
| corrected_map = np.copy(current_map) | |
| H, W = current_map.shape | |
| # column indices from 0..W-1 | |
| x = np.arange(W) | |
| # 1. shape: (W,) | |
| column_means = [np.mean(current_map[:, w]) for w in range(W)] | |
| # 2. fit a line (degree=1 polynomial) to these means | |
| # polyfit returns [slope, intercept] for a degree=1 polynomial | |
| slope, intercept = np.polyfit(x, column_means, deg=1) | |
| # evaluate the fitted line at each column index | |
| # shape: (W,) | |
| best_fit_line = slope * x + intercept | |
| # 3. subtract the fitted line from each pixel in the column | |
| # for column w, best_fit_line[w] is the "gradient" we want to remove | |
| for w in range(W): | |
| corrected_map[:, w] -= best_fit_line[w] | |
| return corrected_map | |
| def _remove_gradients(self): | |
| """ | |
| Remove column-wise gradients from each map in self.current_maps by: | |
| 1. Computing column-wise mean of each channel. | |
| 2. Fitting a best-fit line to these means. | |
| 3. Subtracting that line (per column) from the original values. | |
| """ | |
| for i, current_map in enumerate(self.current_maps): | |
| self.current_maps[i] = self.__remove_gradient(current_map) | |
| def _calculate_mean_std(self) -> None: | |
| """ | |
| Calculate the mean and std of topo/curr maps. | |
| Saves results as internal vars. | |
| """ | |
| self.current_maps_mean = np.mean(np.array(self.current_maps)) | |
| self.current_maps_std = np.std(np.array(self.current_maps)) | |
| self.current_maps_max = np.amax(np.array(self.current_maps)) | |
| self.current_maps_min = np.amin(np.array(self.current_maps)) | |
| self.topo_maps_mean = np.mean(np.array(self.topo_maps)) | |
| self.topo_maps_std = np.std(np.array(self.topo_maps)) | |
| self.topo_maps_max = np.amax(np.array(self.topo_maps)) | |
| self.topo_maps_min = np.amin(np.array(self.topo_maps)) | |
| def _create_augmentation_pipeline(self): | |
| return A.Compose( | |
| [ | |
| # A.HorizontalFlip(p=0.5), | |
| # A.VerticalFlip(p=0.5), | |
| # A.RandomRotate90(p=0.5), | |
| # A.Rotate(limit=15, p=0.5), | |
| # A.ElasticTransform(), | |
| A.RandomCrop(width=self.side_length, height=self.side_length, p=1.0), | |
| ], | |
| additional_targets={ | |
| "X": "image", | |
| "X_mask": "mask", | |
| "y": "mask", | |
| }, | |
| ) | |
| def __len__(self) -> int: | |
| """ | |
| len(self) == self.steps_per_epoch | |
| """ | |
| return self.steps_per_epoch | |
| def __getitem__(self, index: int) -> Dict: | |
| """ | |
| Get the next randomly sampled item from the dataset. | |
| :param index: currently unused, necessiary for batch data-loading | |
| :returns: | |
| ``` | |
| { | |
| 'X' : torch.Tensor, topo-map w/ shape [H, W] | |
| 'X_sparse': torch.Tensor, topo-map w/ shape [H / upsample_factor, W / upsample_factor] | |
| 'X_unnorm': torch.Tensor, topo-map w/ shape [H / upsample_factor, W / upsample_factor] | |
| 'y' : torch.Tensor, current-map w/ shape [H, W] | |
| 'y_sparse': torch.Tensor, current-map w/ shape [H / upsample_factor, W / upsample_factor] | |
| 'y_unnorm': torch.Tensor, current-map w/ shape [H / upsample_factor, W / upsample_factor] | |
| } | |
| """ | |
| # NOTE: we only consider samples: [0, 1, 2, 3]; | |
| # HACK: hard-coded train/val splits | |
| # choose a random sample idx | |
| if self.split == TRAIN_SPLIT: | |
| # randint is inclusive: [a, b] | |
| # select a random sample from self.data[:-1] | |
| sample_idx = random.randint(0, len(self.current_maps) - 2) | |
| elif self.split == VAL_SPLIT: | |
| # select the final data sample: self.data[-1] | |
| sample_idx = len(self.current_maps) - 1 | |
| else: | |
| raise Exception(f"Invalid split: {self.split}") | |
| # [512, 512]; un-normalized, full-sized topography map | |
| X: np.ndarray = self.topo_maps[sample_idx] | |
| # [512, 512]; un-normalized, full-sized current map | |
| y: np.ndarray = self.current_maps[sample_idx] | |
| # ---- select a [128, 128] subset from full-sample ---- | |
| augmented: np.ndarray = self.augmentation_pipeline(image=y, X=X, X_mask=X, y=y) | |
| # [512, 512] -> [128, 128] + apply augs | |
| # HACK: always apply augmentations | |
| if self.split == "train": | |
| X: np.ndarray = augmented["X"] | |
| y: np.ndarray = augmented["image"] | |
| elif self.split == "val": | |
| X: np.ndarray = augmented["X_mask"] | |
| y: np.ndarray = augmented["y"] | |
| else: | |
| raise Exception("Something has gone very wrong") | |
| X: torch.Tensor = torch.Tensor(X).float() | |
| y: torch.Tensor = torch.Tensor(y).float() | |
| # [128, 128] | |
| X_unnorm = X.clone() | |
| y_unnorm = y.clone() | |
| # -> [0, 1] | |
| X = (X - self.topo_maps_min) / ( | |
| self.topo_maps_max - self.topo_maps_min | |
| ) | |
| # -> [0, 1] | |
| y = (y - self.current_maps_min) / ( | |
| self.current_maps_max - self.current_maps_min | |
| ) | |
| # ---- bicubic downsampling ---- | |
| # -> [1, 1, 128, 128] | |
| X_unsqueezed = X.unsqueeze(0).unsqueeze(0) | |
| # -> [H', W'] | |
| X_sparse = F.interpolate( | |
| X_unsqueezed, | |
| scale_factor=1/self.upsample_factor, | |
| mode='bicubic', | |
| align_corners=False | |
| ) | |
| X_sparse = X_sparse.squeeze(0).squeeze(0) | |
| # -> [1, 1, 128, 128] | |
| y_unsqueezed = y.unsqueeze(0).unsqueeze(0) | |
| # -> [H', W'] | |
| y_sparse = F.interpolate( | |
| y_unsqueezed, | |
| scale_factor=1/self.upsample_factor, | |
| mode='bicubic', | |
| align_corners=False | |
| ) | |
| y_sparse = y_sparse.squeeze(0).squeeze(0) | |
| assert (X.max() <= 1.0 and X.min() >= 0.0), f"Error normalizing X sample: {X.shape}" | |
| assert (y.max() <= 1.0 and y.min() >= 0.0), f"Error normalizing y sample: {y.shape}" | |
| return { | |
| "X": X, | |
| "X_sparse": X_sparse, | |
| "X_unnorm": X_unnorm, | |
| "y": y, | |
| "y_sparse": y_sparse, | |
| "y_unnorm": y_unnorm, | |
| } | |
| class BTOSRDataset(Dataset): | |
| """ | |
| Dataset class used for sparse-sampling of BTO surface morphology maps. | |
| :Definitions: | |
| - X: surface height map | (H, W) | |
| """ | |
| def __init__( | |
| self, | |
| src_dir: str = BTO_MANY_RES, | |
| split: str = "train", | |
| upsample_factor: int = 2, | |
| steps_per_epoch: int = 100, | |
| original_image_size: Tuple[int, int] = ORIGINAL_IMAGE_SIZE, | |
| ): | |
| """ | |
| Parameters | |
| --- | |
| split : str | |
| Dataset split; one of {'train', 'val', 'test'}. | |
| - 'train': Uses synthetic downsampling for training samples. | |
| - 'val': Uses synthetic downsampling for validation samples. | |
| - 'test': Uses only real downsampled data. | |
| steps_per_epoch : int | |
| Number of batches per epoch. Data is randomly augmented, so the number of samples per epoch is arbitrary. | |
| upsample_factor : int | |
| Upsampling factor; must be one of {2, 4, 8}. | |
| original_image_size : tuple of int | |
| Size of the original images in the dataset, e.g., (512, 512). | |
| """ | |
| super(BTOSRDataset, self).__init__() | |
| self.steps_per_epoch: int = steps_per_epoch | |
| assert split.lower() in ["train", "val", "test"], f"Error: invalid split. Expected 'train' or 'val'" | |
| self.split: str = split.lower() | |
| assert upsample_factor in [2, 4, 8], f"Error: expected upsample_factor in: [2, 4, 8]" | |
| self.upsample_factor = upsample_factor | |
| # size of subsamples to crop from original (512, 512) data | |
| self.side_length = 128 | |
| if self.upsample_factor == 2: | |
| # [64, 64] -> [128, 128] | |
| self.side_length == 64 * 2 | |
| if self.upsample_factor == 4: | |
| # [64, 64] -> [256, 256] | |
| self.side_length = 64 * 4 | |
| if self.upsample_factor == 8: | |
| # [48, 48] -> [384, 384] | |
| self.side_length = 48 * 8 | |
| assert os.path.isdir(src_dir), f"Error: invalid src_dir: {src_dir}" | |
| self.src_dir = src_dir | |
| self.original_image_size: Tuple[int, int] = original_image_size | |
| # dedicated train/val augmentation pipelines | |
| self.train_augmentation_pipeline = self._create_train_augmentation_pipeline() | |
| self.val_augmentation_pipeline = self._create_val_augmentation_pipeline() | |
| # (B, H, W) | |
| self.topo_maps = None | |
| # paths to un-normalized, high-precision current maps | |
| self._raw_topo_fps: Optional[List[str]] = None | |
| # use these vals to normalize all data -> [0, 1] | |
| self.topo_maps_mean = 0.0 | |
| self.topo_maps_std = 0.0 | |
| # original sample size is 2umx2um (512x512) | |
| self.img_size_um = IMG_SIZE_UM | |
| # all data (current + topo maps) normalized to -> [0, 1] | |
| self.normalized_data_range: Tuple[float, float] = NORMALIZED_DATA_RANGE | |
| self._load_bto_many_res() | |
| # find the mean/std of current and topo maps | |
| self._calculate_mean_std() | |
| def _load_bto_many_res(self) -> None: | |
| """ | |
| BTO dataset only contains surface morphology maps. | |
| - 4x scans @{512, 256, 128, 64} | |
| """ | |
| topo_map_regex_64 = f"{self.src_dir}/*64*.npy" | |
| topo_map_regex_128 = f"{self.src_dir}/*128*.npy" | |
| topo_map_regex_256 = f"{self.src_dir}/*256*.npy" | |
| topo_map_regex_512 = f"{self.src_dir}/*512*.npy" | |
| self._raw_topo_64_fps = sorted(glob(topo_map_regex_64)) | |
| self._raw_topo_128_fps = sorted(glob(topo_map_regex_128)) | |
| self._raw_topo_256_fps = sorted(glob(topo_map_regex_256)) | |
| self._raw_topo_512_fps = sorted(glob(topo_map_regex_512)) | |
| assert (len(self._raw_topo_64_fps) > 0), f"Error: could not load images using regex: {topo_map_regex_64}" | |
| # [H, W] | |
| self.topo_maps_64 : List[np.ndarray] = [np.load(fp) for fp in self._raw_topo_64_fps] | |
| self.topo_maps_128: List[np.ndarray] = [np.load(fp) for fp in self._raw_topo_128_fps] | |
| self.topo_maps_256: List[np.ndarray] = [np.load(fp) for fp in self._raw_topo_256_fps] | |
| self.topo_maps_512: List[np.ndarray] = [np.load(fp) for fp in self._raw_topo_512_fps] | |
| # convert maps to type -> float64 | |
| self.topo_maps_64 = [tm.astype(np.float64) for tm in self.topo_maps_64] | |
| self.topo_maps_128 = [tm.astype(np.float64) for tm in self.topo_maps_128] | |
| self.topo_maps_256 = [tm.astype(np.float64) for tm in self.topo_maps_256] | |
| self.topo_maps_512 = [tm.astype(np.float64) for tm in self.topo_maps_512] | |
| def _calculate_mean_std(self) -> None: | |
| """ | |
| Calculate the mean and std of topo/curr maps. | |
| Saves results as internal vars. | |
| """ | |
| # NOTE: we only use the first three samples to calculate global dataset statistics, validation data | |
| # is not used... | |
| self.topo_maps_mean = np.mean(np.array(self.topo_maps_512)[:-1]) | |
| self.topo_maps_std = np.std(np.array(self.topo_maps_512)[:-1]) | |
| self.topo_maps_max = np.amax(np.array(self.topo_maps_512)[:-1]) | |
| self.topo_maps_min = np.amin(np.array(self.topo_maps_512)[:-1]) | |
| def _create_train_augmentation_pipeline(self): | |
| return A.Compose( | |
| [ | |
| # A.HorizontalFlip(p=0.5), | |
| # A.VerticalFlip(p=0.5), | |
| # A.RandomRotate90(p=0.5), | |
| # A.Rotate(limit=15, p=0.5), | |
| A.RandomCrop(width=self.side_length, height=self.side_length, p=1.0), | |
| ], | |
| additional_targets={ | |
| "X": "image", | |
| "X_mask": "mask", | |
| }, | |
| ) | |
| def _create_val_augmentation_pipeline(self): | |
| return A.Compose( | |
| [ | |
| # A.HorizontalFlip(p=0.5), | |
| # A.VerticalFlip(p=0.5), | |
| # A.RandomRotate90(p=0.5), | |
| # A.Rotate(limit=15, p=0.5), | |
| A.RandomCrop(width=self.side_length, height=self.side_length, p=1.0), | |
| ], | |
| additional_targets={ | |
| "X": "image", | |
| "X_mask": "mask", | |
| }, | |
| ) | |
| def __len__(self) -> int: | |
| """ | |
| len(self) == self.steps_per_epoch | |
| """ | |
| return self.steps_per_epoch | |
| def __getitem__(self, index: int) -> Dict: | |
| """ | |
| Get the next randomly sampled item from the dataset. | |
| :param index: currently unused, necessiary for batch data-loading | |
| :returns: | |
| ``` | |
| { | |
| 'X_64' : torch.Tensor, topo-map w/ shape [H, W] | |
| 'X_128' : torch.Tensor, topo-map w/ shape [H, W] | |
| 'X_256' : torch.Tensor, topo-map w/ shape [H, W] | |
| 'X_512' : torch.Tensor, topo-map w/ shape [H, W] | |
| 'X_sparse': torch.Tensor, topo-map w/ shape [H / upsample_factor, W / upsample_factor] | |
| 'X_unnorm': torch.Tensor, topo-map w/ shape [H / upsample_factor, W / upsample_factor] | |
| } | |
| """ | |
| # NOTE: we only consider samples: [0, 1, 2, 3]; | |
| # choose a random sample idx | |
| if self.split == TRAIN_SPLIT: | |
| # randint is inclusive: [a, b] | |
| # select a random sample from self.data[:-1] | |
| sample_idx = random.randint(0, len(self.topo_maps_512) - 2) | |
| elif self.split == VAL_SPLIT: | |
| # select the final data sample: self.data[-1] | |
| sample_idx = len(self.topo_maps_512) - 1 | |
| elif self.split == TEST_SPLIT: | |
| sample_idx = len(self.topo_maps_512) - 1 | |
| else: | |
| raise Exception(f"Invalid split: {self.split}") | |
| # [512, 512]; un-normalized, full-sized topography map | |
| X_512: np.ndarray = self.topo_maps_512[sample_idx] | |
| X_256: np.ndarray = self.topo_maps_256[sample_idx] | |
| X_128: np.ndarray = self.topo_maps_128[sample_idx] | |
| X_64 : np.ndarray = self.topo_maps_64[sample_idx] | |
| X: np.ndarray = X_512.copy() | |
| # ---- select a [128, 128] subset from full-sample ---- | |
| # TODO: create a validation set augmentation pipeline | |
| if self.split == "train": | |
| augmented: np.ndarray = self.train_augmentation_pipeline(image=X, X=X, X_mask=X) | |
| elif self.split == "val": | |
| augmented: np.ndarray = self.val_augmentation_pipeline(image=X, X=X, X_mask=X) | |
| else: | |
| raise Exception() | |
| # [512, 512] -> [128, 128] + apply augs | |
| if self.split == TRAIN_SPLIT: | |
| X: np.ndarray = augmented["X"] | |
| elif self.split == VAL_SPLIT: | |
| X: np.ndarray = augmented["X_mask"] | |
| elif self.split == TEST_SPLIT: | |
| # don't apply augs to test set | |
| pass | |
| else: | |
| raise Exception("Something has gone very wrong") | |
| X : torch.Tensor = torch.Tensor(X).float() | |
| X_512: torch.Tensor = torch.Tensor(X_512).float() | |
| X_256: torch.Tensor = torch.Tensor(X_256).float() | |
| X_128: torch.Tensor = torch.Tensor(X_128).float() | |
| X_64 : torch.Tensor = torch.Tensor(X_64).float() | |
| X_unnorm = X_512.clone() | |
| # -> [0, 1] | |
| X = (X - self.topo_maps_min) / ( | |
| self.topo_maps_max - self.topo_maps_min | |
| ) | |
| X_512 = (X_512 - self.topo_maps_min) / ( | |
| self.topo_maps_max - self.topo_maps_min | |
| ) | |
| X_256 = (X_256 - self.topo_maps_min) / ( | |
| self.topo_maps_max - self.topo_maps_min | |
| ) | |
| X_128 = (X_128 - self.topo_maps_min) / ( | |
| self.topo_maps_max - self.topo_maps_min | |
| ) | |
| X_64 = (X_64 - self.topo_maps_min) / ( | |
| self.topo_maps_max - self.topo_maps_min | |
| ) | |
| # ---- bicubic downsampling ---- | |
| # -> [1, 1, 128, 128] | |
| X_unsqueezed = X.unsqueeze(0).unsqueeze(0) | |
| # HACK: linear downsampling | |
| # -> [H', W'] | |
| X_sparse = X_unsqueezed[:, :, ::self.upsample_factor, ::self.upsample_factor] | |
| # # -> [H', W'] | |
| # X_sparse = F.interpolate( | |
| # X_unsqueezed, | |
| # scale_factor=1/self.upsample_factor, | |
| # mode='bicubic', | |
| # align_corners=False | |
| # ) | |
| X_sparse = X_sparse.squeeze(0).squeeze(0) | |
| assert (X.max() <= 1.0 and X.min() >= 0.0), f"Error normalizing X sample: {X.shape}" | |
| assert (X_512.max() <= 1.0 and X_512.min() >= 0.0), f"Error normalizing X sample: {X_512.shape}" | |
| # assert (X_256.max() <= 1.0 and X_256.min() >= 0.0), f"Error normalizing X sample: {X_256.shape}" | |
| # assert (X_128.max() <= 1.0 and X_128.min() >= 0.0), f"Error normalizing X sample: {X_128.shape}" | |
| # assert (X_64.max() <= 1.0 and X_64.min() >= 0.0), f"Error normalizing X sample: {X_64.shape}" | |
| return { | |
| "X" : X, | |
| "X_sparse": X_sparse, | |
| "X_512" : X_512, | |
| "X_256" : X_256, | |
| "X_128" : X_128, | |
| "X_64" : X_64, | |
| "X_unnorm": X_unnorm, | |
| } | |
| class UnifiedMOS2SRDataset(Dataset): | |
| """ | |
| A horrible abomination that contains all datasets in one. | |
| """ | |
| def __init__( | |
| self, | |
| split: str = "train", | |
| upsample_factor: int = 2, | |
| steps_per_epoch: int = 100, | |
| original_image_size: Tuple[int, int] = ORIGINAL_IMAGE_SIZE, | |
| ): | |
| """ | |
| :param split: "train" or "val" | |
| :param steps_per_epoch: data is sampled using random augmentations, therefore the # sample per epoch is arbitrary | |
| :param upsample_factor: 1, 2, 4 or 8x | |
| :param original_image_size: size of the original images in the dataset: e.g., (512, 512) | |
| """ | |
| super(UnifiedMOS2SRDataset, self).__init__() | |
| self.mos2_sef_dataset = MOS2SRDataset( | |
| src_dir=MOS2_SEF_SRC_DIR, | |
| split=split, | |
| upsample_factor=upsample_factor, | |
| steps_per_epoch=steps_per_epoch, | |
| original_image_size=original_image_size, | |
| ) | |
| self.sapphire_dataset = MOS2SRDataset( | |
| src_dir=MOS2_SAPPHIRE_DIR, | |
| split=split, | |
| upsample_factor=upsample_factor, | |
| steps_per_epoch=steps_per_epoch, | |
| original_image_size=original_image_size, | |
| ) | |
| self.silicon_datset = MOS2SRDataset( | |
| src_dir=MOS2_SILICON_DIR, | |
| split=split, | |
| upsample_factor=upsample_factor, | |
| steps_per_epoch=steps_per_epoch, | |
| original_image_size=original_image_size, | |
| ) | |
| def __len__(self): | |
| return len(self.mos2_sef_dataset) | |
| def __getitem__(self, index: int) -> dict: | |
| """ | |
| [HACK]: currently returning items for unconditional ControlNet training. | |
| Return a random item from one of three datasets. | |
| """ | |
| item = {} | |
| choice = random.random() | |
| # if choice < .33: item = self.mos2_sef_dataset.__getitem__(index) | |
| # elif choice < .66: item = self.sapphire_dataset.__getitem__(index) | |
| # else: item = self.silicon_datset.__getitem__(index) | |
| # HACK: only use two datasets for transfer learning ablation | |
| if choice < .50: item = self.mos2_sef_dataset.__getitem__(index) | |
| else : item = self.sapphire_dataset.__getitem__(index) | |
| # else: item = self.silicon_datset.__getitem__(index) | |
| return item | |
| if __name__ == "__main__": | |
| dataset = BTOSRDataset( | |
| src_dir=BTO_MANY_RES, | |
| split="val", | |
| upsample_factor=4, | |
| ) | |
| dataset[0] | |