import os import torch import random import numpy as np import albumentations as A import torch.nn.functional as F from torch.utils.data import Dataset from typing import Dict, Optional, Tuple, List, Union from glob import glob MOS2_SYNTHETIC = "data/synth-datasets" MOS2_SAPPHIRE_DIR = "data/raw-data/11-19-24/2. MoS2 on Sapphire" MOS2_SILICON_DIR = "data/raw-data/11-19-24/2. MoS2 on Sapphire" MOS2_SEF_FULL_RES_SRC_DIR = "data/raw-data/1-23-25" MOS2_SEF_MANY_RES_SRC_DIR = "/playpen/mufan/levi/tianlong-chen-lab/sparse-cafm/data/raw-data/2-6-25" BTO_MANY_RES = "data/raw-data/3-12-25" TRAIN_SPLIT = "train" VAL_SPLIT = "val" TEST_SPLIT = "test" ORIGINAL_IMAGE_SIZE = (512, 512) CROPPED_IMG_SIDE_LENGTH = 64 IMG_SIZE_UM = 2.0 NORMALIZED_DATA_RANGE = (0.0, 1.0) class MOS2SRDataset(Dataset): """ Dataset class for sparse-sampling of MoS2 samples collected on various substrates. :Definitions: - X: surface height map | (H, W) - y: current map | (H, W) """ def __init__( self, src_dir: str = MOS2_SEF_FULL_RES_SRC_DIR, split: str = "train", upsample_factor: int = 2, steps_per_epoch: int = 100, original_image_size: Tuple[int, int] = ORIGINAL_IMAGE_SIZE, ): """ Parameters --- split : str Dataset split; one of {'train', 'val', 'test'}. - 'train': Uses synthetic downsampling for training samples. - 'val': Uses synthetic downsampling for validation samples. - 'test': Uses only real downsampled data; supported by `MOS2_SEF_MANY_RES_SRC_DIR` and `BTO_MANY_RES` datasets. steps_per_epoch : int Number of batches per epoch. Data is randomly augmented, so the number of samples per epoch is arbitrary. upsample_factor : int Upsampling factor; must be one of {1, 2, 4, 8}. original_image_size : tuple of int Size of the original images in the dataset, e.g., (512, 512). """ super(MOS2SRDataset, self).__init__() self.steps_per_epoch: int = steps_per_epoch assert split.lower() in ["train", "val", "test"], f"Error: invalid split. Expected 'train' or 'val'" self.split: str = split.lower() assert upsample_factor in [1, 2, 4, 8], f"Error: expected upsample_factor in: [1, 2, 4, 8]" self.upsample_factor = upsample_factor # size of subsamples to crop from original (512, 512) data self.side_length = 128 if self.upsample_factor == 2: # [64, 64] -> [128, 128] self.side_length == 64 * 2 if self.upsample_factor == 4: # [64, 64] -> [256, 256] self.side_length = 64 * 4 if self.upsample_factor == 8: # [48, 48] -> [384, 384] self.side_length = 48 * 8 assert os.path.isdir(src_dir), f"Error: invalid src_dir: {src_dir}" self.src_dir = src_dir self.original_image_size: Tuple[int, int] = original_image_size self.augmentation_pipeline = self._create_augmentation_pipeline() # (B, C, H, W) self.current_maps = None self.topo_maps = None # paths to un-normalized, high-precision current maps self._raw_current_fps: Optional[List[str]] = None self._raw_topo_fps: Optional[List[str]] = None # for normalizing X, y, respectively later self.current_maps_mean = 0.0 self.current_maps_std = 0.0 # use these vals to normalize all data -> [0, 1] self.current_maps_max = 0.0 self.current_maps_min = 0.0 self.topo_maps_mean = 0.0 self.topo_maps_std = 0.0 # original sample size is 2um self.img_size_um = IMG_SIZE_UM # all data (current + topo maps) normalized to -> [0, 1] self.normalized_data_range: Tuple[float, float] = NORMALIZED_DATA_RANGE # load all data from src files if self.src_dir == MOS2_SEF_FULL_RES_SRC_DIR: self._load_imgs_mos2_sef() elif src_dir == MOS2_SILICON_DIR or src_dir == MOS2_SAPPHIRE_DIR: self._load_imgs_sil_saf() elif src_dir == BTO_MANY_RES: self._load_bto_many_res() elif src_dir == MOS2_SYNTHETIC: self._load_imgs_mos2_synth() else: raise Exception(f"Error: unsupported dataset: {src_dir}") # TODO: experiment with this # remove L -> R gradients; remove back contact bias # self._remove_gradients() # find the mean/std of current and topo maps self._calculate_mean_std() def _load_bto_many_res(self) -> None: """ BTO dataset only contains surface morphology maps. - 4x scans @{512, 256, 128, 64} """ raise Exception("BTO dataset is not supported with this dataloader.") def _load_imgs_mos2_synth(self) -> None: # current_map_regex = f"{self.src_dir}/current/{self.split}/current-maps/*.npy" # topo_map_regex = f"{self.src_dir}/topology/{self.split}/topo-maps/*.npy" # HACK: only load in train samples current_map_regex = f"{self.src_dir}/current/train/current-maps/*.npy" topo_map_regex = f"{self.src_dir}/topology/train/topo-maps/*.npy" NUM_SAMPLES = min(len(glob(current_map_regex)), len(glob(topo_map_regex))) self._raw_current_fps = sorted(glob(current_map_regex))[:NUM_SAMPLES] self._raw_topo_fps = sorted(glob(topo_map_regex))[:NUM_SAMPLES] assert (len(self._raw_current_fps) > 0), f"Error: could not load images using regex: {current_map_regex}" assert (len(self._raw_topo_fps) > 0), f"Error: could not load images using regex: {current_map_regex}" # [H, W, C] self.current_maps: List[np.ndarray] = [np.load(fp) for fp in self._raw_current_fps] self.topo_maps: List[np.ndarray] = [np.load(fp) for fp in self._raw_topo_fps] # validate current, topo map paris are aligned _current_fps_basenames = [os.path.basename(fp) for fp in self._raw_current_fps] _topo_fps_basenames = [os.path.basename(fp) for fp in self._raw_topo_fps] assert (_current_fps_basenames == _topo_fps_basenames), f"Error: misalignment of current maps and topo maps during dataloading" # convert maps to type -> float64 self.current_maps = [cm.astype(np.float64) for cm in self.current_maps] self.topo_maps = [tm.astype(np.float64) for tm in self.topo_maps] # [H, W, C] -> [H, W] by averaging across the channel dimension self.current_maps = [np.mean(cm, axis=-1) for cm in self.current_maps] self.topo_maps = [np.mean(tm, axis=-1) for tm in self.topo_maps] def _load_imgs_mos2_sef(self) -> None: """ Load current-map + topo-map data from source dir. """ current_map_regex = f"{self.src_dir}/*Current*.npy" topo_map_regex = f"{self.src_dir}/*Height*.npy" self._raw_current_fps = sorted(glob(current_map_regex)) self._raw_topo_fps = sorted(glob(topo_map_regex)) assert (len(self._raw_current_fps) > 0), f"Error: could not load images using regex: {current_map_regex}" assert (len(self._raw_topo_fps) > 0), f"Error: could not load images using regex: {current_map_regex}" # [H, W, C] self.current_maps: List[np.ndarray] = [np.load(fp) for fp in self._raw_current_fps] self.topo_maps: List[np.ndarray] = [np.load(fp) for fp in self._raw_topo_fps] # validate current, topo map paris are aligned _current_fps_basenames = [os.path.basename(fp)[:4] for fp in self._raw_current_fps] _topo_fps_basenames = [os.path.basename(fp)[:4] for fp in self._raw_topo_fps] assert (_current_fps_basenames == _topo_fps_basenames), f"Error: misalignment of current maps and topo maps during dataloading" # convert maps to type -> float64 self.current_maps = [cm.astype(np.float64) for cm in self.current_maps] self.topo_maps = [tm.astype(np.float64) for tm in self.topo_maps] def _load_imgs_sil_saf(self) -> None: """ TODO: make less clunky and hard-coded. Load current-map + topo-map data from source dir. """ current_map_regex = f"{self.src_dir}/*/*Current*.npy" topo_map_regex = f"{self.src_dir}/*/*Topo*.npy" self._raw_current_fps = glob(current_map_regex) self._raw_topo_fps = glob(topo_map_regex) assert ( len(self._raw_current_fps) > 0 ), f"Error: could not load images using regex: {current_map_regex}" assert ( len(self._raw_topo_fps) > 0 ), f"Error: could not load images using regex: {current_map_regex}" # (H, W) self.current_maps: List[np.ndarray] = [ np.load(fp) for fp in self._raw_current_fps ] self.topo_maps: List[np.ndarray] = [np.load(fp) for fp in self._raw_topo_fps] # validate current, topo map paris are aligned _current_fps_basenames = [ os.path.basename(fp).split("Current")[0] for fp in self._raw_current_fps ] _topo_fps_basenames = [ os.path.basename(fp).split("Topo")[0] for fp in self._raw_topo_fps ] assert ( _current_fps_basenames == _topo_fps_basenames ), f"Error: misalignment of current maps and topo maps during dataloading" # HACK: only use samples: [0, 1, 2, 3] self.current_maps = self.current_maps[0:4] self.topo_maps = self.topo_maps[0:4] def __remove_gradient(self, current_map: np.ndarray) -> np.ndarray: """ Find a line of best fit through the column-wise average current of a sample y. This method helps to remove the bias create by the back-contact; a global bias a model could not be expected to remove without additional information. """ corrected_map = np.copy(current_map) H, W = current_map.shape # column indices from 0..W-1 x = np.arange(W) # 1. shape: (W,) column_means = [np.mean(current_map[:, w]) for w in range(W)] # 2. fit a line (degree=1 polynomial) to these means # polyfit returns [slope, intercept] for a degree=1 polynomial slope, intercept = np.polyfit(x, column_means, deg=1) # evaluate the fitted line at each column index # shape: (W,) best_fit_line = slope * x + intercept # 3. subtract the fitted line from each pixel in the column # for column w, best_fit_line[w] is the "gradient" we want to remove for w in range(W): corrected_map[:, w] -= best_fit_line[w] return corrected_map def _remove_gradients(self): """ Remove column-wise gradients from each map in self.current_maps by: 1. Computing column-wise mean of each channel. 2. Fitting a best-fit line to these means. 3. Subtracting that line (per column) from the original values. """ for i, current_map in enumerate(self.current_maps): self.current_maps[i] = self.__remove_gradient(current_map) def _calculate_mean_std(self) -> None: """ Calculate the mean and std of topo/curr maps. Saves results as internal vars. """ self.current_maps_mean = np.mean(np.array(self.current_maps)) self.current_maps_std = np.std(np.array(self.current_maps)) self.current_maps_max = np.amax(np.array(self.current_maps)) self.current_maps_min = np.amin(np.array(self.current_maps)) self.topo_maps_mean = np.mean(np.array(self.topo_maps)) self.topo_maps_std = np.std(np.array(self.topo_maps)) self.topo_maps_max = np.amax(np.array(self.topo_maps)) self.topo_maps_min = np.amin(np.array(self.topo_maps)) def _create_augmentation_pipeline(self): return A.Compose( [ # A.HorizontalFlip(p=0.5), # A.VerticalFlip(p=0.5), # A.RandomRotate90(p=0.5), # A.Rotate(limit=15, p=0.5), # A.ElasticTransform(), A.RandomCrop(width=self.side_length, height=self.side_length, p=1.0), ], additional_targets={ "X": "image", "X_mask": "mask", "y": "mask", }, ) def __len__(self) -> int: """ len(self) == self.steps_per_epoch """ return self.steps_per_epoch def __getitem__(self, index: int) -> Dict: """ Get the next randomly sampled item from the dataset. :param index: currently unused, necessiary for batch data-loading :returns: ``` { 'X' : torch.Tensor, topo-map w/ shape [H, W] 'X_sparse': torch.Tensor, topo-map w/ shape [H / upsample_factor, W / upsample_factor] 'X_unnorm': torch.Tensor, topo-map w/ shape [H / upsample_factor, W / upsample_factor] 'y' : torch.Tensor, current-map w/ shape [H, W] 'y_sparse': torch.Tensor, current-map w/ shape [H / upsample_factor, W / upsample_factor] 'y_unnorm': torch.Tensor, current-map w/ shape [H / upsample_factor, W / upsample_factor] } """ # NOTE: we only consider samples: [0, 1, 2, 3]; # HACK: hard-coded train/val splits # choose a random sample idx if self.split == TRAIN_SPLIT: # randint is inclusive: [a, b] # select a random sample from self.data[:-1] sample_idx = random.randint(0, len(self.current_maps) - 2) elif self.split == VAL_SPLIT: # select the final data sample: self.data[-1] sample_idx = len(self.current_maps) - 1 else: raise Exception(f"Invalid split: {self.split}") # [512, 512]; un-normalized, full-sized topography map X: np.ndarray = self.topo_maps[sample_idx] # [512, 512]; un-normalized, full-sized current map y: np.ndarray = self.current_maps[sample_idx] # ---- select a [128, 128] subset from full-sample ---- augmented: np.ndarray = self.augmentation_pipeline(image=y, X=X, X_mask=X, y=y) # [512, 512] -> [128, 128] + apply augs # HACK: always apply augmentations if self.split == "train": X: np.ndarray = augmented["X"] y: np.ndarray = augmented["image"] elif self.split == "val": X: np.ndarray = augmented["X_mask"] y: np.ndarray = augmented["y"] else: raise Exception("Something has gone very wrong") X: torch.Tensor = torch.Tensor(X).float() y: torch.Tensor = torch.Tensor(y).float() # [128, 128] X_unnorm = X.clone() y_unnorm = y.clone() # -> [0, 1] X = (X - self.topo_maps_min) / ( self.topo_maps_max - self.topo_maps_min ) # -> [0, 1] y = (y - self.current_maps_min) / ( self.current_maps_max - self.current_maps_min ) # ---- bicubic downsampling ---- # -> [1, 1, 128, 128] X_unsqueezed = X.unsqueeze(0).unsqueeze(0) # -> [H', W'] X_sparse = F.interpolate( X_unsqueezed, scale_factor=1/self.upsample_factor, mode='bicubic', align_corners=False ) X_sparse = X_sparse.squeeze(0).squeeze(0) # -> [1, 1, 128, 128] y_unsqueezed = y.unsqueeze(0).unsqueeze(0) # -> [H', W'] y_sparse = F.interpolate( y_unsqueezed, scale_factor=1/self.upsample_factor, mode='bicubic', align_corners=False ) y_sparse = y_sparse.squeeze(0).squeeze(0) assert (X.max() <= 1.0 and X.min() >= 0.0), f"Error normalizing X sample: {X.shape}" assert (y.max() <= 1.0 and y.min() >= 0.0), f"Error normalizing y sample: {y.shape}" return { "X": X, "X_sparse": X_sparse, "X_unnorm": X_unnorm, "y": y, "y_sparse": y_sparse, "y_unnorm": y_unnorm, } class BTOSRDataset(Dataset): """ Dataset class used for sparse-sampling of BTO surface morphology maps. :Definitions: - X: surface height map | (H, W) """ def __init__( self, src_dir: str = BTO_MANY_RES, split: str = "train", upsample_factor: int = 2, steps_per_epoch: int = 100, original_image_size: Tuple[int, int] = ORIGINAL_IMAGE_SIZE, ): """ Parameters --- split : str Dataset split; one of {'train', 'val', 'test'}. - 'train': Uses synthetic downsampling for training samples. - 'val': Uses synthetic downsampling for validation samples. - 'test': Uses only real downsampled data. steps_per_epoch : int Number of batches per epoch. Data is randomly augmented, so the number of samples per epoch is arbitrary. upsample_factor : int Upsampling factor; must be one of {2, 4, 8}. original_image_size : tuple of int Size of the original images in the dataset, e.g., (512, 512). """ super(BTOSRDataset, self).__init__() self.steps_per_epoch: int = steps_per_epoch assert split.lower() in ["train", "val", "test"], f"Error: invalid split. Expected 'train' or 'val'" self.split: str = split.lower() assert upsample_factor in [2, 4, 8], f"Error: expected upsample_factor in: [2, 4, 8]" self.upsample_factor = upsample_factor # size of subsamples to crop from original (512, 512) data self.side_length = 128 if self.upsample_factor == 2: # [64, 64] -> [128, 128] self.side_length == 64 * 2 if self.upsample_factor == 4: # [64, 64] -> [256, 256] self.side_length = 64 * 4 if self.upsample_factor == 8: # [48, 48] -> [384, 384] self.side_length = 48 * 8 assert os.path.isdir(src_dir), f"Error: invalid src_dir: {src_dir}" self.src_dir = src_dir self.original_image_size: Tuple[int, int] = original_image_size # dedicated train/val augmentation pipelines self.train_augmentation_pipeline = self._create_train_augmentation_pipeline() self.val_augmentation_pipeline = self._create_val_augmentation_pipeline() # (B, H, W) self.topo_maps = None # paths to un-normalized, high-precision current maps self._raw_topo_fps: Optional[List[str]] = None # use these vals to normalize all data -> [0, 1] self.topo_maps_mean = 0.0 self.topo_maps_std = 0.0 # original sample size is 2umx2um (512x512) self.img_size_um = IMG_SIZE_UM # all data (current + topo maps) normalized to -> [0, 1] self.normalized_data_range: Tuple[float, float] = NORMALIZED_DATA_RANGE self._load_bto_many_res() # find the mean/std of current and topo maps self._calculate_mean_std() def _load_bto_many_res(self) -> None: """ BTO dataset only contains surface morphology maps. - 4x scans @{512, 256, 128, 64} """ topo_map_regex_64 = f"{self.src_dir}/*64*.npy" topo_map_regex_128 = f"{self.src_dir}/*128*.npy" topo_map_regex_256 = f"{self.src_dir}/*256*.npy" topo_map_regex_512 = f"{self.src_dir}/*512*.npy" self._raw_topo_64_fps = sorted(glob(topo_map_regex_64)) self._raw_topo_128_fps = sorted(glob(topo_map_regex_128)) self._raw_topo_256_fps = sorted(glob(topo_map_regex_256)) self._raw_topo_512_fps = sorted(glob(topo_map_regex_512)) assert (len(self._raw_topo_64_fps) > 0), f"Error: could not load images using regex: {topo_map_regex_64}" # [H, W] self.topo_maps_64 : List[np.ndarray] = [np.load(fp) for fp in self._raw_topo_64_fps] self.topo_maps_128: List[np.ndarray] = [np.load(fp) for fp in self._raw_topo_128_fps] self.topo_maps_256: List[np.ndarray] = [np.load(fp) for fp in self._raw_topo_256_fps] self.topo_maps_512: List[np.ndarray] = [np.load(fp) for fp in self._raw_topo_512_fps] # convert maps to type -> float64 self.topo_maps_64 = [tm.astype(np.float64) for tm in self.topo_maps_64] self.topo_maps_128 = [tm.astype(np.float64) for tm in self.topo_maps_128] self.topo_maps_256 = [tm.astype(np.float64) for tm in self.topo_maps_256] self.topo_maps_512 = [tm.astype(np.float64) for tm in self.topo_maps_512] def _calculate_mean_std(self) -> None: """ Calculate the mean and std of topo/curr maps. Saves results as internal vars. """ # NOTE: we only use the first three samples to calculate global dataset statistics, validation data # is not used... self.topo_maps_mean = np.mean(np.array(self.topo_maps_512)[:-1]) self.topo_maps_std = np.std(np.array(self.topo_maps_512)[:-1]) self.topo_maps_max = np.amax(np.array(self.topo_maps_512)[:-1]) self.topo_maps_min = np.amin(np.array(self.topo_maps_512)[:-1]) def _create_train_augmentation_pipeline(self): return A.Compose( [ # A.HorizontalFlip(p=0.5), # A.VerticalFlip(p=0.5), # A.RandomRotate90(p=0.5), # A.Rotate(limit=15, p=0.5), A.RandomCrop(width=self.side_length, height=self.side_length, p=1.0), ], additional_targets={ "X": "image", "X_mask": "mask", }, ) def _create_val_augmentation_pipeline(self): return A.Compose( [ # A.HorizontalFlip(p=0.5), # A.VerticalFlip(p=0.5), # A.RandomRotate90(p=0.5), # A.Rotate(limit=15, p=0.5), A.RandomCrop(width=self.side_length, height=self.side_length, p=1.0), ], additional_targets={ "X": "image", "X_mask": "mask", }, ) def __len__(self) -> int: """ len(self) == self.steps_per_epoch """ return self.steps_per_epoch def __getitem__(self, index: int) -> Dict: """ Get the next randomly sampled item from the dataset. :param index: currently unused, necessiary for batch data-loading :returns: ``` { 'X_64' : torch.Tensor, topo-map w/ shape [H, W] 'X_128' : torch.Tensor, topo-map w/ shape [H, W] 'X_256' : torch.Tensor, topo-map w/ shape [H, W] 'X_512' : torch.Tensor, topo-map w/ shape [H, W] 'X_sparse': torch.Tensor, topo-map w/ shape [H / upsample_factor, W / upsample_factor] 'X_unnorm': torch.Tensor, topo-map w/ shape [H / upsample_factor, W / upsample_factor] } """ # NOTE: we only consider samples: [0, 1, 2, 3]; # choose a random sample idx if self.split == TRAIN_SPLIT: # randint is inclusive: [a, b] # select a random sample from self.data[:-1] sample_idx = random.randint(0, len(self.topo_maps_512) - 2) elif self.split == VAL_SPLIT: # select the final data sample: self.data[-1] sample_idx = len(self.topo_maps_512) - 1 elif self.split == TEST_SPLIT: sample_idx = len(self.topo_maps_512) - 1 else: raise Exception(f"Invalid split: {self.split}") # [512, 512]; un-normalized, full-sized topography map X_512: np.ndarray = self.topo_maps_512[sample_idx] X_256: np.ndarray = self.topo_maps_256[sample_idx] X_128: np.ndarray = self.topo_maps_128[sample_idx] X_64 : np.ndarray = self.topo_maps_64[sample_idx] X: np.ndarray = X_512.copy() # ---- select a [128, 128] subset from full-sample ---- # TODO: create a validation set augmentation pipeline if self.split == "train": augmented: np.ndarray = self.train_augmentation_pipeline(image=X, X=X, X_mask=X) elif self.split == "val": augmented: np.ndarray = self.val_augmentation_pipeline(image=X, X=X, X_mask=X) else: raise Exception() # [512, 512] -> [128, 128] + apply augs if self.split == TRAIN_SPLIT: X: np.ndarray = augmented["X"] elif self.split == VAL_SPLIT: X: np.ndarray = augmented["X_mask"] elif self.split == TEST_SPLIT: # don't apply augs to test set pass else: raise Exception("Something has gone very wrong") X : torch.Tensor = torch.Tensor(X).float() X_512: torch.Tensor = torch.Tensor(X_512).float() X_256: torch.Tensor = torch.Tensor(X_256).float() X_128: torch.Tensor = torch.Tensor(X_128).float() X_64 : torch.Tensor = torch.Tensor(X_64).float() X_unnorm = X_512.clone() # -> [0, 1] X = (X - self.topo_maps_min) / ( self.topo_maps_max - self.topo_maps_min ) X_512 = (X_512 - self.topo_maps_min) / ( self.topo_maps_max - self.topo_maps_min ) X_256 = (X_256 - self.topo_maps_min) / ( self.topo_maps_max - self.topo_maps_min ) X_128 = (X_128 - self.topo_maps_min) / ( self.topo_maps_max - self.topo_maps_min ) X_64 = (X_64 - self.topo_maps_min) / ( self.topo_maps_max - self.topo_maps_min ) # ---- bicubic downsampling ---- # -> [1, 1, 128, 128] X_unsqueezed = X.unsqueeze(0).unsqueeze(0) # HACK: linear downsampling # -> [H', W'] X_sparse = X_unsqueezed[:, :, ::self.upsample_factor, ::self.upsample_factor] # # -> [H', W'] # X_sparse = F.interpolate( # X_unsqueezed, # scale_factor=1/self.upsample_factor, # mode='bicubic', # align_corners=False # ) X_sparse = X_sparse.squeeze(0).squeeze(0) assert (X.max() <= 1.0 and X.min() >= 0.0), f"Error normalizing X sample: {X.shape}" assert (X_512.max() <= 1.0 and X_512.min() >= 0.0), f"Error normalizing X sample: {X_512.shape}" # assert (X_256.max() <= 1.0 and X_256.min() >= 0.0), f"Error normalizing X sample: {X_256.shape}" # assert (X_128.max() <= 1.0 and X_128.min() >= 0.0), f"Error normalizing X sample: {X_128.shape}" # assert (X_64.max() <= 1.0 and X_64.min() >= 0.0), f"Error normalizing X sample: {X_64.shape}" return { "X" : X, "X_sparse": X_sparse, "X_512" : X_512, "X_256" : X_256, "X_128" : X_128, "X_64" : X_64, "X_unnorm": X_unnorm, } class UnifiedMOS2SRDataset(Dataset): """ A horrible abomination that contains all datasets in one. """ def __init__( self, split: str = "train", upsample_factor: int = 2, steps_per_epoch: int = 100, original_image_size: Tuple[int, int] = ORIGINAL_IMAGE_SIZE, ): """ :param split: "train" or "val" :param steps_per_epoch: data is sampled using random augmentations, therefore the # sample per epoch is arbitrary :param upsample_factor: 1, 2, 4 or 8x :param original_image_size: size of the original images in the dataset: e.g., (512, 512) """ super(UnifiedMOS2SRDataset, self).__init__() self.mos2_sef_dataset = MOS2SRDataset( src_dir=MOS2_SEF_SRC_DIR, split=split, upsample_factor=upsample_factor, steps_per_epoch=steps_per_epoch, original_image_size=original_image_size, ) self.sapphire_dataset = MOS2SRDataset( src_dir=MOS2_SAPPHIRE_DIR, split=split, upsample_factor=upsample_factor, steps_per_epoch=steps_per_epoch, original_image_size=original_image_size, ) self.silicon_datset = MOS2SRDataset( src_dir=MOS2_SILICON_DIR, split=split, upsample_factor=upsample_factor, steps_per_epoch=steps_per_epoch, original_image_size=original_image_size, ) def __len__(self): return len(self.mos2_sef_dataset) def __getitem__(self, index: int) -> dict: """ [HACK]: currently returning items for unconditional ControlNet training. Return a random item from one of three datasets. """ item = {} choice = random.random() # if choice < .33: item = self.mos2_sef_dataset.__getitem__(index) # elif choice < .66: item = self.sapphire_dataset.__getitem__(index) # else: item = self.silicon_datset.__getitem__(index) # HACK: only use two datasets for transfer learning ablation if choice < .50: item = self.mos2_sef_dataset.__getitem__(index) else : item = self.sapphire_dataset.__getitem__(index) # else: item = self.silicon_datset.__getitem__(index) return item if __name__ == "__main__": dataset = BTOSRDataset( src_dir=BTO_MANY_RES, split="val", upsample_factor=4, ) dataset[0]