|
|
|
|
|
"""WireSeg dataset indexing and loading. |
|
|
|
|
|
Pairs images in `images_dir` with masks in `masks_dir` by matching filename stems. |
|
|
Mask is loaded as single-channel 0/1. |
|
|
""" |
|
|
|
|
|
from typing import Any, Dict, List, Tuple |
|
|
|
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
import cv2 |
|
|
|
|
|
|
|
|
class WireSegDataset: |
|
|
def __init__(self, images_dir: str, masks_dir: str, split: str = "train"): |
|
|
self.images_dir = Path(images_dir) |
|
|
self.masks_dir = Path(masks_dir) |
|
|
self.split = split |
|
|
assert self.images_dir.exists(), f"Missing images_dir: {self.images_dir}" |
|
|
assert self.masks_dir.exists(), f"Missing masks_dir: {self.masks_dir}" |
|
|
self._items: List[Tuple[Path, Path]] = self._index_pairs() |
|
|
|
|
|
self._sizes: List[Tuple[int, int]] = [] |
|
|
self._size_bins: Dict[Tuple[int, int], List[int]] = {} |
|
|
self._compute_size_bins() |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self._items) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, Any]: |
|
|
img_path, mask_path = self._items[idx] |
|
|
img_bgr = cv2.imread(str(img_path), cv2.IMREAD_COLOR) |
|
|
assert img_bgr is not None, f"Failed to read image: {img_path}" |
|
|
img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) |
|
|
mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) |
|
|
assert mask is not None, f"Failed to read mask: {mask_path}" |
|
|
mask_bin = (mask > 0).astype(np.uint8) |
|
|
return { |
|
|
"image": img, |
|
|
"mask": mask_bin, |
|
|
"image_path": str(img_path), |
|
|
"mask_path": str(mask_path), |
|
|
} |
|
|
|
|
|
def _index_pairs(self) -> List[Tuple[Path, Path]]: |
|
|
|
|
|
img_files = sorted([p for p in self.images_dir.glob("*.jpg") if p.exists()]) |
|
|
img_files += sorted([p for p in self.images_dir.glob("*.jpeg") if p.exists()]) |
|
|
assert len(img_files) > 0, f"No .jpg/.jpeg images in {self.images_dir}" |
|
|
pairs: List[Tuple[Path, Path]] = [] |
|
|
ids: List[int] = [] |
|
|
for p in img_files: |
|
|
stem = p.stem |
|
|
assert stem.isdigit(), f"Non-numeric filename encountered: {p.name}" |
|
|
ids.append(int(stem)) |
|
|
ids = sorted(ids) |
|
|
for i in ids: |
|
|
|
|
|
ip_jpg = self.images_dir / f"{i}.jpg" |
|
|
ip_jpeg = self.images_dir / f"{i}.jpeg" |
|
|
ip = ip_jpg if ip_jpg.exists() else ip_jpeg |
|
|
assert ip.exists(), f"Missing image for {i}: {ip_jpg} or {ip_jpeg}" |
|
|
mp = self.masks_dir / f"{i}.png" |
|
|
assert mp.exists(), f"Missing mask for {i}: {mp}" |
|
|
pairs.append((ip, mp)) |
|
|
assert len(pairs) > 0, ( |
|
|
f"No numeric pairs found in {self.images_dir} and {self.masks_dir}" |
|
|
) |
|
|
return pairs |
|
|
|
|
|
def _compute_size_bins(self) -> None: |
|
|
sizes: List[Tuple[int, int]] = [] |
|
|
bins: Dict[Tuple[int, int], List[int]] = {} |
|
|
for idx, (_ip, mp) in enumerate(self._items): |
|
|
m = cv2.imread(str(mp), cv2.IMREAD_GRAYSCALE) |
|
|
assert m is not None, f"Failed to read mask for size scan: {mp}" |
|
|
H, W = int(m.shape[0]), int(m.shape[1]) |
|
|
sizes.append((H, W)) |
|
|
bins.setdefault((H, W), []).append(idx) |
|
|
self._sizes = sizes |
|
|
self._size_bins = bins |
|
|
|
|
|
@property |
|
|
def size_bins(self) -> Dict[Tuple[int, int], List[int]]: |
|
|
return self._size_bins |
|
|
|