Spaces:
Running
Running
| """ | |
| scripts/prepare_cocoxgen.py | |
| Prepares COCO-XGen dataset for fingerprint engine training. | |
| Maps COCO real images against XL-generated equivalents for generator attribution training. | |
| Kaggle usage: | |
| !python scripts/prepare_cocoxgen.py \ | |
| --real_source /kaggle/input/coco-2017-dataset/coco2017/val2017 \ | |
| --fake_source /kaggle/input/coco-xgen-synthetic \ | |
| --output /kaggle/working/processed/fingerprint \ | |
| --max 15000 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import logging | |
| import random | |
| import shutil | |
| from pathlib import Path | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") | |
| log = logging.getLogger(__name__) | |
| IMG_EXTS = {".jpg", ".jpeg", ".png"} | |
| def copy_subset(src_dir: Path, dst_dir: Path, max_n: int, prefix: str, rng: random.Random) -> int: | |
| dst_dir.mkdir(parents=True, exist_ok=True) | |
| imgs = [p for p in src_dir.rglob("*") if p.suffix.lower() in IMG_EXTS] | |
| rng.shuffle(imgs) | |
| imgs = imgs[:max_n] | |
| for img in imgs: | |
| dst = dst_dir / f"{prefix}_{img.name}" | |
| if not dst.exists(): | |
| shutil.copy2(img, dst) | |
| return len(imgs) | |
| def main(args: argparse.Namespace) -> None: | |
| rng = random.Random(args.seed) | |
| splits = {"train": int(args.max * 0.9), "val": int(args.max * 0.1)} | |
| for split, max_n in splits.items(): | |
| if args.real_source and Path(args.real_source).exists(): | |
| n = copy_subset( | |
| Path(args.real_source), | |
| Path(args.output) / split / "real", | |
| max_n, "coco", rng | |
| ) | |
| log.info(f" real/{split}: {n} images") | |
| if args.fake_source and Path(args.fake_source).exists(): | |
| n = copy_subset( | |
| Path(args.fake_source), | |
| Path(args.output) / split / "fake", | |
| max_n, "xgen", rng | |
| ) | |
| log.info(f" fake/{split}: {n} images (generator: stable_diffusion/xl)") | |
| log.info("COCO-XGen preparation complete.") | |
| def parse_args(): | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--real_source", default=None) | |
| p.add_argument("--fake_source", default=None) | |
| p.add_argument("--output", default="/kaggle/working/processed/fingerprint") | |
| p.add_argument("--max", type=int, default=15000) | |
| p.add_argument("--seed", type=int, default=42) | |
| return p.parse_args() | |
| if __name__ == "__main__": | |
| main(parse_args()) | |