Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms as T | |
| from torchvision.transforms import functional as TF | |
| IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".tif", ".tiff"} | |
| def quality_key(value: Any) -> str: | |
| if value is None or str(value).strip() == "": | |
| return "" | |
| text = str(value).strip().lower() | |
| return text if text.startswith("q") else f"q{text}" | |
| def find_image_by_stem(directory: Path, stem: str) -> Path: | |
| for path in sorted(directory.iterdir()): | |
| if path.is_file() and path.suffix.lower() in IMAGE_EXTS and path.stem == stem: | |
| return path | |
| raise FileNotFoundError(f"No image with stem '{stem}' in {directory}") | |
| def list_images(directory: Path) -> list[Path]: | |
| return sorted(path for path in directory.iterdir() if path.is_file() and path.suffix.lower() in IMAGE_EXTS) | |
| def annotation_path(ann_dir: Path, image_path: Path) -> Path: | |
| candidates = [ | |
| ann_dir / f"{image_path.name}.json", | |
| ann_dir / f"{image_path.stem}.json", | |
| ann_dir / f"{image_path.stem}{image_path.suffix}.json", | |
| ] | |
| for candidate in candidates: | |
| if candidate.exists(): | |
| return candidate | |
| matches = sorted(ann_dir.glob(f"{image_path.stem}*.json")) | |
| if matches: | |
| return matches[0] | |
| raise FileNotFoundError(f"No annotation JSON found for '{image_path.name}' in {ann_dir}") | |
| def bbox_from_mask(mask: Image.Image) -> tuple[int, int, int, int] | None: | |
| return mask.convert("L").point(lambda x: 255 if x > 0 else 0).getbbox() | |
| def bbox_from_annotation(path: str | Path) -> tuple[int, int, int, int] | None: | |
| with Path(path).open() as f: | |
| payload = json.load(f) | |
| points: list[list[float]] = [] | |
| for obj in payload.get("objects", []): | |
| if obj.get("classTitle") != "Monogram": | |
| continue | |
| points.extend(obj.get("points", {}).get("exterior", [])) | |
| if not points: | |
| return None | |
| xs = [point[0] for point in points] | |
| ys = [point[1] for point in points] | |
| return int(min(xs)), int(min(ys)), int(max(xs)) + 1, int(max(ys)) + 1 | |
| def pad_bbox( | |
| bbox: tuple[int, int, int, int], | |
| width: int, | |
| height: int, | |
| padding_frac: float, | |
| ) -> tuple[int, int, int, int]: | |
| left, top, right, bottom = bbox | |
| pad = int(max(right - left, bottom - top) * padding_frac) | |
| return max(0, left - pad), max(0, top - pad), min(width, right + pad), min(height, bottom + pad) | |
| def _jitter_bbox( | |
| bbox: tuple[int, int, int, int], | |
| image_width: int, | |
| image_height: int, | |
| translate_frac: float, | |
| scale_frac: float, | |
| ) -> tuple[int, int, int, int]: | |
| if translate_frac <= 0 and scale_frac <= 0: | |
| return bbox | |
| left, top, right, bottom = bbox | |
| box_w = max(right - left, 1) | |
| box_h = max(bottom - top, 1) | |
| center_x = (left + right) / 2.0 | |
| center_y = (top + bottom) / 2.0 | |
| if translate_frac > 0: | |
| center_x += float(torch.empty(1).uniform_(-translate_frac, translate_frac).item()) * box_w | |
| center_y += float(torch.empty(1).uniform_(-translate_frac, translate_frac).item()) * box_h | |
| if scale_frac > 0: | |
| scale = float(torch.empty(1).uniform_(1.0 - scale_frac, 1.0 + scale_frac).item()) | |
| box_w *= scale | |
| box_h *= scale | |
| left = int(round(center_x - box_w / 2.0)) | |
| right = int(round(center_x + box_w / 2.0)) | |
| top = int(round(center_y - box_h / 2.0)) | |
| bottom = int(round(center_y + box_h / 2.0)) | |
| if left < 0: | |
| right -= left | |
| left = 0 | |
| if top < 0: | |
| bottom -= top | |
| top = 0 | |
| if right > image_width: | |
| left -= right - image_width | |
| right = image_width | |
| if bottom > image_height: | |
| top -= bottom - image_height | |
| bottom = image_height | |
| return max(0, left), max(0, top), min(image_width, max(left + 1, right)), min(image_height, max(top + 1, bottom)) | |
| def crop_pair( | |
| image: Image.Image, | |
| mask: Image.Image, | |
| ann_path: str | Path | None, | |
| image_size: int, | |
| crop_padding: float, | |
| crop_jitter: float = 0.0, | |
| crop_scale_jitter: float = 0.0, | |
| ) -> tuple[Image.Image, Image.Image]: | |
| if mask.size != image.size: | |
| mask = mask.resize(image.size, Image.Resampling.NEAREST) | |
| bbox = bbox_from_mask(mask) | |
| if bbox is None and ann_path is not None: | |
| bbox = bbox_from_annotation(ann_path) | |
| if bbox is None: | |
| bbox = (0, 0, image.width, image.height) | |
| bbox = pad_bbox(bbox, image.width, image.height, crop_padding) | |
| bbox = _jitter_bbox(bbox, image.width, image.height, crop_jitter, crop_scale_jitter) | |
| image_crop = image.crop(bbox).resize((image_size, image_size), Image.Resampling.BICUBIC) | |
| mask_crop = mask.crop(bbox).resize((image_size, image_size), Image.Resampling.NEAREST) | |
| return image_crop, mask_crop | |
| def image_transform() -> T.Compose: | |
| return T.Compose( | |
| [ | |
| T.ToTensor(), | |
| T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
| ] | |
| ) | |
| def mask_transform() -> T.Compose: | |
| return T.Compose( | |
| [ | |
| T.Grayscale(1), | |
| T.ToTensor(), | |
| T.Lambda(lambda x: (x > 0.5).float()), | |
| ] | |
| ) | |
| def apply_pair_augment( | |
| image: Image.Image, | |
| mask: Image.Image, | |
| rotation: float = 15.0, | |
| hflip_p: float = 0.5, | |
| vflip_p: float = 0.2, | |
| brightness: float = 0.15, | |
| contrast: float = 0.15, | |
| ) -> tuple[Image.Image, Image.Image]: | |
| if hflip_p > 0 and torch.rand(()) < hflip_p: | |
| image = TF.hflip(image) | |
| mask = TF.hflip(mask) | |
| if vflip_p > 0 and torch.rand(()) < vflip_p: | |
| image = TF.vflip(image) | |
| mask = TF.vflip(mask) | |
| if rotation > 0: | |
| angle = float(torch.empty(1).uniform_(-rotation, rotation).item()) | |
| image = TF.rotate(image, angle, interpolation=T.InterpolationMode.BICUBIC, fill=0) | |
| mask = TF.rotate(mask, angle, interpolation=T.InterpolationMode.NEAREST, fill=0) | |
| if brightness > 0: | |
| factor = float(torch.empty(1).uniform_(max(0.0, 1.0 - brightness), 1.0 + brightness).item()) | |
| image = TF.adjust_brightness(image, factor) | |
| if contrast > 0: | |
| factor = float(torch.empty(1).uniform_(max(0.0, 1.0 - contrast), 1.0 + contrast).item()) | |
| image = TF.adjust_contrast(image, factor) | |
| return image, mask | |