rupkotha / finetune /preprocess.py
Deb
rupkotha_1st_commit
f655146
Raw
History Blame Contribute Delete
4.89 kB
# finetune/preprocess.py
"""Stage 0.5 — clean the collected image pool into a training-ready set.
Sweeps finetune/data/images/**, and for each image:
1. optional fixed fractional crop (per source) — strips page furniture like the
"Kye Drawing for student" header / signature footer on the village scans;
2. auto-trim near-uniform borders (white margins on doodles & scans);
3. resize so the longest side ≤ MAX_SIZE;
4. RGB-normalise and re-encode;
5. drop near-duplicates via a 64-bit average hash (Hamming distance).
Outputs finetune/data/processed/<source>/<name>.jpg + a manifest.json. The
labeler (gen_labels.py) can then point at finetune/data/processed.
Run:
uv run python finetune/preprocess.py # process everything
uv run python finetune/preprocess.py --max-size 1024 --dedupe-distance 4
"""
import argparse
import json
from pathlib import Path
from PIL import Image, ImageChops
ROOT = Path(__file__).resolve().parent
SRC = ROOT / "data" / "images"
DST = ROOT / "data" / "processed"
# Per-source fractional crop (top, bottom, left, right) to remove page furniture.
# "root" = images placed directly in data/images/ (the village drawing-book scans).
CROP_FRACTIONS: dict[str, tuple[float, float, float, float]] = {
"root": (0.07, 0.06, 0.02, 0.02), # trim header text + signature/date footer
}
def _source_of(path: Path) -> str:
rel = path.relative_to(SRC)
return rel.parts[0] if len(rel.parts) > 1 else "root"
def _frac_crop(img: Image.Image, fracs: tuple[float, float, float, float]) -> Image.Image:
t, b, l, r = fracs
w, h = img.size
box = (int(w * l), int(h * t), int(w * (1 - r)), int(h * (1 - b)))
return img.crop(box) if box[2] > box[0] and box[3] > box[1] else img
def _autotrim(img: Image.Image, tol: int = 18) -> Image.Image:
"""Trim a near-uniform border using the top-left pixel as the background."""
bg = Image.new("RGB", img.size, img.getpixel((0, 0)))
diff = ImageChops.difference(img, bg).convert("L").point(lambda p: 255 if p > tol else 0)
bbox = diff.getbbox()
return img.crop(bbox) if bbox else img
def _resize(img: Image.Image, max_size: int) -> Image.Image:
w, h = img.size
scale = max_size / max(w, h)
if scale < 1:
img = img.resize((max(1, int(w * scale)), max(1, int(h * scale))), Image.LANCZOS)
return img
def _ahash(img: Image.Image) -> int:
"""64-bit average hash for near-duplicate detection (no extra deps)."""
small = img.convert("L").resize((8, 8), Image.LANCZOS)
px = list(small.getdata())
avg = sum(px) / len(px)
bits = 0
for i, p in enumerate(px):
if p >= avg:
bits |= 1 << i
return bits
def _hamming(a: int, b: int) -> int:
return bin(a ^ b).count("1")
def process(max_size: int, dedupe_distance: int) -> None:
exts = {".png", ".jpg", ".jpeg", ".webp"}
paths = sorted(p for p in SRC.rglob("*") if p.suffix.lower() in exts)
if not paths:
raise SystemExit(f"No images under {SRC}")
DST.mkdir(parents=True, exist_ok=True)
hashes: list[int] = []
manifest, kept, dups, errors = [], 0, 0, 0
for p in paths:
source = _source_of(p)
try:
img = Image.open(p).convert("RGB")
img = _frac_crop(img, CROP_FRACTIONS.get(source, (0, 0, 0, 0)))
img = _autotrim(img)
img = _resize(img, max_size)
except Exception as e: # noqa: BLE001
print(f" skip {p.name}: {e}")
errors += 1
continue
h = _ahash(img)
if any(_hamming(h, prev) <= dedupe_distance for prev in hashes):
dups += 1
continue
hashes.append(h)
out_dir = DST / source
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / f"{p.stem}.jpg"
img.save(out_path, "JPEG", quality=90)
manifest.append({"processed": str(out_path.relative_to(ROOT)),
"source_image": str(p.relative_to(ROOT)), "source": source})
kept += 1
(DST / "manifest.json").write_text(
json.dumps(manifest, ensure_ascii=False, indent=2), encoding="utf-8"
)
print(f"\nProcessed {len(paths)} → kept {kept}, dropped {dups} dups, {errors} errors")
by_src: dict[str, int] = {}
for m in manifest:
by_src[m["source"]] = by_src.get(m["source"], 0) + 1
for s, c in sorted(by_src.items()):
print(f" {s}: {c}")
print(f"Output: {DST} (+ manifest.json)")
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--max-size", type=int, default=1024)
ap.add_argument("--dedupe-distance", type=int, default=4,
help="max aHash Hamming distance to treat as duplicate (0=exact)")
args = ap.parse_args()
process(args.max_size, args.dedupe_distance)
if __name__ == "__main__":
main()