Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import io | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Iterable, Optional | |
| import click | |
| import requests | |
| from PIL import Image, ImageOps | |
| from tqdm import tqdm | |
| class Config: | |
| out_dir: Path | |
| target: str | |
| n: int | |
| seed: int | |
| normalize: bool | |
| max_side: int | |
| jpeg_quality: int | |
| normalize_only: bool | |
| in_dir: Optional[Path] | |
| reset: bool | |
| def ensure_dir(path: Path) -> None: | |
| path.mkdir(parents=True, exist_ok=True) | |
| def normalize_image_bytes(img_bytes: bytes, max_side: int, jpeg_quality: int) -> bytes: | |
| with Image.open(io.BytesIO(img_bytes)) as im: | |
| im = ImageOps.exif_transpose(im) | |
| im = im.convert("RGB") | |
| w, h = im.size | |
| scale = max_side / float(max(w, h)) | |
| if scale < 1.0: | |
| new_w = max(1, int(round(w * scale))) | |
| new_h = max(1, int(round(h * scale))) | |
| im = im.resize((new_w, new_h), Image.Resampling.LANCZOS) | |
| out = io.BytesIO() | |
| im.save(out, format="JPEG", quality=jpeg_quality, optimize=True, progressive=True) | |
| return out.getvalue() | |
| def download_url(url: str, timeout_s: float = 20.0) -> Optional[bytes]: | |
| try: | |
| resp = requests.get(url, timeout=timeout_s) | |
| resp.raise_for_status() | |
| return resp.content | |
| except Exception: | |
| return None | |
| def save_bytes(path: Path, data: bytes) -> None: | |
| ensure_dir(path.parent) | |
| path.write_bytes(data) | |
| def iter_images(paths: Iterable[Path]) -> Iterable[Path]: | |
| exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tif", ".tiff"} | |
| for path in paths: | |
| if path.is_dir(): | |
| for p in sorted(path.rglob("*")): | |
| if p.is_file() and p.suffix.lower() in exts: | |
| yield p | |
| elif path.is_file() and path.suffix.lower() in exts: | |
| yield path | |
| def normalize_existing(in_dir: Path, out_dir: Path, max_side: int, jpeg_quality: int) -> None: | |
| files = list(iter_images([in_dir])) | |
| ensure_dir(out_dir) | |
| for p in tqdm(files, desc=f"Normalizing {in_dir.name}"): | |
| try: | |
| raw = p.read_bytes() | |
| norm = normalize_image_bytes(raw, max_side=max_side, jpeg_quality=jpeg_quality) | |
| rel = p.relative_to(in_dir) | |
| out_path = (out_dir / rel).with_suffix(".jpg") | |
| save_bytes(out_path, norm) | |
| except Exception: | |
| continue | |
| print(f"Normalized images written to: {out_dir}") | |
| def _next_index(dir_path: Path, prefix: str) -> int: | |
| if not dir_path.exists(): | |
| return 0 | |
| max_idx = -1 | |
| for p in dir_path.glob(f"{prefix}_*.bin"): | |
| stem = p.stem | |
| try: | |
| idx = int(stem.split("_")[-1]) | |
| except ValueError: | |
| continue | |
| max_idx = max(max_idx, idx) | |
| return max_idx + 1 | |
| def download_photos_open_images( | |
| out_dir: Path, | |
| n: int, | |
| seed: int, | |
| normalize: bool, | |
| max_side: int, | |
| jpeg_quality: int, | |
| reset: bool, | |
| ) -> None: | |
| from datasets import load_dataset | |
| ds = load_dataset("bitmind/open-images-v7", split="train", streaming=True) | |
| saved = 0 | |
| raw_dir = out_dir / "photos" / "raw" | |
| norm_dir = out_dir / "photos" / "normalized" | |
| if reset: | |
| for p in raw_dir.glob("openimages_*.bin"): | |
| p.unlink() | |
| for p in norm_dir.glob("openimages_*.jpg"): | |
| p.unlink() | |
| start_idx = _next_index(raw_dir, "openimages") | |
| for row in tqdm(ds, desc="Streaming Open Images V7"): | |
| if saved >= n: | |
| break | |
| url = row.get("image_url") or row.get("url") or row.get("imageUrl") or row.get("ImageURL") | |
| if not url: | |
| continue | |
| img = download_url(url) | |
| if not img: | |
| continue | |
| idx = start_idx + saved | |
| raw_name = f"openimages_{idx:06d}.bin" | |
| save_bytes(raw_dir / raw_name, img) | |
| if normalize: | |
| norm = normalize_image_bytes(img, max_side=max_side, jpeg_quality=jpeg_quality) | |
| norm_name = f"openimages_{idx:06d}.jpg" | |
| save_bytes(norm_dir / norm_name, norm) | |
| saved += 1 | |
| print(f"Saved {saved} images to: {raw_dir}") | |
| if normalize: | |
| print(f"Normalized images written to: {norm_dir}") | |
| def download_dance_x_dance( | |
| out_dir: Path, | |
| n: int, | |
| seed: int, | |
| normalize: bool, | |
| max_side: int, | |
| jpeg_quality: int, | |
| reset: bool, | |
| ) -> None: | |
| from datasets import load_dataset | |
| ds = load_dataset("MCG-NJU/X-Dance", split="train", streaming=True) | |
| raw_dir = out_dir / "dance" / "raw" | |
| norm_dir = out_dir / "dance" / "normalized" | |
| if reset: | |
| for p in raw_dir.glob("xdance_*.bin"): | |
| p.unlink() | |
| for p in norm_dir.glob("xdance_*.jpg"): | |
| p.unlink() | |
| start_idx = _next_index(raw_dir, "xdance") | |
| saved = 0 | |
| for row in tqdm(ds, desc="Streaming X-Dance"): | |
| if saved >= n: | |
| break | |
| img_obj = row.get("image") | |
| img_bytes: Optional[bytes] = None | |
| if img_obj is not None and hasattr(img_obj, "convert"): | |
| out = io.BytesIO() | |
| img_obj.convert("RGB").save(out, format="PNG") | |
| img_bytes = out.getvalue() | |
| else: | |
| url = row.get("image_url") or row.get("url") | |
| if url: | |
| img_bytes = download_url(url) | |
| if not img_bytes: | |
| continue | |
| idx = start_idx + saved | |
| raw_name = f"xdance_{idx:06d}.bin" | |
| save_bytes(raw_dir / raw_name, img_bytes) | |
| if normalize: | |
| norm = normalize_image_bytes(img_bytes, max_side=max_side, jpeg_quality=jpeg_quality) | |
| norm_name = f"xdance_{idx:06d}.jpg" | |
| save_bytes(norm_dir / norm_name, norm) | |
| saved += 1 | |
| print(f"Saved {saved} images to: {raw_dir}") | |
| if normalize: | |
| print(f"Normalized images written to: {norm_dir}") | |
| def run(cfg: Config) -> None: | |
| if cfg.normalize_only: | |
| if not cfg.in_dir: | |
| raise SystemExit("--in-dir is required with --normalize-only") | |
| normalize_existing(cfg.in_dir, cfg.out_dir, cfg.max_side, cfg.jpeg_quality) | |
| return | |
| ensure_dir(cfg.out_dir) | |
| if cfg.target == "photos": | |
| download_photos_open_images( | |
| out_dir=cfg.out_dir, | |
| n=cfg.n, | |
| seed=cfg.seed, | |
| normalize=cfg.normalize, | |
| max_side=cfg.max_side, | |
| jpeg_quality=cfg.jpeg_quality, | |
| reset=cfg.reset, | |
| ) | |
| else: | |
| download_dance_x_dance( | |
| out_dir=cfg.out_dir, | |
| n=cfg.n, | |
| seed=cfg.seed, | |
| normalize=cfg.normalize, | |
| max_side=cfg.max_side, | |
| jpeg_quality=cfg.jpeg_quality, | |
| reset=cfg.reset, | |
| ) | |
| def cli( | |
| out_dir: Path, | |
| target: str, | |
| n: int, | |
| seed: int, | |
| normalize: bool, | |
| max_side: int, | |
| jpeg_quality: int, | |
| normalize_only: bool, | |
| in_dir: Optional[Path], | |
| reset: bool, | |
| ) -> None: | |
| cfg = Config( | |
| out_dir=out_dir, | |
| target=target, | |
| n=n, | |
| seed=seed, | |
| normalize=normalize, | |
| max_side=max_side, | |
| jpeg_quality=jpeg_quality, | |
| normalize_only=normalize_only, | |
| in_dir=in_dir, | |
| reset=reset, | |
| ) | |
| run(cfg) | |
| if __name__ == "__main__": | |
| cli() | |