from __future__ import annotations import json from pathlib import Path from typing import Any import torch from PIL import Image from torchvision import transforms as T from torchvision.transforms import functional as TF IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".tif", ".tiff"} def quality_key(value: Any) -> str: if value is None or str(value).strip() == "": return "" text = str(value).strip().lower() return text if text.startswith("q") else f"q{text}" def find_image_by_stem(directory: Path, stem: str) -> Path: for path in sorted(directory.iterdir()): if path.is_file() and path.suffix.lower() in IMAGE_EXTS and path.stem == stem: return path raise FileNotFoundError(f"No image with stem '{stem}' in {directory}") def list_images(directory: Path) -> list[Path]: return sorted(path for path in directory.iterdir() if path.is_file() and path.suffix.lower() in IMAGE_EXTS) def annotation_path(ann_dir: Path, image_path: Path) -> Path: candidates = [ ann_dir / f"{image_path.name}.json", ann_dir / f"{image_path.stem}.json", ann_dir / f"{image_path.stem}{image_path.suffix}.json", ] for candidate in candidates: if candidate.exists(): return candidate matches = sorted(ann_dir.glob(f"{image_path.stem}*.json")) if matches: return matches[0] raise FileNotFoundError(f"No annotation JSON found for '{image_path.name}' in {ann_dir}") def bbox_from_mask(mask: Image.Image) -> tuple[int, int, int, int] | None: return mask.convert("L").point(lambda x: 255 if x > 0 else 0).getbbox() def bbox_from_annotation(path: str | Path) -> tuple[int, int, int, int] | None: with Path(path).open() as f: payload = json.load(f) points: list[list[float]] = [] for obj in payload.get("objects", []): if obj.get("classTitle") != "Monogram": continue points.extend(obj.get("points", {}).get("exterior", [])) if not points: return None xs = [point[0] for point in points] ys = [point[1] for point in points] return int(min(xs)), int(min(ys)), int(max(xs)) + 1, int(max(ys)) + 1 def pad_bbox( bbox: tuple[int, int, int, int], width: int, height: int, padding_frac: float, ) -> tuple[int, int, int, int]: left, top, right, bottom = bbox pad = int(max(right - left, bottom - top) * padding_frac) return max(0, left - pad), max(0, top - pad), min(width, right + pad), min(height, bottom + pad) def _jitter_bbox( bbox: tuple[int, int, int, int], image_width: int, image_height: int, translate_frac: float, scale_frac: float, ) -> tuple[int, int, int, int]: if translate_frac <= 0 and scale_frac <= 0: return bbox left, top, right, bottom = bbox box_w = max(right - left, 1) box_h = max(bottom - top, 1) center_x = (left + right) / 2.0 center_y = (top + bottom) / 2.0 if translate_frac > 0: center_x += float(torch.empty(1).uniform_(-translate_frac, translate_frac).item()) * box_w center_y += float(torch.empty(1).uniform_(-translate_frac, translate_frac).item()) * box_h if scale_frac > 0: scale = float(torch.empty(1).uniform_(1.0 - scale_frac, 1.0 + scale_frac).item()) box_w *= scale box_h *= scale left = int(round(center_x - box_w / 2.0)) right = int(round(center_x + box_w / 2.0)) top = int(round(center_y - box_h / 2.0)) bottom = int(round(center_y + box_h / 2.0)) if left < 0: right -= left left = 0 if top < 0: bottom -= top top = 0 if right > image_width: left -= right - image_width right = image_width if bottom > image_height: top -= bottom - image_height bottom = image_height return max(0, left), max(0, top), min(image_width, max(left + 1, right)), min(image_height, max(top + 1, bottom)) def crop_pair( image: Image.Image, mask: Image.Image, ann_path: str | Path | None, image_size: int, crop_padding: float, crop_jitter: float = 0.0, crop_scale_jitter: float = 0.0, ) -> tuple[Image.Image, Image.Image]: if mask.size != image.size: mask = mask.resize(image.size, Image.Resampling.NEAREST) bbox = bbox_from_mask(mask) if bbox is None and ann_path is not None: bbox = bbox_from_annotation(ann_path) if bbox is None: bbox = (0, 0, image.width, image.height) bbox = pad_bbox(bbox, image.width, image.height, crop_padding) bbox = _jitter_bbox(bbox, image.width, image.height, crop_jitter, crop_scale_jitter) image_crop = image.crop(bbox).resize((image_size, image_size), Image.Resampling.BICUBIC) mask_crop = mask.crop(bbox).resize((image_size, image_size), Image.Resampling.NEAREST) return image_crop, mask_crop def image_transform() -> T.Compose: return T.Compose( [ T.ToTensor(), T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ] ) def mask_transform() -> T.Compose: return T.Compose( [ T.Grayscale(1), T.ToTensor(), T.Lambda(lambda x: (x > 0.5).float()), ] ) def apply_pair_augment( image: Image.Image, mask: Image.Image, rotation: float = 15.0, hflip_p: float = 0.5, vflip_p: float = 0.2, brightness: float = 0.15, contrast: float = 0.15, ) -> tuple[Image.Image, Image.Image]: if hflip_p > 0 and torch.rand(()) < hflip_p: image = TF.hflip(image) mask = TF.hflip(mask) if vflip_p > 0 and torch.rand(()) < vflip_p: image = TF.vflip(image) mask = TF.vflip(mask) if rotation > 0: angle = float(torch.empty(1).uniform_(-rotation, rotation).item()) image = TF.rotate(image, angle, interpolation=T.InterpolationMode.BICUBIC, fill=0) mask = TF.rotate(mask, angle, interpolation=T.InterpolationMode.NEAREST, fill=0) if brightness > 0: factor = float(torch.empty(1).uniform_(max(0.0, 1.0 - brightness), 1.0 + brightness).item()) image = TF.adjust_brightness(image, factor) if contrast > 0: factor = float(torch.empty(1).uniform_(max(0.0, 1.0 - contrast), 1.0 + contrast).item()) image = TF.adjust_contrast(image, factor) return image, mask