photo-classification / src /eval /dataset_prep.py
esandorfi's picture
Eval updates
ab318c0
#!/usr/bin/env python3
from __future__ import annotations
import io
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Optional
import click
import requests
from PIL import Image, ImageOps
from tqdm import tqdm
@dataclass(frozen=True)
class Config:
out_dir: Path
target: str
n: int
seed: int
normalize: bool
max_side: int
jpeg_quality: int
normalize_only: bool
in_dir: Optional[Path]
reset: bool
def ensure_dir(path: Path) -> None:
path.mkdir(parents=True, exist_ok=True)
def normalize_image_bytes(img_bytes: bytes, max_side: int, jpeg_quality: int) -> bytes:
with Image.open(io.BytesIO(img_bytes)) as im:
im = ImageOps.exif_transpose(im)
im = im.convert("RGB")
w, h = im.size
scale = max_side / float(max(w, h))
if scale < 1.0:
new_w = max(1, int(round(w * scale)))
new_h = max(1, int(round(h * scale)))
im = im.resize((new_w, new_h), Image.Resampling.LANCZOS)
out = io.BytesIO()
im.save(out, format="JPEG", quality=jpeg_quality, optimize=True, progressive=True)
return out.getvalue()
def download_url(url: str, timeout_s: float = 20.0) -> Optional[bytes]:
try:
resp = requests.get(url, timeout=timeout_s)
resp.raise_for_status()
return resp.content
except Exception:
return None
def save_bytes(path: Path, data: bytes) -> None:
ensure_dir(path.parent)
path.write_bytes(data)
def iter_images(paths: Iterable[Path]) -> Iterable[Path]:
exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tif", ".tiff"}
for path in paths:
if path.is_dir():
for p in sorted(path.rglob("*")):
if p.is_file() and p.suffix.lower() in exts:
yield p
elif path.is_file() and path.suffix.lower() in exts:
yield path
def normalize_existing(in_dir: Path, out_dir: Path, max_side: int, jpeg_quality: int) -> None:
files = list(iter_images([in_dir]))
ensure_dir(out_dir)
for p in tqdm(files, desc=f"Normalizing {in_dir.name}"):
try:
raw = p.read_bytes()
norm = normalize_image_bytes(raw, max_side=max_side, jpeg_quality=jpeg_quality)
rel = p.relative_to(in_dir)
out_path = (out_dir / rel).with_suffix(".jpg")
save_bytes(out_path, norm)
except Exception:
continue
print(f"Normalized images written to: {out_dir}")
def _next_index(dir_path: Path, prefix: str) -> int:
if not dir_path.exists():
return 0
max_idx = -1
for p in dir_path.glob(f"{prefix}_*.bin"):
stem = p.stem
try:
idx = int(stem.split("_")[-1])
except ValueError:
continue
max_idx = max(max_idx, idx)
return max_idx + 1
def download_photos_open_images(
out_dir: Path,
n: int,
seed: int,
normalize: bool,
max_side: int,
jpeg_quality: int,
reset: bool,
) -> None:
from datasets import load_dataset
ds = load_dataset("bitmind/open-images-v7", split="train", streaming=True)
saved = 0
raw_dir = out_dir / "photos" / "raw"
norm_dir = out_dir / "photos" / "normalized"
if reset:
for p in raw_dir.glob("openimages_*.bin"):
p.unlink()
for p in norm_dir.glob("openimages_*.jpg"):
p.unlink()
start_idx = _next_index(raw_dir, "openimages")
for row in tqdm(ds, desc="Streaming Open Images V7"):
if saved >= n:
break
url = row.get("image_url") or row.get("url") or row.get("imageUrl") or row.get("ImageURL")
if not url:
continue
img = download_url(url)
if not img:
continue
idx = start_idx + saved
raw_name = f"openimages_{idx:06d}.bin"
save_bytes(raw_dir / raw_name, img)
if normalize:
norm = normalize_image_bytes(img, max_side=max_side, jpeg_quality=jpeg_quality)
norm_name = f"openimages_{idx:06d}.jpg"
save_bytes(norm_dir / norm_name, norm)
saved += 1
print(f"Saved {saved} images to: {raw_dir}")
if normalize:
print(f"Normalized images written to: {norm_dir}")
def download_dance_x_dance(
out_dir: Path,
n: int,
seed: int,
normalize: bool,
max_side: int,
jpeg_quality: int,
reset: bool,
) -> None:
from datasets import load_dataset
ds = load_dataset("MCG-NJU/X-Dance", split="train", streaming=True)
raw_dir = out_dir / "dance" / "raw"
norm_dir = out_dir / "dance" / "normalized"
if reset:
for p in raw_dir.glob("xdance_*.bin"):
p.unlink()
for p in norm_dir.glob("xdance_*.jpg"):
p.unlink()
start_idx = _next_index(raw_dir, "xdance")
saved = 0
for row in tqdm(ds, desc="Streaming X-Dance"):
if saved >= n:
break
img_obj = row.get("image")
img_bytes: Optional[bytes] = None
if img_obj is not None and hasattr(img_obj, "convert"):
out = io.BytesIO()
img_obj.convert("RGB").save(out, format="PNG")
img_bytes = out.getvalue()
else:
url = row.get("image_url") or row.get("url")
if url:
img_bytes = download_url(url)
if not img_bytes:
continue
idx = start_idx + saved
raw_name = f"xdance_{idx:06d}.bin"
save_bytes(raw_dir / raw_name, img_bytes)
if normalize:
norm = normalize_image_bytes(img_bytes, max_side=max_side, jpeg_quality=jpeg_quality)
norm_name = f"xdance_{idx:06d}.jpg"
save_bytes(norm_dir / norm_name, norm)
saved += 1
print(f"Saved {saved} images to: {raw_dir}")
if normalize:
print(f"Normalized images written to: {norm_dir}")
def run(cfg: Config) -> None:
if cfg.normalize_only:
if not cfg.in_dir:
raise SystemExit("--in-dir is required with --normalize-only")
normalize_existing(cfg.in_dir, cfg.out_dir, cfg.max_side, cfg.jpeg_quality)
return
ensure_dir(cfg.out_dir)
if cfg.target == "photos":
download_photos_open_images(
out_dir=cfg.out_dir,
n=cfg.n,
seed=cfg.seed,
normalize=cfg.normalize,
max_side=cfg.max_side,
jpeg_quality=cfg.jpeg_quality,
reset=cfg.reset,
)
else:
download_dance_x_dance(
out_dir=cfg.out_dir,
n=cfg.n,
seed=cfg.seed,
normalize=cfg.normalize,
max_side=cfg.max_side,
jpeg_quality=cfg.jpeg_quality,
reset=cfg.reset,
)
@click.command()
@click.option("--out", "out_dir", required=True, type=click.Path(path_type=Path))
@click.option("--target", type=click.Choice(["photos", "dance"], case_sensitive=False), required=True)
@click.option("--n", default=500, show_default=True, type=int)
@click.option("--seed", default=0, show_default=True, type=int)
@click.option("--normalize", is_flag=True, default=False)
@click.option("--max-side", default=512, show_default=True, type=int)
@click.option("--jpeg-quality", default=92, show_default=True, type=int)
@click.option("--normalize-only", is_flag=True, default=False)
@click.option("--in-dir", type=click.Path(path_type=Path))
@click.option("--reset", is_flag=True, default=False, help="Delete existing raw/normalized files before download.")
def cli(
out_dir: Path,
target: str,
n: int,
seed: int,
normalize: bool,
max_side: int,
jpeg_quality: int,
normalize_only: bool,
in_dir: Optional[Path],
reset: bool,
) -> None:
cfg = Config(
out_dir=out_dir,
target=target,
n=n,
seed=seed,
normalize=normalize,
max_side=max_side,
jpeg_quality=jpeg_quality,
normalize_only=normalize_only,
in_dir=in_dir,
reset=reset,
)
run(cfg)
if __name__ == "__main__":
cli()