File size: 4,521 Bytes
e994c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""Synthetic defect augmentation helpers."""

from __future__ import annotations

import random
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable

from PIL import Image, ImageEnhance, ImageOps

from data.preprocessing import load_image

SYNTHETIC_LABEL_MAP = {
    "dust": "dust",
    "dirt": "dirt",
    "dots": "dust",
    "hair": "long_hair",
    "hair-short": "short_hair",
    "lint": "short_hair",
    "scratches": "scratch",
    "scratch": "scratch",
    "smut": "dirt",
    "spots": "dirt",
    "sprinkles": "dust",
    "stain": "dirt",
}


@dataclass(frozen=True)
class OverlayDefect:
    path: Path
    label: str


def discover_overlays(root: str | Path) -> list[OverlayDefect]:
    """Find transparent PNG overlays and infer labels from folder names."""
    root = Path(root)
    overlays: list[OverlayDefect] = []
    if not root.exists():
        return overlays
    for path in sorted(root.rglob("*.png")):
        label = SYNTHETIC_LABEL_MAP.get(path.parent.name.lower())
        if label:
            overlays.append(OverlayDefect(path=path, label=label))
    return overlays


def _visible_bbox(alpha: Image.Image) -> tuple[int, int, int, int] | None:
    bbox = alpha.getbbox()
    if bbox is None:
        return None
    x_min, y_min, x_max, y_max = bbox
    if x_max <= x_min or y_max <= y_min:
        return None
    return bbox


def _normalized_bbox(
    paste_x: int,
    paste_y: int,
    visible_bbox: tuple[int, int, int, int],
    width: int,
    height: int,
) -> list[float]:
    x_min, y_min, x_max, y_max = visible_bbox
    return [
        round((paste_x + x_min) / width, 6),
        round((paste_y + y_min) / height, 6),
        round((paste_x + x_max) / width, 6),
        round((paste_y + y_max) / height, 6),
    ]


def apply_overlay(
    base: Image.Image,
    overlay: OverlayDefect,
    *,
    rng: random.Random,
    scale_range: tuple[float, float] = (0.35, 1.4),
    opacity_range: tuple[float, float] = (0.55, 0.95),
) -> tuple[Image.Image, dict] | None:
    """Paste one defect overlay onto a copy of base and return annotation."""
    out = load_image(base).convert("RGBA")
    width, height = out.size

    layer = Image.open(overlay.path).convert("RGBA")
    if rng.random() < 0.5:
        layer = ImageOps.mirror(layer)
    if rng.random() < 0.35:
        layer = layer.rotate(rng.uniform(-22, 22), expand=True, resample=Image.Resampling.BICUBIC)

    scale = rng.uniform(*scale_range)
    new_size = (
        max(2, int(layer.width * scale)),
        max(2, int(layer.height * scale)),
    )
    layer = layer.resize(new_size, Image.Resampling.LANCZOS)
    if layer.width >= width or layer.height >= height:
        layer.thumbnail((width // 2, height // 2), Image.Resampling.LANCZOS)

    alpha = layer.getchannel("A")
    alpha = ImageEnhance.Brightness(alpha).enhance(rng.uniform(*opacity_range))
    layer.putalpha(alpha)
    visible = _visible_bbox(alpha)
    if visible is None:
        return None

    max_x = max(0, width - layer.width)
    max_y = max(0, height - layer.height)
    paste_x = rng.randint(0, max_x) if max_x else 0
    paste_y = rng.randint(0, max_y) if max_y else 0

    out.alpha_composite(layer, (paste_x, paste_y))
    annotation = {
        "label": overlay.label,
        "bbox": _normalized_bbox(paste_x, paste_y, visible, width, height),
    }
    return out.convert("RGB"), annotation


def augment_image(
    base: Image.Image,
    overlays: Iterable[OverlayDefect],
    *,
    seed: int,
    defects_per_image: tuple[int, int] = (3, 9),
) -> tuple[Image.Image, list[dict]]:
    """Create one augmented image and its generated annotations."""
    rng = random.Random(seed)
    overlay_list = list(overlays)
    if not overlay_list:
        return load_image(base), []

    by_label: dict[str, list[OverlayDefect]] = {}
    for overlay in overlay_list:
        by_label.setdefault(overlay.label, []).append(overlay)
    labels = sorted(by_label)

    out = load_image(base)
    annotations: list[dict] = []
    target = rng.randint(*defects_per_image)
    for _ in range(target):
        label = rng.choice(labels)
        overlay = rng.choice(by_label[label])
        result = apply_overlay(out, overlay, rng=rng)
        if result is None:
            continue
        out, annotation = result
        annotations.append(annotation)
    return out, annotations


__all__ = [
    "OverlayDefect",
    "SYNTHETIC_LABEL_MAP",
    "apply_overlay",
    "augment_image",
    "discover_overlays",
]