""" scripts/prepare_cocoxgen.py Prepares COCO-XGen dataset for fingerprint engine training. Maps COCO real images against XL-generated equivalents for generator attribution training. Kaggle usage: !python scripts/prepare_cocoxgen.py \ --real_source /kaggle/input/coco-2017-dataset/coco2017/val2017 \ --fake_source /kaggle/input/coco-xgen-synthetic \ --output /kaggle/working/processed/fingerprint \ --max 15000 """ from __future__ import annotations import argparse import logging import random import shutil from pathlib import Path logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") log = logging.getLogger(__name__) IMG_EXTS = {".jpg", ".jpeg", ".png"} def copy_subset(src_dir: Path, dst_dir: Path, max_n: int, prefix: str, rng: random.Random) -> int: dst_dir.mkdir(parents=True, exist_ok=True) imgs = [p for p in src_dir.rglob("*") if p.suffix.lower() in IMG_EXTS] rng.shuffle(imgs) imgs = imgs[:max_n] for img in imgs: dst = dst_dir / f"{prefix}_{img.name}" if not dst.exists(): shutil.copy2(img, dst) return len(imgs) def main(args: argparse.Namespace) -> None: rng = random.Random(args.seed) splits = {"train": int(args.max * 0.9), "val": int(args.max * 0.1)} for split, max_n in splits.items(): if args.real_source and Path(args.real_source).exists(): n = copy_subset( Path(args.real_source), Path(args.output) / split / "real", max_n, "coco", rng ) log.info(f" real/{split}: {n} images") if args.fake_source and Path(args.fake_source).exists(): n = copy_subset( Path(args.fake_source), Path(args.output) / split / "fake", max_n, "xgen", rng ) log.info(f" fake/{split}: {n} images (generator: stable_diffusion/xl)") log.info("COCO-XGen preparation complete.") def parse_args(): p = argparse.ArgumentParser() p.add_argument("--real_source", default=None) p.add_argument("--fake_source", default=None) p.add_argument("--output", default="/kaggle/working/processed/fingerprint") p.add_argument("--max", type=int, default=15000) p.add_argument("--seed", type=int, default=42) return p.parse_args() if __name__ == "__main__": main(parse_args())