| import json |
| import random |
| from pathlib import Path |
| from typing import Dict, Tuple, List |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
| from torchvision import transforms |
| from PIL import Image |
|
|
|
|
| class SemanticSegmentationDataset(Dataset): |
| def __init__( |
| self, |
| root_dir: Path, |
| split_json: Path, |
| split: str, |
| mode: str, |
| resize_size: Tuple[int, int], |
| crop_size: Tuple[int, int], |
| class_map: Dict[Tuple[int, ...], int], |
| transform: transforms.Compose = None |
| ): |
| """ |
| Dataset for semantic segmentation, zones structured as: |
| root_dir/zone/images/*.jpg |
| root_dir/zone/masks/*.png |
| """ |
| self.root_dir = Path(root_dir) |
| self.mode = mode |
| self.resize_size = resize_size |
| self.crop_size = crop_size |
| self.transform = transform |
| self.pairs = self._gather_pairs(split_json, split) |
|
|
| self.lut = np.full(256, fill_value=255, dtype=np.uint8) |
| for keys, v in class_map.items(): |
| self.lut[list(keys)] = v |
|
|
| def _gather_pairs(self, split_json: Path, split: str) -> List[Tuple[Path, Path]]: |
| with open(split_json, 'r') as f: |
| split_data = json.load(f) |
| dirs = split_data.get(split, []) |
| pairs = [] |
| for zone in sorted(dirs): |
| images_dir = self.root_dir / zone / "images" |
| masks_dir = self.root_dir / zone / "masks" |
| if not images_dir.is_dir(): |
| continue |
| if self.mode in ("train", "val") and not masks_dir.is_dir(): |
| continue |
| for img_path in images_dir.glob("*.JPG"): |
| if self.mode in ["test", "test3d", "val3d"]: |
| pairs.append((img_path, None)) |
| else: |
| mask_path = masks_dir / img_path.name.replace(".JPG", ".png") |
| if mask_path.exists(): |
| pairs.append((img_path, mask_path)) |
| return pairs |
|
|
| def __len__(self) -> int: |
| return len(self.pairs) |
|
|
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| img_path, mask_path = self.pairs[idx] |
| image = Image.open(img_path).convert("RGB").resize(self.resize_size, resample=Image.BILINEAR) |
| image_np = np.array(image) |
|
|
| if mask_path is not None: |
| mask = Image.open(mask_path).resize(self.resize_size, resample=Image.NEAREST) |
| mask_np = self.lut[np.array(mask)] |
| else: |
| mask_np = None |
|
|
| if self.mode == 'train': |
| h, w = image_np.shape[:2] |
| ch, cw = self.crop_size |
| if h < ch or w < cw: |
| raise RuntimeError(f"Image {img_path} size ({h},{w}) < crop {self.crop_size}") |
| top = random.randint(0, h - ch) |
| left = random.randint(0, w - cw) |
| image_np = image_np[top:top+ch, left:left+cw] |
| mask_np = mask_np[top:top+ch, left:left+cw] |
|
|
| image_tensor = self.transform(Image.fromarray(image_np)) if self.transform else torch.from_numpy(image_np).permute(2,0,1).float()/255.0 |
|
|
| if mask_np is None: |
| zone = img_path.parent.parent.name |
| filename = img_path.name.replace('.JPG', '.png') |
| path_pred = f"{zone}/{filename}" |
| return image_tensor, path_pred |
| else: |
| mask_tensor = torch.from_numpy(mask_np).long() |
| return image_tensor, mask_tensor |
|
|
|
|