Image Segmentation
English
CASWiT / dataset /definition_dataset.py
antoine.carreaud67
corr bug with very large vrt
b824cd5
"""
Dataset definitions for CASWiT training and evaluation.
This module provides dataset classes for semantic segmentation with
HR/LR dual-branch processing.
"""
import os
import math
from pathlib import Path
from typing import Optional, Union, Tuple, List, Dict
import numpy as np
import torch
from torch import Tensor
from torch.utils.data import Dataset
from PIL import Image
from tifffile import imread as tiff_imread
from torchvision import transforms
from dataclasses import dataclass
import rasterio
from rasterio.windows import Window
class SemanticSegmentationDatasetFusion(Dataset):
"""
Dataset for HR/LR fusion training on FLAIRHub.
Returns (image_hr, mask_hr, image_lr, mask_lr):
- image_hr: 512x512 crop starting at (256, 256)
- image_lr: full image downsampled by factor 2
- mask >=15 replaced by 255 (ignore)
- transforms applied to images (ToTensor + Normalize) and mask -> LongTensor
- optional JOINT augmentations applied consistently to HR/LR + masks
"""
def __init__(self, image_dir: Path, mask_dir: Path, transform: Optional[transforms.Compose] = None, augment = None):
self.image_dir = Path(image_dir)
self.mask_dir = Path(mask_dir)
self.image_filenames = sorted(os.listdir(self.image_dir))
self.mask_filenames = sorted(os.listdir(self.mask_dir))
assert len(self.image_filenames) == len(self.mask_filenames), "Images/Masks count mismatch"
self.transform = transform
self.augment = augment
def __len__(self):
return len(self.image_filenames)
def __getitem__(self, idx):
image_path = self.image_dir / self.image_filenames[idx]
mask_path = self.mask_dir / self.mask_filenames[idx]
image = load_image(image_path)
mask = load_mask(mask_path)
mask[mask >= 15] = 255
# Crop HR at 512x512
hr_crop_size = 512
crop_x, crop_y = 256, 256
image_hr = image[crop_x:crop_x + hr_crop_size, crop_y:crop_y + hr_crop_size]
mask_hr = mask[crop_x:crop_x + hr_crop_size, crop_y:crop_y + hr_crop_size]
# Downsample LR
image_lr = image[::2,::2,:]
mask_lr = mask[::2,::2]
# ---- Convert to PIL for augmentations
img_hr_pil = to_pil_uint8(image_hr)
img_lr_pil = to_pil_uint8(image_lr)
m_hr_pil = Image.fromarray(mask_hr.astype(np.uint8), mode="L")
m_lr_pil = Image.fromarray(mask_lr.astype(np.uint8), mode="L")
# ---- Joint HR/LR + mask augmentations (TRAIN ONLY)
if self.augment is not None:
img_hr_pil, m_hr_pil, img_lr_pil, m_lr_pil = self.augment(
img_hr_pil, m_hr_pil, img_lr_pil, m_lr_pil
)
# ---- Image transform (ToTensor + Normalize)
if self.transform:
image_hr = self.transform(img_hr_pil)
image_lr = self.transform(img_lr_pil)
else:
image_hr = to_tensor_img(np.array(img_hr_pil))
image_lr = to_tensor_img(np.array(img_lr_pil))
# ---- Masks to torch
mask_hr = torch.as_tensor(np.array(m_hr_pil, dtype=np.uint8), dtype=torch.long)
mask_lr = torch.as_tensor(np.array(m_lr_pil, dtype=np.uint8), dtype=torch.long)
return image_hr, mask_hr, image_lr, mask_lr
class SemanticSegmentationDatasetHR(Dataset):
"""
Dataset for HR-only training (single branch, no LR).
Returns (image_hr, mask_hr):
- image_hr: 512x512 crop starting at (256, 256)
- mask >=15 replaced by 255 (ignore)
"""
def __init__(self, image_dir: Path, mask_dir: Path, transform: Optional[transforms.Compose] = None):
self.image_dir = Path(image_dir)
self.mask_dir = Path(mask_dir)
self.image_filenames = sorted(os.listdir(self.image_dir))
self.mask_filenames = sorted(os.listdir(self.mask_dir))
assert len(self.image_filenames) == len(self.mask_filenames), "Images/Masks count mismatch"
self.transform = transform
def __len__(self):
return len(self.image_filenames)
def __getitem__(self, idx):
image_path = self.image_dir / self.image_filenames[idx]
mask_path = self.mask_dir / self.mask_filenames[idx]
image = load_image(image_path)
mask = load_mask(mask_path)
mask[mask >= 15] = 255
crop_x, crop_y = 256, 256
image_hr = image[crop_x:crop_x + 512, crop_y:crop_y + 512]
mask_hr = mask[crop_x:crop_x + 512, crop_y:crop_y + 512]
if self.transform:
image_hr = self.transform(to_pil_uint8(image_hr))
else:
image_hr = to_tensor_img(image_hr)
mask_hr = torch.tensor(mask_hr, dtype=torch.long)
return image_hr, mask_hr
# ----------------------------
# Image Loading Functions for slifing windows without overlap on URUR, deepglobe and INRIA
# ----------------------------
def load_image(path: Union[str, Path]) -> np.ndarray:
"""
Load an image as HxWx3 (RGB), float32 [0,1].
Handles both TIFF and PNG files gracefully.
"""
p = str(path)
arr = None
# 1) Try TIFF first if available
if tiff_imread is not None:
try:
arr = tiff_imread(p)
except Exception:
arr = None
# 2) Fallback to PIL
if arr is None:
with Image.open(p) as im:
arr = np.array(im.convert("RGB")) # HWC uint8
# Ensure HWC format
if arr.ndim == 2:
arr = np.stack((arr, arr, arr), axis=-1) # HWC
elif arr.ndim == 3 and arr.shape[0] in (3, 4) and arr.shape[-1] not in (3, 4):
arr = np.moveaxis(arr, 0, -1) # CHW -> HWC
# Keep 3 channels
c = arr.shape[-1]
if c == 4:
arr = arr[..., :3]
elif c == 1:
arr = np.repeat(arr, 3, axis=-1)
# Normalize -> float32 [0,1]
if arr.dtype is np.dtype(np.uint8):
arr = arr.astype(np.float32) / 255.0
else:
arr = arr.astype(np.float32, copy=False)
m = arr.max()
if m > 1.0:
arr = arr / m
return arr # float32 HWC in [0,1]
def load_mask(path: Union[str, Path]) -> np.ndarray:
"""Load a mask as HxW int64 (labels). Handles both TIFF and PNG files."""
p = str(path)
m = None
if tiff_imread is not None:
try:
m = tiff_imread(p)
except Exception:
m = None
if m is None:
with Image.open(p) as im:
m = np.array(im)
# Force 2D
if m.ndim == 3:
m = m[..., 0]
return m.astype(np.int64, copy=False)
# ----------------------------
# Helper Functions
# ----------------------------
def crop_with_pad(img: np.ndarray, y0: int, x0: int, h: int, w: int, pad_val=0) -> np.ndarray:
"""Extract a crop HxW with padding if necessary (img HxW[ xC])."""
H, W = img.shape[:2]
y1, x1 = y0 + h, x0 + w
pad_top = max(0, -y0); ys = max(0, y0)
pad_left = max(0, -x0); xs = max(0, x0)
pad_bot = max(0, y1 - H); ye = min(H, y1)
pad_right = max(0, x1 - W); xe = min(W, x1)
sl = img[ys:ye, xs:xe]
pad_cfg = ((pad_top, pad_bot), (pad_left, pad_right)) + (((0, 0),) if img.ndim == 3 else ())
return np.pad(sl, pad_cfg, mode="constant", constant_values=pad_val)
def resize_np_img(img_hwc_float01: np.ndarray, size_hw: Tuple[int, int]) -> np.ndarray:
"""Resize HWC float32[0,1] -> HWC float32[0,1] using bilinear interpolation."""
Ht, Wt = size_hw
im = Image.fromarray((np.clip(img_hwc_float01, 0.0, 1.0) * 255.0).astype(np.uint8))
im = im.resize((Wt, Ht), resample=Image.BILINEAR)
out = np.asarray(im, dtype=np.uint8).astype(np.float32) / 255.0
if out.ndim == 2:
out = np.stack((out, out, out), axis=-1)
return out
def resize_np_mask(mask_hw_int: np.ndarray, size_hw: Tuple[int, int]) -> np.ndarray:
"""Resize mask HW using nearest neighbor via PIL, output int64."""
Ht, Wt = size_hw
mask = np.ascontiguousarray(mask_hw_int)
if mask.ndim != 2:
raise ValueError(f"resize_np_mask expects 2D mask HW, received shape={mask.shape}")
dt = mask.dtype
if dt in (np.int64, np.int32, np.int16, np.int8, np.uint16):
pil_arr, pil_mode = mask.astype(np.int32, copy=False), "I"
elif dt == np.uint8:
pil_arr, pil_mode = mask, "L"
else:
pil_arr, pil_mode = mask.astype(np.int32, copy=False), "I"
im = Image.fromarray(pil_arr, mode=pil_mode).resize((Wt, Ht), resample=Image.NEAREST)
return np.asarray(im).astype(np.int64, copy=False)
def to_tensor_img(x: np.ndarray) -> Tensor:
"""Convert HWC float32[0,1] -> CHW float32[0,1]."""
return torch.from_numpy(np.transpose(x, (2, 0, 1)).copy())
def to_pil_uint8(img_float01_hwc: np.ndarray) -> Image.Image:
"""Convert HWC float32[0,1] -> PIL RGB uint8."""
arr = (np.clip(img_float01_hwc, 0.0, 1.0) * 255.0).round().astype(np.uint8)
return Image.fromarray(arr, mode="RGB")
# ----------------------------
# Dataset Classes
# ----------------------------
class SWISSIMAGEINFERENCEDataset(Dataset):
"""
dataset with HR/LR dual-branch processing and tiling support.
In test mode, each worker caches the current image+mask in RAM for all tiles
of the same image to avoid re-reading for each tile.
"""
def __init__(
self,
image_dir: Union[str, Path],
num_classes: int,
mode: str = "train",
ignore_index: int = 255,
hr_size: int = 1024,
lr_side: int = 2048,
transform: Optional = None,
limit: Optional[int] = None,
stride: Optional[int] = None, # <--- nouveau paramètre
) -> None:
assert mode in {"train", "val", "test"}
self.image_dir = Path(image_dir)
self.mode = mode
self.num_classes = int(num_classes)
self.ignore_index = int(ignore_index)
self.HR = int(hr_size)
self.LR_WIN = int(lr_side)
self.transform = transform
# si stride n'est pas donné, on garde le comportement actuel (stride = HR)
self.stride = int(stride) if stride is not None else self.HR
imgs = sorted([p for p in self.image_dir.iterdir() if p.is_file()])
self.images: List[Path] = imgs
# Tile index + quick sizes (without loading full image)
self._test_index: List[Tuple[int, int, int]] = [] # (img_id, y0, x0)
self._sizes: List[Tuple[int, int]] = [] # (H, W) per image
if self.mode == "test" or self.mode == "val":
for img_id, ip in enumerate(self.images):
with Image.open(ip) as im:
W, H = im.size
self._sizes.append((H, W))
# --- sliding window avec stride ---
# Si stride == HR, tu retrouves exactement le comportement précédent.
for y0 in range(0, H, self.stride):
for x0 in range(0, W, self.stride):
self._test_index.append((img_id, y0, x0))
# Cache per worker (used only in test mode)
self._cache_img_id: Optional[int] = None
self._cache_img: Optional[np.ndarray] = None # float32 HWC [0,1]
def __len__(self) -> int:
return len(self._test_index) if self.mode == "test" else len(self.images)
def _extract_pair_np(
self, img: np.ndarray, y0: int, x0: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Extract HR/LR pair in numpy (images HWC float32[0,1], masks HW int64)."""
# HR
img_hr = crop_with_pad(img, y0, x0, self.HR, self.HR, pad_val=0)
# Resize HR to half size
img_hr = resize_np_img(img_hr, (self.HR//2, self.HR//2))
# LR centered on HR
cy, cx = y0 + self.HR // 2, x0 + self.HR // 2
half = self.LR_WIN // 2
img_lr = crop_with_pad(img, cy - half, cx - half, self.LR_WIN, self.LR_WIN, pad_val=0)
# Downsample LR -> HR
img_lr_512 = resize_np_img(img_lr, (self.HR//2, self.HR//2))
return img_hr, img_lr_512
def __getitem__(self, idx: int):
if self.mode == "test" or self.mode == "val":
# Tile index
img_id, y0, x0 = self._test_index[idx]
ip = self.images[img_id]
H, W = self._sizes[img_id]
# Cache per worker: read/convert only once per image
if self._cache_img_id != img_id:
self._cache_img = load_image(ip) # float32 HWC [0,1]
self._cache_img_id = img_id
img = self._cache_img
img_hr_np, img_lr_np = self._extract_pair_np(img, y0, x0)
if self.transform:
image_hr = self.transform(to_pil_uint8(img_hr_np))
image_lr = self.transform(to_pil_uint8(img_lr_np))
else:
image_hr = to_tensor_img(img_hr_np)
image_lr = to_tensor_img(img_lr_np)
meta: Dict[str, object] = {
"img_path": str(ip),
"tile": (int(y0), int(x0), self.HR, self.HR),
"img_hw": (int(H), int(W)),
"tile_index": int(idx),
}
return image_hr, image_lr, meta
# train mode
ip, mp = self.images[idx], self.masks[idx]
img = load_image(ip)
msk = load_mask(mp)
H, W = img.shape[:2]
y0 = 0 if H <= self.HR else np.random.randint(0, H - self.HR + 1)
x0 = 0 if W <= self.HR else np.random.randint(0, W - self.HR + 1)
img_hr_np, msk_hr_np, img_lr_np, msk_lr_np = self._extract_pair_np(img, msk, y0, x0)
if self.transform:
image_hr = self.transform(to_pil_uint8(img_hr_np))
image_lr = self.transform(to_pil_uint8(img_lr_np))
else:
image_hr = to_tensor_img(img_hr_np)
image_lr = to_tensor_img(img_lr_np)
mask_hr = torch.as_tensor(msk_hr_np, dtype=torch.long)
mask_lr = torch.as_tensor(msk_lr_np, dtype=torch.long)
meta: Dict[str, object] = {"img_path": str(ip), "mask_path": str(mp)}
return image_hr, mask_hr, image_lr, mask_lr, meta
class URURHRLRDataset(Dataset):
"""
URUR dataset with HR/LR dual-branch processing and tiling support.
In test mode, each worker caches the current image+mask in RAM for all tiles
of the same image to avoid re-reading for each tile.
"""
def __init__(
self,
image_dir: Union[str, Path],
mask_dir: Union[str, Path],
num_classes: int,
mode: str = "train",
ignore_index: int = 255,
hr_size: int = 1024,
lr_side: int = 2048,
transform: Optional = None,
augment=None,
limit: Optional[int] = None
) -> None:
assert mode in {"train", "val", "test"}
self.image_dir = Path(image_dir)
self.mask_dir = Path(mask_dir)
self.mode = mode
self.num_classes = int(num_classes)
self.ignore_index = int(ignore_index)
self.HR = int(hr_size)
self.LR_WIN = int(lr_side)
self.transform = transform
self.augment = augment
imgs = sorted([p for p in self.image_dir.iterdir() if p.is_file()])
msks = sorted([p for p in self.mask_dir.iterdir() if p.is_file()])
if limit is not None:
imgs, msks = imgs[:limit], msks[:limit]
assert len(imgs) == len(msks) and len(imgs) > 0, "Images/Masks missing or misaligned"
self.images: List[Path] = imgs
self.masks: List[Path] = msks
# Tile index + quick sizes (without loading full image)
self._test_index: List[Tuple[int, int, int]] = [] # (img_id, y0, x0)
self._sizes: List[Tuple[int, int]] = [] # (H, W) per image
if self.mode == "test" or self.mode == "val":
for img_id, ip in enumerate(self.images):
with Image.open(ip) as im:
W, H = im.size
self._sizes.append((H, W))
n_ty = math.ceil(H / self.HR)
n_tx = math.ceil(W / self.HR)
for iy in range(n_ty):
for ix in range(n_tx):
self._test_index.append((img_id, iy * self.HR, ix * self.HR))
# Cache per worker (used only in test mode)
self._cache_img_id: Optional[int] = None
self._cache_img: Optional[np.ndarray] = None # float32 HWC [0,1]
self._cache_msk: Optional[np.ndarray] = None # int64 HW
def __len__(self) -> int:
return len(self._test_index) if self.mode == "test" else len(self.images)
def _extract_pair_np(
self, img: np.ndarray, msk: np.ndarray, y0: int, x0: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Extract HR/LR pair in numpy (images HWC float32[0,1], masks HW int64)."""
# HR
img_hr = crop_with_pad(img, y0, x0, self.HR, self.HR, pad_val=0)
msk_hr = crop_with_pad(msk, y0, x0, self.HR, self.HR, pad_val=self.ignore_index)
# Resize HR to half size
img_hr = resize_np_img(img_hr, (self.HR//2, self.HR//2))
# LR centered on HR
cy, cx = y0 + self.HR // 2, x0 + self.HR // 2
half = self.LR_WIN // 2
img_lr = crop_with_pad(img, cy - half, cx - half, self.LR_WIN, self.LR_WIN, pad_val=0)
msk_lr = crop_with_pad(msk, cy - half, cx - half, self.LR_WIN, self.LR_WIN, pad_val=self.ignore_index)
# Downsample LR -> HR
img_lr_512 = resize_np_img(img_lr, (self.HR//2, self.HR//2))
msk_lr_512 = resize_np_mask(msk_lr, (self.HR, self.HR))
# Clamp out of range -> ignore_index
msk_hr = msk_hr.astype(np.int64, copy=False)
msk_lr_512 = msk_lr_512.astype(np.int64, copy=False)
if hasattr(self, "_postprocess_masks"): #for ISIC encoded classes as 0 and 255
msk_hr, msk_lr_512 = self._postprocess_masks(msk_hr, msk_lr_512)
msk_hr[msk_hr >= self.num_classes] = self.ignore_index
msk_lr_512[msk_lr_512 >= self.num_classes] = self.ignore_index
return img_hr, msk_hr, img_lr_512, msk_lr_512
def __getitem__(self, idx: int):
if self.mode == "test" or self.mode == "val":
# Tile index
img_id, y0, x0 = self._test_index[idx]
ip, mp = self.images[img_id], self.masks[img_id]
H, W = self._sizes[img_id]
# Cache per worker: read/convert only once per image
if self._cache_img_id != img_id:
self._cache_img = load_image(ip) # float32 HWC [0,1]
self._cache_msk = load_mask(mp) # int64 HW
self._cache_img_id = img_id
img = self._cache_img
msk = self._cache_msk
img_hr_np, msk_hr_np, img_lr_np, msk_lr_np = self._extract_pair_np(img, msk, y0, x0)
if self.transform:
image_hr = self.transform(to_pil_uint8(img_hr_np))
image_lr = self.transform(to_pil_uint8(img_lr_np))
else:
image_hr = to_tensor_img(img_hr_np)
image_lr = to_tensor_img(img_lr_np)
mask_hr = torch.as_tensor(msk_hr_np, dtype=torch.long)
mask_lr = torch.as_tensor(msk_lr_np, dtype=torch.long)
meta: Dict[str, object] = {
"img_path": str(ip),
"mask_path": str(mp),
"tile": (int(y0), int(x0), self.HR, self.HR),
"img_hw": (int(H), int(W)),
"tile_index": int(idx),
}
return image_hr, mask_hr, image_lr, mask_lr, meta
# train mode
ip, mp = self.images[idx], self.masks[idx]
img = load_image(ip)
msk = load_mask(mp)
H, W = img.shape[:2]
y0 = 0 if H <= self.HR else np.random.randint(0, H - self.HR + 1)
x0 = 0 if W <= self.HR else np.random.randint(0, W - self.HR + 1)
img_hr_np, msk_hr_np, img_lr_np, msk_lr_np = self._extract_pair_np(img, msk, y0, x0)
img_hr_pil = to_pil_uint8(img_hr_np)
img_lr_pil = to_pil_uint8(img_lr_np)
m_hr_pil = Image.fromarray(msk_hr_np.astype(np.uint8), mode="L")
m_lr_pil = Image.fromarray(msk_lr_np.astype(np.uint8), mode="L")
if self.augment is not None:
img_hr_pil, m_hr_pil, img_lr_pil, m_lr_pil = self.augment(img_hr_pil, m_hr_pil, img_lr_pil, m_lr_pil)
if self.transform:
image_hr = self.transform(img_hr_pil)
image_lr = self.transform(img_lr_pil)
else:
image_hr = to_tensor_img(img_hr_np)
image_lr = to_tensor_img(img_lr_np)
mask_hr = torch.as_tensor(np.array(m_hr_pil, dtype=np.uint8), dtype=torch.long)
mask_lr = torch.as_tensor(np.array(m_lr_pil, dtype=np.uint8), dtype=torch.long)
meta: Dict[str, object] = {"img_path": str(ip), "mask_path": str(mp)}
return image_hr, mask_hr, image_lr, mask_lr, meta
def build_transforms():
"""Build standard transforms with normalization (mean=std=0.5)."""
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
def _read_vrt_rgb_window(vrt_path: str, y0: int, x0: int, h: int, w: int) -> np.ndarray:
# Retour HWC float32 [0,1]
with rasterio.open(vrt_path) as src:
win = Window(col_off=x0, row_off=y0, width=w, height=h)
data = src.read(window=win, boundless=True, fill_value=0) # CHW
# prend 3 canaux
data = data[:3, :, :]
img = np.transpose(data, (1, 2, 0)) # HWC
if img.dtype == np.uint8:
img = img.astype(np.float32) / 255.0
else:
img = img.astype(np.float32)
m = float(img.max()) if img.size else 1.0
if m > 1.0:
img /= m
return img
class VRTDualResInferenceDataset(Dataset):
def __init__(self, vrt_path, tile_size=1024, stride=512, lr_side=2048, transform=None):
self.vrt_path = vrt_path
self.tile = int(tile_size)
self.stride = int(stride)
self.lr_side = int(lr_side)
self.transform = transform
# dimensions uniquement (on ferme)
with rasterio.open(self.vrt_path) as src:
self.H = src.height
self.W = src.width
self.profile = src.profile
self.index = [(y0, x0) for y0 in range(0, self.H, self.stride)
for x0 in range(0, self.W, self.stride)]
if self.tile < self.stride:
raise ValueError("tile_size must be >= stride")
self.halo = (self.tile - self.stride) // 2
# handle lazy (créé dans le worker)
self._src = None
def _get_src(self):
if self._src is None:
self._src = rasterio.open(self.vrt_path)
return self._src
def _read_rgb_window(self, y0, x0, h, w):
src = self._get_src()
win = Window(col_off=x0, row_off=y0, width=w, height=h)
data = src.read(window=win, boundless=True, fill_value=0) # CHW
data = data[:3, :, :] # RGB
img = np.transpose(data, (1, 2, 0)) # HWC
if img.dtype == np.uint8:
img = img.astype(np.float32) / 255.0
else:
img = img.astype(np.float32)
m = float(img.max()) if img.size else 1.0
if m > 1.0:
img /= m
return img
def __len__(self):
return len(self.index)
def __getitem__(self, idx):
y0, x0 = self.index[idx]
img_hr = self._read_rgb_window(y0, x0, self.tile, self.tile)
cy, cx = y0 + self.tile // 2, x0 + self.tile // 2
half = self.lr_side // 2
img_lr = self._read_rgb_window(cy - half, cx - half, self.lr_side, self.lr_side)
from dataset.definition_dataset import resize_np_img, to_pil_uint8, to_tensor_img
# fidèle à ton ancien setup: 1024 -> 512
img_hr = resize_np_img(img_hr, (self.tile // 2, self.tile // 2))
img_lr = resize_np_img(img_lr, (self.tile // 2, self.tile // 2))
if self.transform:
image_hr = self.transform(to_pil_uint8(img_hr))
image_lr = self.transform(to_pil_uint8(img_lr))
else:
image_hr = to_tensor_img(img_hr)
image_lr = to_tensor_img(img_lr)
wy0 = y0 + self.halo
wx0 = x0 + self.halo
wh = max(0, min(self.stride, self.H - wy0))
ww = max(0, min(self.stride, self.W - wx0))
meta = {
"tile_y0": y0,
"tile_x0": x0,
"tile_size": self.tile,
"img_hw": (self.H, self.W),
"write_y0": wy0,
"write_x0": wx0,
"write_h": wh,
"write_w": ww,
"idx": idx,
}
return image_hr, image_lr, meta
def __del__(self):
try:
if self._src is not None:
self._src.close()
except Exception:
pass