deepdetection / scripts /prepare_cocoxgen.py
akagtag's picture
Initial commit
4e75170
"""
scripts/prepare_cocoxgen.py
Prepares COCO-XGen dataset for fingerprint engine training.
Maps COCO real images against XL-generated equivalents for generator attribution training.
Kaggle usage:
!python scripts/prepare_cocoxgen.py \
--real_source /kaggle/input/coco-2017-dataset/coco2017/val2017 \
--fake_source /kaggle/input/coco-xgen-synthetic \
--output /kaggle/working/processed/fingerprint \
--max 15000
"""
from __future__ import annotations
import argparse
import logging
import random
import shutil
from pathlib import Path
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
log = logging.getLogger(__name__)
IMG_EXTS = {".jpg", ".jpeg", ".png"}
def copy_subset(src_dir: Path, dst_dir: Path, max_n: int, prefix: str, rng: random.Random) -> int:
dst_dir.mkdir(parents=True, exist_ok=True)
imgs = [p for p in src_dir.rglob("*") if p.suffix.lower() in IMG_EXTS]
rng.shuffle(imgs)
imgs = imgs[:max_n]
for img in imgs:
dst = dst_dir / f"{prefix}_{img.name}"
if not dst.exists():
shutil.copy2(img, dst)
return len(imgs)
def main(args: argparse.Namespace) -> None:
rng = random.Random(args.seed)
splits = {"train": int(args.max * 0.9), "val": int(args.max * 0.1)}
for split, max_n in splits.items():
if args.real_source and Path(args.real_source).exists():
n = copy_subset(
Path(args.real_source),
Path(args.output) / split / "real",
max_n, "coco", rng
)
log.info(f" real/{split}: {n} images")
if args.fake_source and Path(args.fake_source).exists():
n = copy_subset(
Path(args.fake_source),
Path(args.output) / split / "fake",
max_n, "xgen", rng
)
log.info(f" fake/{split}: {n} images (generator: stable_diffusion/xl)")
log.info("COCO-XGen preparation complete.")
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--real_source", default=None)
p.add_argument("--fake_source", default=None)
p.add_argument("--output", default="/kaggle/working/processed/fingerprint")
p.add_argument("--max", type=int, default=15000)
p.add_argument("--seed", type=int, default=42)
return p.parse_args()
if __name__ == "__main__":
main(parse_args())