| import json | |
| import random | |
| from pathlib import Path | |
| from typing import Dict, Tuple, List | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| from PIL import Image | |
| class SemanticSegmentationDataset(Dataset): | |
| def __init__( | |
| self, | |
| root_dir: Path, | |
| split_json: Path, | |
| split: str, | |
| mode: str, | |
| resize_size: Tuple[int, int], | |
| crop_size: Tuple[int, int], | |
| class_map: Dict[Tuple[int, ...], int], | |
| transform: transforms.Compose = None | |
| ): | |
| """ | |
| Dataset for semantic segmentation, zones structured as: | |
| root_dir/zone/images/*.jpg | |
| root_dir/zone/masks/*.png | |
| """ | |
| self.root_dir = Path(root_dir) | |
| self.mode = mode | |
| self.resize_size = resize_size | |
| self.crop_size = crop_size | |
| self.transform = transform | |
| self.pairs = self._gather_pairs(split_json, split) | |
| self.lut = np.full(256, fill_value=255, dtype=np.uint8) | |
| for keys, v in class_map.items(): | |
| self.lut[list(keys)] = v | |
| def _gather_pairs(self, split_json: Path, split: str) -> List[Tuple[Path, Path]]: | |
| with open(split_json, 'r') as f: | |
| split_data = json.load(f) | |
| dirs = split_data.get(split, []) | |
| pairs = [] | |
| for zone in sorted(dirs): | |
| images_dir = self.root_dir / zone / "images" | |
| masks_dir = self.root_dir / zone / "masks" | |
| if not images_dir.is_dir(): | |
| continue | |
| if self.mode in ("train", "val") and not masks_dir.is_dir(): | |
| continue | |
| for img_path in images_dir.glob("*.JPG"): | |
| if self.mode in ["test", "test3d", "val3d"]: | |
| pairs.append((img_path, None)) | |
| else: | |
| mask_path = masks_dir / img_path.name.replace(".JPG", ".png") | |
| if mask_path.exists(): | |
| pairs.append((img_path, mask_path)) | |
| return pairs | |
| def __len__(self) -> int: | |
| return len(self.pairs) | |
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| img_path, mask_path = self.pairs[idx] | |
| image = Image.open(img_path).convert("RGB").resize(self.resize_size, resample=Image.BILINEAR) | |
| image_np = np.array(image) | |
| if mask_path is not None: | |
| mask = Image.open(mask_path).resize(self.resize_size, resample=Image.NEAREST) | |
| mask_np = self.lut[np.array(mask)] | |
| else: | |
| mask_np = None | |
| if self.mode == 'train': | |
| h, w = image_np.shape[:2] | |
| ch, cw = self.crop_size | |
| if h < ch or w < cw: | |
| raise RuntimeError(f"Image {img_path} size ({h},{w}) < crop {self.crop_size}") | |
| top = random.randint(0, h - ch) | |
| left = random.randint(0, w - cw) | |
| image_np = image_np[top:top+ch, left:left+cw] | |
| mask_np = mask_np[top:top+ch, left:left+cw] | |
| image_tensor = self.transform(Image.fromarray(image_np)) if self.transform else torch.from_numpy(image_np).permute(2,0,1).float()/255.0 | |
| if mask_np is None: | |
| zone = img_path.parent.parent.name | |
| filename = img_path.name.replace('.JPG', '.png') | |
| path_pred = f"{zone}/{filename}" | |
| return image_tensor, path_pred | |
| else: | |
| mask_tensor = torch.from_numpy(mask_np).long() | |
| return image_tensor, mask_tensor | |