#!/usr/bin/env python3 from __future__ import annotations import io from dataclasses import dataclass from pathlib import Path from typing import Iterable, Optional import click import requests from PIL import Image, ImageOps from tqdm import tqdm @dataclass(frozen=True) class Config: out_dir: Path target: str n: int seed: int normalize: bool max_side: int jpeg_quality: int normalize_only: bool in_dir: Optional[Path] reset: bool def ensure_dir(path: Path) -> None: path.mkdir(parents=True, exist_ok=True) def normalize_image_bytes(img_bytes: bytes, max_side: int, jpeg_quality: int) -> bytes: with Image.open(io.BytesIO(img_bytes)) as im: im = ImageOps.exif_transpose(im) im = im.convert("RGB") w, h = im.size scale = max_side / float(max(w, h)) if scale < 1.0: new_w = max(1, int(round(w * scale))) new_h = max(1, int(round(h * scale))) im = im.resize((new_w, new_h), Image.Resampling.LANCZOS) out = io.BytesIO() im.save(out, format="JPEG", quality=jpeg_quality, optimize=True, progressive=True) return out.getvalue() def download_url(url: str, timeout_s: float = 20.0) -> Optional[bytes]: try: resp = requests.get(url, timeout=timeout_s) resp.raise_for_status() return resp.content except Exception: return None def save_bytes(path: Path, data: bytes) -> None: ensure_dir(path.parent) path.write_bytes(data) def iter_images(paths: Iterable[Path]) -> Iterable[Path]: exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tif", ".tiff"} for path in paths: if path.is_dir(): for p in sorted(path.rglob("*")): if p.is_file() and p.suffix.lower() in exts: yield p elif path.is_file() and path.suffix.lower() in exts: yield path def normalize_existing(in_dir: Path, out_dir: Path, max_side: int, jpeg_quality: int) -> None: files = list(iter_images([in_dir])) ensure_dir(out_dir) for p in tqdm(files, desc=f"Normalizing {in_dir.name}"): try: raw = p.read_bytes() norm = normalize_image_bytes(raw, max_side=max_side, jpeg_quality=jpeg_quality) rel = p.relative_to(in_dir) out_path = (out_dir / rel).with_suffix(".jpg") save_bytes(out_path, norm) except Exception: continue print(f"Normalized images written to: {out_dir}") def _next_index(dir_path: Path, prefix: str) -> int: if not dir_path.exists(): return 0 max_idx = -1 for p in dir_path.glob(f"{prefix}_*.bin"): stem = p.stem try: idx = int(stem.split("_")[-1]) except ValueError: continue max_idx = max(max_idx, idx) return max_idx + 1 def download_photos_open_images( out_dir: Path, n: int, seed: int, normalize: bool, max_side: int, jpeg_quality: int, reset: bool, ) -> None: from datasets import load_dataset ds = load_dataset("bitmind/open-images-v7", split="train", streaming=True) saved = 0 raw_dir = out_dir / "photos" / "raw" norm_dir = out_dir / "photos" / "normalized" if reset: for p in raw_dir.glob("openimages_*.bin"): p.unlink() for p in norm_dir.glob("openimages_*.jpg"): p.unlink() start_idx = _next_index(raw_dir, "openimages") for row in tqdm(ds, desc="Streaming Open Images V7"): if saved >= n: break url = row.get("image_url") or row.get("url") or row.get("imageUrl") or row.get("ImageURL") if not url: continue img = download_url(url) if not img: continue idx = start_idx + saved raw_name = f"openimages_{idx:06d}.bin" save_bytes(raw_dir / raw_name, img) if normalize: norm = normalize_image_bytes(img, max_side=max_side, jpeg_quality=jpeg_quality) norm_name = f"openimages_{idx:06d}.jpg" save_bytes(norm_dir / norm_name, norm) saved += 1 print(f"Saved {saved} images to: {raw_dir}") if normalize: print(f"Normalized images written to: {norm_dir}") def download_dance_x_dance( out_dir: Path, n: int, seed: int, normalize: bool, max_side: int, jpeg_quality: int, reset: bool, ) -> None: from datasets import load_dataset ds = load_dataset("MCG-NJU/X-Dance", split="train", streaming=True) raw_dir = out_dir / "dance" / "raw" norm_dir = out_dir / "dance" / "normalized" if reset: for p in raw_dir.glob("xdance_*.bin"): p.unlink() for p in norm_dir.glob("xdance_*.jpg"): p.unlink() start_idx = _next_index(raw_dir, "xdance") saved = 0 for row in tqdm(ds, desc="Streaming X-Dance"): if saved >= n: break img_obj = row.get("image") img_bytes: Optional[bytes] = None if img_obj is not None and hasattr(img_obj, "convert"): out = io.BytesIO() img_obj.convert("RGB").save(out, format="PNG") img_bytes = out.getvalue() else: url = row.get("image_url") or row.get("url") if url: img_bytes = download_url(url) if not img_bytes: continue idx = start_idx + saved raw_name = f"xdance_{idx:06d}.bin" save_bytes(raw_dir / raw_name, img_bytes) if normalize: norm = normalize_image_bytes(img_bytes, max_side=max_side, jpeg_quality=jpeg_quality) norm_name = f"xdance_{idx:06d}.jpg" save_bytes(norm_dir / norm_name, norm) saved += 1 print(f"Saved {saved} images to: {raw_dir}") if normalize: print(f"Normalized images written to: {norm_dir}") def run(cfg: Config) -> None: if cfg.normalize_only: if not cfg.in_dir: raise SystemExit("--in-dir is required with --normalize-only") normalize_existing(cfg.in_dir, cfg.out_dir, cfg.max_side, cfg.jpeg_quality) return ensure_dir(cfg.out_dir) if cfg.target == "photos": download_photos_open_images( out_dir=cfg.out_dir, n=cfg.n, seed=cfg.seed, normalize=cfg.normalize, max_side=cfg.max_side, jpeg_quality=cfg.jpeg_quality, reset=cfg.reset, ) else: download_dance_x_dance( out_dir=cfg.out_dir, n=cfg.n, seed=cfg.seed, normalize=cfg.normalize, max_side=cfg.max_side, jpeg_quality=cfg.jpeg_quality, reset=cfg.reset, ) @click.command() @click.option("--out", "out_dir", required=True, type=click.Path(path_type=Path)) @click.option("--target", type=click.Choice(["photos", "dance"], case_sensitive=False), required=True) @click.option("--n", default=500, show_default=True, type=int) @click.option("--seed", default=0, show_default=True, type=int) @click.option("--normalize", is_flag=True, default=False) @click.option("--max-side", default=512, show_default=True, type=int) @click.option("--jpeg-quality", default=92, show_default=True, type=int) @click.option("--normalize-only", is_flag=True, default=False) @click.option("--in-dir", type=click.Path(path_type=Path)) @click.option("--reset", is_flag=True, default=False, help="Delete existing raw/normalized files before download.") def cli( out_dir: Path, target: str, n: int, seed: int, normalize: bool, max_side: int, jpeg_quality: int, normalize_only: bool, in_dir: Optional[Path], reset: bool, ) -> None: cfg = Config( out_dir=out_dir, target=target, n=n, seed=seed, normalize=normalize, max_side=max_side, jpeg_quality=jpeg_quality, normalize_only=normalize_only, in_dir=in_dir, reset=reset, ) run(cfg) if __name__ == "__main__": cli()