segment-monograms / src /data /transforms.py
Saranga7's picture
Deploy monogram segmentation demo
bcc432f verified
Raw
History Blame Contribute Delete
6.32 kB
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
import torch
from PIL import Image
from torchvision import transforms as T
from torchvision.transforms import functional as TF
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".tif", ".tiff"}
def quality_key(value: Any) -> str:
if value is None or str(value).strip() == "":
return ""
text = str(value).strip().lower()
return text if text.startswith("q") else f"q{text}"
def find_image_by_stem(directory: Path, stem: str) -> Path:
for path in sorted(directory.iterdir()):
if path.is_file() and path.suffix.lower() in IMAGE_EXTS and path.stem == stem:
return path
raise FileNotFoundError(f"No image with stem '{stem}' in {directory}")
def list_images(directory: Path) -> list[Path]:
return sorted(path for path in directory.iterdir() if path.is_file() and path.suffix.lower() in IMAGE_EXTS)
def annotation_path(ann_dir: Path, image_path: Path) -> Path:
candidates = [
ann_dir / f"{image_path.name}.json",
ann_dir / f"{image_path.stem}.json",
ann_dir / f"{image_path.stem}{image_path.suffix}.json",
]
for candidate in candidates:
if candidate.exists():
return candidate
matches = sorted(ann_dir.glob(f"{image_path.stem}*.json"))
if matches:
return matches[0]
raise FileNotFoundError(f"No annotation JSON found for '{image_path.name}' in {ann_dir}")
def bbox_from_mask(mask: Image.Image) -> tuple[int, int, int, int] | None:
return mask.convert("L").point(lambda x: 255 if x > 0 else 0).getbbox()
def bbox_from_annotation(path: str | Path) -> tuple[int, int, int, int] | None:
with Path(path).open() as f:
payload = json.load(f)
points: list[list[float]] = []
for obj in payload.get("objects", []):
if obj.get("classTitle") != "Monogram":
continue
points.extend(obj.get("points", {}).get("exterior", []))
if not points:
return None
xs = [point[0] for point in points]
ys = [point[1] for point in points]
return int(min(xs)), int(min(ys)), int(max(xs)) + 1, int(max(ys)) + 1
def pad_bbox(
bbox: tuple[int, int, int, int],
width: int,
height: int,
padding_frac: float,
) -> tuple[int, int, int, int]:
left, top, right, bottom = bbox
pad = int(max(right - left, bottom - top) * padding_frac)
return max(0, left - pad), max(0, top - pad), min(width, right + pad), min(height, bottom + pad)
def _jitter_bbox(
bbox: tuple[int, int, int, int],
image_width: int,
image_height: int,
translate_frac: float,
scale_frac: float,
) -> tuple[int, int, int, int]:
if translate_frac <= 0 and scale_frac <= 0:
return bbox
left, top, right, bottom = bbox
box_w = max(right - left, 1)
box_h = max(bottom - top, 1)
center_x = (left + right) / 2.0
center_y = (top + bottom) / 2.0
if translate_frac > 0:
center_x += float(torch.empty(1).uniform_(-translate_frac, translate_frac).item()) * box_w
center_y += float(torch.empty(1).uniform_(-translate_frac, translate_frac).item()) * box_h
if scale_frac > 0:
scale = float(torch.empty(1).uniform_(1.0 - scale_frac, 1.0 + scale_frac).item())
box_w *= scale
box_h *= scale
left = int(round(center_x - box_w / 2.0))
right = int(round(center_x + box_w / 2.0))
top = int(round(center_y - box_h / 2.0))
bottom = int(round(center_y + box_h / 2.0))
if left < 0:
right -= left
left = 0
if top < 0:
bottom -= top
top = 0
if right > image_width:
left -= right - image_width
right = image_width
if bottom > image_height:
top -= bottom - image_height
bottom = image_height
return max(0, left), max(0, top), min(image_width, max(left + 1, right)), min(image_height, max(top + 1, bottom))
def crop_pair(
image: Image.Image,
mask: Image.Image,
ann_path: str | Path | None,
image_size: int,
crop_padding: float,
crop_jitter: float = 0.0,
crop_scale_jitter: float = 0.0,
) -> tuple[Image.Image, Image.Image]:
if mask.size != image.size:
mask = mask.resize(image.size, Image.Resampling.NEAREST)
bbox = bbox_from_mask(mask)
if bbox is None and ann_path is not None:
bbox = bbox_from_annotation(ann_path)
if bbox is None:
bbox = (0, 0, image.width, image.height)
bbox = pad_bbox(bbox, image.width, image.height, crop_padding)
bbox = _jitter_bbox(bbox, image.width, image.height, crop_jitter, crop_scale_jitter)
image_crop = image.crop(bbox).resize((image_size, image_size), Image.Resampling.BICUBIC)
mask_crop = mask.crop(bbox).resize((image_size, image_size), Image.Resampling.NEAREST)
return image_crop, mask_crop
def image_transform() -> T.Compose:
return T.Compose(
[
T.ToTensor(),
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
def mask_transform() -> T.Compose:
return T.Compose(
[
T.Grayscale(1),
T.ToTensor(),
T.Lambda(lambda x: (x > 0.5).float()),
]
)
def apply_pair_augment(
image: Image.Image,
mask: Image.Image,
rotation: float = 15.0,
hflip_p: float = 0.5,
vflip_p: float = 0.2,
brightness: float = 0.15,
contrast: float = 0.15,
) -> tuple[Image.Image, Image.Image]:
if hflip_p > 0 and torch.rand(()) < hflip_p:
image = TF.hflip(image)
mask = TF.hflip(mask)
if vflip_p > 0 and torch.rand(()) < vflip_p:
image = TF.vflip(image)
mask = TF.vflip(mask)
if rotation > 0:
angle = float(torch.empty(1).uniform_(-rotation, rotation).item())
image = TF.rotate(image, angle, interpolation=T.InterpolationMode.BICUBIC, fill=0)
mask = TF.rotate(mask, angle, interpolation=T.InterpolationMode.NEAREST, fill=0)
if brightness > 0:
factor = float(torch.empty(1).uniform_(max(0.0, 1.0 - brightness), 1.0 + brightness).item())
image = TF.adjust_brightness(image, factor)
if contrast > 0:
factor = float(torch.empty(1).uniform_(max(0.0, 1.0 - contrast), 1.0 + contrast).item())
image = TF.adjust_contrast(image, factor)
return image, mask