| """ |
| Dataset definitions for CASWiT training and evaluation. |
| |
| This module provides dataset classes for semantic segmentation with |
| HR/LR dual-branch processing. |
| """ |
|
|
| import os |
| import math |
| from pathlib import Path |
| from typing import Optional, Union, Tuple, List, Dict |
| import numpy as np |
| import torch |
| from torch import Tensor |
| from torch.utils.data import Dataset |
| from PIL import Image |
| from tifffile import imread as tiff_imread |
| from torchvision import transforms |
| from dataclasses import dataclass |
| import rasterio |
| from rasterio.windows import Window |
|
|
|
|
| class SemanticSegmentationDatasetFusion(Dataset): |
| """ |
| Dataset for HR/LR fusion training on FLAIRHub. |
| |
| Returns (image_hr, mask_hr, image_lr, mask_lr): |
| - image_hr: 512x512 crop starting at (256, 256) |
| - image_lr: full image downsampled by factor 2 |
| - mask >=15 replaced by 255 (ignore) |
| - transforms applied to images (ToTensor + Normalize) and mask -> LongTensor |
| - optional JOINT augmentations applied consistently to HR/LR + masks |
| """ |
| def __init__(self, image_dir: Path, mask_dir: Path, transform: Optional[transforms.Compose] = None, augment = None): |
| self.image_dir = Path(image_dir) |
| self.mask_dir = Path(mask_dir) |
| self.image_filenames = sorted(os.listdir(self.image_dir)) |
| self.mask_filenames = sorted(os.listdir(self.mask_dir)) |
| assert len(self.image_filenames) == len(self.mask_filenames), "Images/Masks count mismatch" |
| self.transform = transform |
| self.augment = augment |
|
|
| def __len__(self): |
| return len(self.image_filenames) |
|
|
| def __getitem__(self, idx): |
| image_path = self.image_dir / self.image_filenames[idx] |
| mask_path = self.mask_dir / self.mask_filenames[idx] |
|
|
| image = load_image(image_path) |
| mask = load_mask(mask_path) |
| mask[mask >= 15] = 255 |
|
|
| |
| hr_crop_size = 512 |
| crop_x, crop_y = 256, 256 |
|
|
| image_hr = image[crop_x:crop_x + hr_crop_size, crop_y:crop_y + hr_crop_size] |
| mask_hr = mask[crop_x:crop_x + hr_crop_size, crop_y:crop_y + hr_crop_size] |
| |
| |
| image_lr = image[::2,::2,:] |
| mask_lr = mask[::2,::2] |
|
|
| |
| img_hr_pil = to_pil_uint8(image_hr) |
| img_lr_pil = to_pil_uint8(image_lr) |
|
|
| m_hr_pil = Image.fromarray(mask_hr.astype(np.uint8), mode="L") |
| m_lr_pil = Image.fromarray(mask_lr.astype(np.uint8), mode="L") |
|
|
| |
| if self.augment is not None: |
| img_hr_pil, m_hr_pil, img_lr_pil, m_lr_pil = self.augment( |
| img_hr_pil, m_hr_pil, img_lr_pil, m_lr_pil |
| ) |
|
|
| |
| if self.transform: |
| image_hr = self.transform(img_hr_pil) |
| image_lr = self.transform(img_lr_pil) |
| else: |
| image_hr = to_tensor_img(np.array(img_hr_pil)) |
| image_lr = to_tensor_img(np.array(img_lr_pil)) |
|
|
| |
| mask_hr = torch.as_tensor(np.array(m_hr_pil, dtype=np.uint8), dtype=torch.long) |
| mask_lr = torch.as_tensor(np.array(m_lr_pil, dtype=np.uint8), dtype=torch.long) |
|
|
| return image_hr, mask_hr, image_lr, mask_lr |
|
|
|
|
| class SemanticSegmentationDatasetHR(Dataset): |
| """ |
| Dataset for HR-only training (single branch, no LR). |
| |
| Returns (image_hr, mask_hr): |
| - image_hr: 512x512 crop starting at (256, 256) |
| - mask >=15 replaced by 255 (ignore) |
| """ |
| def __init__(self, image_dir: Path, mask_dir: Path, transform: Optional[transforms.Compose] = None): |
| self.image_dir = Path(image_dir) |
| self.mask_dir = Path(mask_dir) |
| self.image_filenames = sorted(os.listdir(self.image_dir)) |
| self.mask_filenames = sorted(os.listdir(self.mask_dir)) |
| assert len(self.image_filenames) == len(self.mask_filenames), "Images/Masks count mismatch" |
| self.transform = transform |
|
|
| def __len__(self): |
| return len(self.image_filenames) |
|
|
| def __getitem__(self, idx): |
| image_path = self.image_dir / self.image_filenames[idx] |
| mask_path = self.mask_dir / self.mask_filenames[idx] |
|
|
| image = load_image(image_path) |
| mask = load_mask(mask_path) |
| mask[mask >= 15] = 255 |
|
|
| crop_x, crop_y = 256, 256 |
| image_hr = image[crop_x:crop_x + 512, crop_y:crop_y + 512] |
| mask_hr = mask[crop_x:crop_x + 512, crop_y:crop_y + 512] |
|
|
| if self.transform: |
| image_hr = self.transform(to_pil_uint8(image_hr)) |
| else: |
| image_hr = to_tensor_img(image_hr) |
| mask_hr = torch.tensor(mask_hr, dtype=torch.long) |
| return image_hr, mask_hr |
|
|
|
|
| |
| |
| |
|
|
| def load_image(path: Union[str, Path]) -> np.ndarray: |
| """ |
| Load an image as HxWx3 (RGB), float32 [0,1]. |
| Handles both TIFF and PNG files gracefully. |
| """ |
| p = str(path) |
| arr = None |
|
|
| |
| if tiff_imread is not None: |
| try: |
| arr = tiff_imread(p) |
| except Exception: |
| arr = None |
|
|
| |
| if arr is None: |
| with Image.open(p) as im: |
| arr = np.array(im.convert("RGB")) |
|
|
| |
| if arr.ndim == 2: |
| arr = np.stack((arr, arr, arr), axis=-1) |
| elif arr.ndim == 3 and arr.shape[0] in (3, 4) and arr.shape[-1] not in (3, 4): |
| arr = np.moveaxis(arr, 0, -1) |
|
|
| |
| c = arr.shape[-1] |
| if c == 4: |
| arr = arr[..., :3] |
| elif c == 1: |
| arr = np.repeat(arr, 3, axis=-1) |
|
|
| |
| if arr.dtype is np.dtype(np.uint8): |
| arr = arr.astype(np.float32) / 255.0 |
| else: |
| arr = arr.astype(np.float32, copy=False) |
| m = arr.max() |
| if m > 1.0: |
| arr = arr / m |
|
|
| return arr |
|
|
|
|
| def load_mask(path: Union[str, Path]) -> np.ndarray: |
| """Load a mask as HxW int64 (labels). Handles both TIFF and PNG files.""" |
| p = str(path) |
| m = None |
| if tiff_imread is not None: |
| try: |
| m = tiff_imread(p) |
| except Exception: |
| m = None |
| if m is None: |
| with Image.open(p) as im: |
| m = np.array(im) |
|
|
| |
| if m.ndim == 3: |
| m = m[..., 0] |
| return m.astype(np.int64, copy=False) |
|
|
|
|
| |
| |
| |
|
|
| def crop_with_pad(img: np.ndarray, y0: int, x0: int, h: int, w: int, pad_val=0) -> np.ndarray: |
| """Extract a crop HxW with padding if necessary (img HxW[ xC]).""" |
| H, W = img.shape[:2] |
| y1, x1 = y0 + h, x0 + w |
|
|
| pad_top = max(0, -y0); ys = max(0, y0) |
| pad_left = max(0, -x0); xs = max(0, x0) |
| pad_bot = max(0, y1 - H); ye = min(H, y1) |
| pad_right = max(0, x1 - W); xe = min(W, x1) |
|
|
| sl = img[ys:ye, xs:xe] |
| pad_cfg = ((pad_top, pad_bot), (pad_left, pad_right)) + (((0, 0),) if img.ndim == 3 else ()) |
| return np.pad(sl, pad_cfg, mode="constant", constant_values=pad_val) |
|
|
|
|
| def resize_np_img(img_hwc_float01: np.ndarray, size_hw: Tuple[int, int]) -> np.ndarray: |
| """Resize HWC float32[0,1] -> HWC float32[0,1] using bilinear interpolation.""" |
| Ht, Wt = size_hw |
| im = Image.fromarray((np.clip(img_hwc_float01, 0.0, 1.0) * 255.0).astype(np.uint8)) |
| im = im.resize((Wt, Ht), resample=Image.BILINEAR) |
| out = np.asarray(im, dtype=np.uint8).astype(np.float32) / 255.0 |
| if out.ndim == 2: |
| out = np.stack((out, out, out), axis=-1) |
| return out |
|
|
|
|
| def resize_np_mask(mask_hw_int: np.ndarray, size_hw: Tuple[int, int]) -> np.ndarray: |
| """Resize mask HW using nearest neighbor via PIL, output int64.""" |
| Ht, Wt = size_hw |
| mask = np.ascontiguousarray(mask_hw_int) |
| if mask.ndim != 2: |
| raise ValueError(f"resize_np_mask expects 2D mask HW, received shape={mask.shape}") |
|
|
| dt = mask.dtype |
| if dt in (np.int64, np.int32, np.int16, np.int8, np.uint16): |
| pil_arr, pil_mode = mask.astype(np.int32, copy=False), "I" |
| elif dt == np.uint8: |
| pil_arr, pil_mode = mask, "L" |
| else: |
| pil_arr, pil_mode = mask.astype(np.int32, copy=False), "I" |
|
|
| im = Image.fromarray(pil_arr, mode=pil_mode).resize((Wt, Ht), resample=Image.NEAREST) |
| return np.asarray(im).astype(np.int64, copy=False) |
|
|
|
|
| def to_tensor_img(x: np.ndarray) -> Tensor: |
| """Convert HWC float32[0,1] -> CHW float32[0,1].""" |
| return torch.from_numpy(np.transpose(x, (2, 0, 1)).copy()) |
|
|
|
|
| def to_pil_uint8(img_float01_hwc: np.ndarray) -> Image.Image: |
| """Convert HWC float32[0,1] -> PIL RGB uint8.""" |
| arr = (np.clip(img_float01_hwc, 0.0, 1.0) * 255.0).round().astype(np.uint8) |
| return Image.fromarray(arr, mode="RGB") |
|
|
|
|
| |
| |
| |
|
|
| class SWISSIMAGEINFERENCEDataset(Dataset): |
| """ |
| dataset with HR/LR dual-branch processing and tiling support. |
| |
| In test mode, each worker caches the current image+mask in RAM for all tiles |
| of the same image to avoid re-reading for each tile. |
| """ |
| def __init__( |
| self, |
| image_dir: Union[str, Path], |
| num_classes: int, |
| mode: str = "train", |
| ignore_index: int = 255, |
| hr_size: int = 1024, |
| lr_side: int = 2048, |
| transform: Optional = None, |
| limit: Optional[int] = None, |
| stride: Optional[int] = None, |
| ) -> None: |
| assert mode in {"train", "val", "test"} |
| self.image_dir = Path(image_dir) |
| self.mode = mode |
| self.num_classes = int(num_classes) |
| self.ignore_index = int(ignore_index) |
| self.HR = int(hr_size) |
| self.LR_WIN = int(lr_side) |
| self.transform = transform |
|
|
| |
| self.stride = int(stride) if stride is not None else self.HR |
|
|
| imgs = sorted([p for p in self.image_dir.iterdir() if p.is_file()]) |
| self.images: List[Path] = imgs |
|
|
| |
| self._test_index: List[Tuple[int, int, int]] = [] |
| self._sizes: List[Tuple[int, int]] = [] |
|
|
| if self.mode == "test" or self.mode == "val": |
| for img_id, ip in enumerate(self.images): |
| with Image.open(ip) as im: |
| W, H = im.size |
| self._sizes.append((H, W)) |
|
|
| |
| |
| for y0 in range(0, H, self.stride): |
| for x0 in range(0, W, self.stride): |
| self._test_index.append((img_id, y0, x0)) |
|
|
| |
| self._cache_img_id: Optional[int] = None |
| self._cache_img: Optional[np.ndarray] = None |
|
|
| def __len__(self) -> int: |
| return len(self._test_index) if self.mode == "test" else len(self.images) |
|
|
| def _extract_pair_np( |
| self, img: np.ndarray, y0: int, x0: int |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: |
| """Extract HR/LR pair in numpy (images HWC float32[0,1], masks HW int64).""" |
| |
| img_hr = crop_with_pad(img, y0, x0, self.HR, self.HR, pad_val=0) |
| |
| img_hr = resize_np_img(img_hr, (self.HR//2, self.HR//2)) |
|
|
| |
| cy, cx = y0 + self.HR // 2, x0 + self.HR // 2 |
| half = self.LR_WIN // 2 |
| img_lr = crop_with_pad(img, cy - half, cx - half, self.LR_WIN, self.LR_WIN, pad_val=0) |
|
|
| |
| img_lr_512 = resize_np_img(img_lr, (self.HR//2, self.HR//2)) |
| return img_hr, img_lr_512 |
|
|
| def __getitem__(self, idx: int): |
| if self.mode == "test" or self.mode == "val": |
| |
| img_id, y0, x0 = self._test_index[idx] |
| ip = self.images[img_id] |
| H, W = self._sizes[img_id] |
|
|
| |
| if self._cache_img_id != img_id: |
| self._cache_img = load_image(ip) |
| self._cache_img_id = img_id |
|
|
| img = self._cache_img |
|
|
| img_hr_np, img_lr_np = self._extract_pair_np(img, y0, x0) |
|
|
| if self.transform: |
| image_hr = self.transform(to_pil_uint8(img_hr_np)) |
| image_lr = self.transform(to_pil_uint8(img_lr_np)) |
| else: |
| image_hr = to_tensor_img(img_hr_np) |
| image_lr = to_tensor_img(img_lr_np) |
|
|
| meta: Dict[str, object] = { |
| "img_path": str(ip), |
| "tile": (int(y0), int(x0), self.HR, self.HR), |
| "img_hw": (int(H), int(W)), |
| "tile_index": int(idx), |
| } |
| return image_hr, image_lr, meta |
|
|
| |
| ip, mp = self.images[idx], self.masks[idx] |
| img = load_image(ip) |
| msk = load_mask(mp) |
| H, W = img.shape[:2] |
|
|
| y0 = 0 if H <= self.HR else np.random.randint(0, H - self.HR + 1) |
| x0 = 0 if W <= self.HR else np.random.randint(0, W - self.HR + 1) |
|
|
| img_hr_np, msk_hr_np, img_lr_np, msk_lr_np = self._extract_pair_np(img, msk, y0, x0) |
|
|
| if self.transform: |
| image_hr = self.transform(to_pil_uint8(img_hr_np)) |
| image_lr = self.transform(to_pil_uint8(img_lr_np)) |
| else: |
| image_hr = to_tensor_img(img_hr_np) |
| image_lr = to_tensor_img(img_lr_np) |
|
|
| mask_hr = torch.as_tensor(msk_hr_np, dtype=torch.long) |
| mask_lr = torch.as_tensor(msk_lr_np, dtype=torch.long) |
|
|
| meta: Dict[str, object] = {"img_path": str(ip), "mask_path": str(mp)} |
| return image_hr, mask_hr, image_lr, mask_lr, meta |
|
|
| class URURHRLRDataset(Dataset): |
| """ |
| URUR dataset with HR/LR dual-branch processing and tiling support. |
| |
| In test mode, each worker caches the current image+mask in RAM for all tiles |
| of the same image to avoid re-reading for each tile. |
| """ |
| def __init__( |
| self, |
| image_dir: Union[str, Path], |
| mask_dir: Union[str, Path], |
| num_classes: int, |
| mode: str = "train", |
| ignore_index: int = 255, |
| hr_size: int = 1024, |
| lr_side: int = 2048, |
| transform: Optional = None, |
| augment=None, |
| limit: Optional[int] = None |
| ) -> None: |
| assert mode in {"train", "val", "test"} |
| self.image_dir = Path(image_dir) |
| self.mask_dir = Path(mask_dir) |
| self.mode = mode |
| self.num_classes = int(num_classes) |
| self.ignore_index = int(ignore_index) |
| self.HR = int(hr_size) |
| self.LR_WIN = int(lr_side) |
| self.transform = transform |
| self.augment = augment |
|
|
| imgs = sorted([p for p in self.image_dir.iterdir() if p.is_file()]) |
| msks = sorted([p for p in self.mask_dir.iterdir() if p.is_file()]) |
| if limit is not None: |
| imgs, msks = imgs[:limit], msks[:limit] |
| assert len(imgs) == len(msks) and len(imgs) > 0, "Images/Masks missing or misaligned" |
|
|
| self.images: List[Path] = imgs |
| self.masks: List[Path] = msks |
|
|
| |
| self._test_index: List[Tuple[int, int, int]] = [] |
| self._sizes: List[Tuple[int, int]] = [] |
|
|
| if self.mode == "test" or self.mode == "val": |
| for img_id, ip in enumerate(self.images): |
| with Image.open(ip) as im: |
| W, H = im.size |
| self._sizes.append((H, W)) |
|
|
| n_ty = math.ceil(H / self.HR) |
| n_tx = math.ceil(W / self.HR) |
| for iy in range(n_ty): |
| for ix in range(n_tx): |
| self._test_index.append((img_id, iy * self.HR, ix * self.HR)) |
|
|
| |
| self._cache_img_id: Optional[int] = None |
| self._cache_img: Optional[np.ndarray] = None |
| self._cache_msk: Optional[np.ndarray] = None |
|
|
| def __len__(self) -> int: |
| return len(self._test_index) if self.mode == "test" else len(self.images) |
|
|
| def _extract_pair_np( |
| self, img: np.ndarray, msk: np.ndarray, y0: int, x0: int |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: |
| """Extract HR/LR pair in numpy (images HWC float32[0,1], masks HW int64).""" |
| |
| img_hr = crop_with_pad(img, y0, x0, self.HR, self.HR, pad_val=0) |
| msk_hr = crop_with_pad(msk, y0, x0, self.HR, self.HR, pad_val=self.ignore_index) |
| |
| img_hr = resize_np_img(img_hr, (self.HR//2, self.HR//2)) |
|
|
| |
| cy, cx = y0 + self.HR // 2, x0 + self.HR // 2 |
| half = self.LR_WIN // 2 |
| img_lr = crop_with_pad(img, cy - half, cx - half, self.LR_WIN, self.LR_WIN, pad_val=0) |
| msk_lr = crop_with_pad(msk, cy - half, cx - half, self.LR_WIN, self.LR_WIN, pad_val=self.ignore_index) |
|
|
| |
| img_lr_512 = resize_np_img(img_lr, (self.HR//2, self.HR//2)) |
| msk_lr_512 = resize_np_mask(msk_lr, (self.HR, self.HR)) |
|
|
| |
| msk_hr = msk_hr.astype(np.int64, copy=False) |
| msk_lr_512 = msk_lr_512.astype(np.int64, copy=False) |
| |
| if hasattr(self, "_postprocess_masks"): |
| msk_hr, msk_lr_512 = self._postprocess_masks(msk_hr, msk_lr_512) |
| |
| msk_hr[msk_hr >= self.num_classes] = self.ignore_index |
| msk_lr_512[msk_lr_512 >= self.num_classes] = self.ignore_index |
| |
| |
|
|
| return img_hr, msk_hr, img_lr_512, msk_lr_512 |
|
|
| def __getitem__(self, idx: int): |
| if self.mode == "test" or self.mode == "val": |
| |
| img_id, y0, x0 = self._test_index[idx] |
| ip, mp = self.images[img_id], self.masks[img_id] |
| H, W = self._sizes[img_id] |
|
|
| |
| if self._cache_img_id != img_id: |
| self._cache_img = load_image(ip) |
| self._cache_msk = load_mask(mp) |
| self._cache_img_id = img_id |
|
|
| img = self._cache_img |
| msk = self._cache_msk |
|
|
| img_hr_np, msk_hr_np, img_lr_np, msk_lr_np = self._extract_pair_np(img, msk, y0, x0) |
|
|
| if self.transform: |
| image_hr = self.transform(to_pil_uint8(img_hr_np)) |
| image_lr = self.transform(to_pil_uint8(img_lr_np)) |
| else: |
| image_hr = to_tensor_img(img_hr_np) |
| image_lr = to_tensor_img(img_lr_np) |
|
|
| mask_hr = torch.as_tensor(msk_hr_np, dtype=torch.long) |
| mask_lr = torch.as_tensor(msk_lr_np, dtype=torch.long) |
|
|
| meta: Dict[str, object] = { |
| "img_path": str(ip), |
| "mask_path": str(mp), |
| "tile": (int(y0), int(x0), self.HR, self.HR), |
| "img_hw": (int(H), int(W)), |
| "tile_index": int(idx), |
| } |
| return image_hr, mask_hr, image_lr, mask_lr, meta |
|
|
| |
| ip, mp = self.images[idx], self.masks[idx] |
| img = load_image(ip) |
| msk = load_mask(mp) |
| H, W = img.shape[:2] |
|
|
| y0 = 0 if H <= self.HR else np.random.randint(0, H - self.HR + 1) |
| x0 = 0 if W <= self.HR else np.random.randint(0, W - self.HR + 1) |
|
|
| img_hr_np, msk_hr_np, img_lr_np, msk_lr_np = self._extract_pair_np(img, msk, y0, x0) |
|
|
| img_hr_pil = to_pil_uint8(img_hr_np) |
| img_lr_pil = to_pil_uint8(img_lr_np) |
| m_hr_pil = Image.fromarray(msk_hr_np.astype(np.uint8), mode="L") |
| m_lr_pil = Image.fromarray(msk_lr_np.astype(np.uint8), mode="L") |
|
|
| if self.augment is not None: |
| img_hr_pil, m_hr_pil, img_lr_pil, m_lr_pil = self.augment(img_hr_pil, m_hr_pil, img_lr_pil, m_lr_pil) |
|
|
| if self.transform: |
| image_hr = self.transform(img_hr_pil) |
| image_lr = self.transform(img_lr_pil) |
| else: |
| image_hr = to_tensor_img(img_hr_np) |
| image_lr = to_tensor_img(img_lr_np) |
|
|
| mask_hr = torch.as_tensor(np.array(m_hr_pil, dtype=np.uint8), dtype=torch.long) |
| mask_lr = torch.as_tensor(np.array(m_lr_pil, dtype=np.uint8), dtype=torch.long) |
|
|
| meta: Dict[str, object] = {"img_path": str(ip), "mask_path": str(mp)} |
| return image_hr, mask_hr, image_lr, mask_lr, meta |
|
|
| def build_transforms(): |
| """Build standard transforms with normalization (mean=std=0.5).""" |
| return transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
| ]) |
|
|
|
|
| def _read_vrt_rgb_window(vrt_path: str, y0: int, x0: int, h: int, w: int) -> np.ndarray: |
| |
| with rasterio.open(vrt_path) as src: |
| win = Window(col_off=x0, row_off=y0, width=w, height=h) |
| data = src.read(window=win, boundless=True, fill_value=0) |
| |
| data = data[:3, :, :] |
| img = np.transpose(data, (1, 2, 0)) |
| if img.dtype == np.uint8: |
| img = img.astype(np.float32) / 255.0 |
| else: |
| img = img.astype(np.float32) |
| m = float(img.max()) if img.size else 1.0 |
| if m > 1.0: |
| img /= m |
| return img |
|
|
| class VRTDualResInferenceDataset(Dataset): |
| def __init__(self, vrt_path, tile_size=1024, stride=512, lr_side=2048, transform=None): |
| self.vrt_path = vrt_path |
| self.tile = int(tile_size) |
| self.stride = int(stride) |
| self.lr_side = int(lr_side) |
| self.transform = transform |
|
|
| |
| with rasterio.open(self.vrt_path) as src: |
| self.H = src.height |
| self.W = src.width |
| self.profile = src.profile |
|
|
| self.index = [(y0, x0) for y0 in range(0, self.H, self.stride) |
| for x0 in range(0, self.W, self.stride)] |
|
|
| if self.tile < self.stride: |
| raise ValueError("tile_size must be >= stride") |
| self.halo = (self.tile - self.stride) // 2 |
|
|
| |
| self._src = None |
|
|
| def _get_src(self): |
| if self._src is None: |
| self._src = rasterio.open(self.vrt_path) |
| return self._src |
|
|
| def _read_rgb_window(self, y0, x0, h, w): |
| src = self._get_src() |
| win = Window(col_off=x0, row_off=y0, width=w, height=h) |
| data = src.read(window=win, boundless=True, fill_value=0) |
| data = data[:3, :, :] |
| img = np.transpose(data, (1, 2, 0)) |
|
|
| if img.dtype == np.uint8: |
| img = img.astype(np.float32) / 255.0 |
| else: |
| img = img.astype(np.float32) |
| m = float(img.max()) if img.size else 1.0 |
| if m > 1.0: |
| img /= m |
| return img |
|
|
| def __len__(self): |
| return len(self.index) |
|
|
| def __getitem__(self, idx): |
| y0, x0 = self.index[idx] |
|
|
| img_hr = self._read_rgb_window(y0, x0, self.tile, self.tile) |
|
|
| cy, cx = y0 + self.tile // 2, x0 + self.tile // 2 |
| half = self.lr_side // 2 |
| img_lr = self._read_rgb_window(cy - half, cx - half, self.lr_side, self.lr_side) |
|
|
| from dataset.definition_dataset import resize_np_img, to_pil_uint8, to_tensor_img |
|
|
| |
| img_hr = resize_np_img(img_hr, (self.tile // 2, self.tile // 2)) |
| img_lr = resize_np_img(img_lr, (self.tile // 2, self.tile // 2)) |
|
|
| if self.transform: |
| image_hr = self.transform(to_pil_uint8(img_hr)) |
| image_lr = self.transform(to_pil_uint8(img_lr)) |
| else: |
| image_hr = to_tensor_img(img_hr) |
| image_lr = to_tensor_img(img_lr) |
|
|
| wy0 = y0 + self.halo |
| wx0 = x0 + self.halo |
| wh = max(0, min(self.stride, self.H - wy0)) |
| ww = max(0, min(self.stride, self.W - wx0)) |
|
|
| meta = { |
| "tile_y0": y0, |
| "tile_x0": x0, |
| "tile_size": self.tile, |
| "img_hw": (self.H, self.W), |
| "write_y0": wy0, |
| "write_x0": wx0, |
| "write_h": wh, |
| "write_w": ww, |
| "idx": idx, |
| } |
| return image_hr, image_lr, meta |
|
|
| def __del__(self): |
| try: |
| if self._src is not None: |
| self._src.close() |
| except Exception: |
| pass |
|
|