| import json |
| import os |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.multiprocessing |
| import torch.nn.functional as F |
| from PIL import Image |
| from torch.utils.data import Dataset |
|
|
| from ..preprocess import normalize_bands |
|
|
| torch.multiprocessing.set_sharing_strategy("file_system") |
|
|
|
|
| def split_and_filter_tensors(image_tensor, label_tensor): |
| """ |
| Split image and label tensors into 9 tiles and filter based on label content. |
| |
| Args: |
| image_tensor (torch.Tensor): Input tensor of shape (13, 240, 240) |
| label_tensor (torch.Tensor): Label tensor of shape (240, 240) |
| |
| Returns: |
| list of tuples: Each tuple contains (image_tile, label_tile) |
| """ |
| assert image_tensor.shape == (13, 240, 240), "Image tensor must be of shape (13, 240, 240)" |
| assert label_tensor.shape == (240, 240), "Label tensor must be of shape (240, 240)" |
|
|
| tile_size = 80 |
| tiles = [] |
| labels = [] |
|
|
| for i in range(3): |
| for j in range(3): |
| |
| image_tile = image_tensor[ |
| :, i * tile_size : (i + 1) * tile_size, j * tile_size : (j + 1) * tile_size |
| ] |
|
|
| |
| label_tile = label_tensor[ |
| i * tile_size : (i + 1) * tile_size, j * tile_size : (j + 1) * tile_size |
| ] |
|
|
| |
| if torch.any(label_tile > 0): |
| tiles.append(image_tile) |
| labels.append(label_tile) |
|
|
| return tiles, labels |
|
|
|
|
| class PrepMADOSDataset(Dataset): |
| def __init__(self, root_dir, split_file): |
| self.root_dir = root_dir |
|
|
| with open(os.path.join(root_dir, "splits", split_file), "r") as f: |
| self.scene_list = [line.strip() for line in f] |
|
|
| def __len__(self): |
| return len(self.scene_list) |
|
|
| def __getitem__(self, idx): |
| scene_name = self.scene_list[idx] |
| scene_num_1 = scene_name.split("_")[1] |
| scene_num_2 = scene_name.split("_")[2] |
|
|
| |
| B1 = self._load_band(scene_num_1, scene_num_2, [442, 443], 60) |
| B2 = self._load_band(scene_num_1, scene_num_2, [492], 10) |
| B3 = self._load_band(scene_num_1, scene_num_2, [559, 560], 10) |
| B4 = self._load_band(scene_num_1, scene_num_2, [665], 10) |
| B5 = self._load_band(scene_num_1, scene_num_2, [704], 20) |
| B7 = self._load_band(scene_num_1, scene_num_2, [780, 783], 20) |
| B8 = self._load_band(scene_num_1, scene_num_2, [833], 10) |
| B8A = self._load_band(scene_num_1, scene_num_2, [864, 865], 20) |
| B11 = self._load_band(scene_num_1, scene_num_2, [1610, 1614], 20) |
| B12 = self._load_band(scene_num_1, scene_num_2, [2186, 2202], 20) |
|
|
| B1 = self._resize(B1) |
| B5 = self._resize(B5) |
| B7 = self._resize(B7) |
| B8A = self._resize(B8A) |
| B11 = self._resize(B11) |
| B12 = self._resize(B12) |
|
|
| |
| B6 = (B5 + B7) / 2 |
| B9 = B8A |
| B10 = (B8A + B11) / 2 |
|
|
| image = torch.cat( |
| [B1, B2, B3, B4, B5, B6, B7, B8, B8A, B9, B10, B11, B12], axis=1 |
| ).squeeze(0) |
| mask = self._load_mask(scene_num_1, scene_num_2).squeeze(0).squeeze(0) |
| images, masks = split_and_filter_tensors(image, mask) |
|
|
| return images, masks |
|
|
| def _load_band(self, scene_num_1, scene_num_2, bands, resolution): |
| for band in bands: |
| band_path = f"{self.root_dir}/Scene_{scene_num_1}/{resolution}/Scene_{scene_num_1}_L2R_rhorc_{band}_{scene_num_2}.tif" |
| if os.path.exists(band_path): |
| return ( |
| torch.from_numpy(np.array(Image.open(band_path))) |
| .float() |
| .unsqueeze(0) |
| .unsqueeze(0) |
| ) |
| print(f"COULDNT FIND {scene_num_1, scene_num_2, bands, resolution}") |
|
|
| def _resize(self, image): |
| return F.interpolate(image, size=240, mode="bilinear", align_corners=False) |
|
|
| def _load_mask(self, scene_num_1, scene_num_2): |
| mask_path = ( |
| f"{self.root_dir}/Scene_{scene_num_1}/10/Scene_{scene_num_1}_L2R_cl_{scene_num_2}.tif" |
| ) |
| return torch.from_numpy(np.array(Image.open(mask_path))).long().unsqueeze(0).unsqueeze(0) |
|
|
|
|
| def get_mados(save_path, root_dir="MADOS", split_file="test_X.txt"): |
| dataset = PrepMADOSDataset(root_dir=root_dir, split_file=split_file) |
| all_images = [] |
| all_masks = [] |
| for i in dataset: |
| all_images += i[0] |
| all_masks += i[1] |
|
|
| split_images = torch.stack(all_images) |
| split_masks = torch.stack(all_masks) |
| torch.save(obj={"images": split_images, "labels": split_masks}, f=save_path) |
|
|
|
|
| class MADOSDataset(Dataset): |
| def __init__(self, path_to_splits: Path, split: str, norm_operation, augmentation, partition): |
| with (Path(__file__).parents[0] / Path("configs") / Path("mados.json")).open("r") as f: |
| config = json.load(f) |
|
|
| |
| assert split in ["train", "val", "valid", "test"] |
| if split == "valid": |
| split = "val" |
|
|
| self.band_info = config["band_info"] |
| self.split = split |
| self.augmentation = augmentation |
| self.norm_operation = norm_operation |
|
|
| torch_obj = torch.load(path_to_splits / f"MADOS_{split}.pt") |
| self.images = torch_obj["images"] |
| self.labels = torch_obj["labels"] |
|
|
| if (partition != "default") and (split == "train"): |
| with open(path_to_splits / f"{partition}_partition.json", "r") as json_file: |
| subset_indices = json.load(json_file) |
|
|
| self.images = self.images[subset_indices] |
| self.labels = self.labels[subset_indices] |
|
|
| def __len__(self): |
| return self.images.shape[0] |
|
|
| def __getitem__(self, idx): |
| image = self.images[idx] |
| label = self.labels[idx] |
| image = torch.tensor(normalize_bands(image.numpy(), self.norm_operation, self.band_info)) |
| image, label = self.augmentation.apply(image, label, "seg") |
| return {"s2": image, "target": label} |
|
|