File size: 2,430 Bytes
4e75170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
scripts/prepare_cocoxgen.py

Prepares COCO-XGen dataset for fingerprint engine training.
Maps COCO real images against XL-generated equivalents for generator attribution training.

Kaggle usage:
  !python scripts/prepare_cocoxgen.py \
    --real_source  /kaggle/input/coco-2017-dataset/coco2017/val2017 \
    --fake_source  /kaggle/input/coco-xgen-synthetic \
    --output       /kaggle/working/processed/fingerprint \
    --max          15000
"""
from __future__ import annotations

import argparse
import logging
import random
import shutil
from pathlib import Path

logging.basicConfig(level=logging.INFO, format="%(asctime)s  %(message)s")
log = logging.getLogger(__name__)

IMG_EXTS = {".jpg", ".jpeg", ".png"}


def copy_subset(src_dir: Path, dst_dir: Path, max_n: int, prefix: str, rng: random.Random) -> int:
    dst_dir.mkdir(parents=True, exist_ok=True)
    imgs = [p for p in src_dir.rglob("*") if p.suffix.lower() in IMG_EXTS]
    rng.shuffle(imgs)
    imgs = imgs[:max_n]
    for img in imgs:
        dst = dst_dir / f"{prefix}_{img.name}"
        if not dst.exists():
            shutil.copy2(img, dst)
    return len(imgs)


def main(args: argparse.Namespace) -> None:
    rng = random.Random(args.seed)

    splits = {"train": int(args.max * 0.9), "val": int(args.max * 0.1)}

    for split, max_n in splits.items():
        if args.real_source and Path(args.real_source).exists():
            n = copy_subset(
                Path(args.real_source),
                Path(args.output) / split / "real",
                max_n, "coco", rng
            )
            log.info(f"  real/{split}: {n} images")

        if args.fake_source and Path(args.fake_source).exists():
            n = copy_subset(
                Path(args.fake_source),
                Path(args.output) / split / "fake",
                max_n, "xgen", rng
            )
            log.info(f"  fake/{split}: {n} images (generator: stable_diffusion/xl)")

    log.info("COCO-XGen preparation complete.")


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--real_source", default=None)
    p.add_argument("--fake_source", default=None)
    p.add_argument("--output",      default="/kaggle/working/processed/fingerprint")
    p.add_argument("--max",         type=int, default=15000)
    p.add_argument("--seed",        type=int, default=42)
    return p.parse_args()


if __name__ == "__main__":
    main(parse_args())