project-halide / data /augmentation.py
Lonelyguyse1's picture
Deploy Project Halide Gradio Space
e994c16 verified
Raw
History Blame Contribute Delete
4.52 kB
"""Synthetic defect augmentation helpers."""
from __future__ import annotations
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable
from PIL import Image, ImageEnhance, ImageOps
from data.preprocessing import load_image
SYNTHETIC_LABEL_MAP = {
"dust": "dust",
"dirt": "dirt",
"dots": "dust",
"hair": "long_hair",
"hair-short": "short_hair",
"lint": "short_hair",
"scratches": "scratch",
"scratch": "scratch",
"smut": "dirt",
"spots": "dirt",
"sprinkles": "dust",
"stain": "dirt",
}
@dataclass(frozen=True)
class OverlayDefect:
path: Path
label: str
def discover_overlays(root: str | Path) -> list[OverlayDefect]:
"""Find transparent PNG overlays and infer labels from folder names."""
root = Path(root)
overlays: list[OverlayDefect] = []
if not root.exists():
return overlays
for path in sorted(root.rglob("*.png")):
label = SYNTHETIC_LABEL_MAP.get(path.parent.name.lower())
if label:
overlays.append(OverlayDefect(path=path, label=label))
return overlays
def _visible_bbox(alpha: Image.Image) -> tuple[int, int, int, int] | None:
bbox = alpha.getbbox()
if bbox is None:
return None
x_min, y_min, x_max, y_max = bbox
if x_max <= x_min or y_max <= y_min:
return None
return bbox
def _normalized_bbox(
paste_x: int,
paste_y: int,
visible_bbox: tuple[int, int, int, int],
width: int,
height: int,
) -> list[float]:
x_min, y_min, x_max, y_max = visible_bbox
return [
round((paste_x + x_min) / width, 6),
round((paste_y + y_min) / height, 6),
round((paste_x + x_max) / width, 6),
round((paste_y + y_max) / height, 6),
]
def apply_overlay(
base: Image.Image,
overlay: OverlayDefect,
*,
rng: random.Random,
scale_range: tuple[float, float] = (0.35, 1.4),
opacity_range: tuple[float, float] = (0.55, 0.95),
) -> tuple[Image.Image, dict] | None:
"""Paste one defect overlay onto a copy of base and return annotation."""
out = load_image(base).convert("RGBA")
width, height = out.size
layer = Image.open(overlay.path).convert("RGBA")
if rng.random() < 0.5:
layer = ImageOps.mirror(layer)
if rng.random() < 0.35:
layer = layer.rotate(rng.uniform(-22, 22), expand=True, resample=Image.Resampling.BICUBIC)
scale = rng.uniform(*scale_range)
new_size = (
max(2, int(layer.width * scale)),
max(2, int(layer.height * scale)),
)
layer = layer.resize(new_size, Image.Resampling.LANCZOS)
if layer.width >= width or layer.height >= height:
layer.thumbnail((width // 2, height // 2), Image.Resampling.LANCZOS)
alpha = layer.getchannel("A")
alpha = ImageEnhance.Brightness(alpha).enhance(rng.uniform(*opacity_range))
layer.putalpha(alpha)
visible = _visible_bbox(alpha)
if visible is None:
return None
max_x = max(0, width - layer.width)
max_y = max(0, height - layer.height)
paste_x = rng.randint(0, max_x) if max_x else 0
paste_y = rng.randint(0, max_y) if max_y else 0
out.alpha_composite(layer, (paste_x, paste_y))
annotation = {
"label": overlay.label,
"bbox": _normalized_bbox(paste_x, paste_y, visible, width, height),
}
return out.convert("RGB"), annotation
def augment_image(
base: Image.Image,
overlays: Iterable[OverlayDefect],
*,
seed: int,
defects_per_image: tuple[int, int] = (3, 9),
) -> tuple[Image.Image, list[dict]]:
"""Create one augmented image and its generated annotations."""
rng = random.Random(seed)
overlay_list = list(overlays)
if not overlay_list:
return load_image(base), []
by_label: dict[str, list[OverlayDefect]] = {}
for overlay in overlay_list:
by_label.setdefault(overlay.label, []).append(overlay)
labels = sorted(by_label)
out = load_image(base)
annotations: list[dict] = []
target = rng.randint(*defects_per_image)
for _ in range(target):
label = rng.choice(labels)
overlay = rng.choice(by_label[label])
result = apply_overlay(out, overlay, rng=rng)
if result is None:
continue
out, annotation = result
annotations.append(annotation)
return out, annotations
__all__ = [
"OverlayDefect",
"SYNTHETIC_LABEL_MAP",
"apply_overlay",
"augment_image",
"discover_overlays",
]