File size: 4,241 Bytes
7c72c12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efbbd1d
 
 
 
 
 
7c72c12
 
 
efbbd1d
 
7c72c12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efbbd1d
7c72c12
 
 
 
efbbd1d
 
 
7c72c12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import argparse
import json
import shutil
from pathlib import Path

from gradio_client import Client, handle_file


VALID_SUFFIXES = {".jpg", ".jpeg", ".png", ".webp"}


def iter_images(root: Path):
    for path in sorted(root.rglob("*")):
        if path.is_file() and path.suffix.lower() in VALID_SUFFIXES:
            yield path


def main():
    parser = argparse.ArgumentParser(description="Batch client for the Catalog Creator Space.")
    parser.add_argument("--space", required=True, help="Space id, for example username/CatalogCreator")
    parser.add_argument("--input-dir", required=True, help="Directory of source images")
    parser.add_argument("--output-dir", required=True, help="Directory for processed files")
    parser.add_argument("--mode", choices=["extract_outfit", "cutout"], default="extract_outfit")
    parser.add_argument(
        "--extract-prompt",
        default="Extract the clothing and create a flat mockup.",
        help="Prompt used only for extract_outfit mode",
    )
    parser.add_argument("--output-size", type=int, default=1400)
    parser.add_argument("--padding-percent", type=int, default=8)
    parser.add_argument("--alpha-threshold", type=int, default=8)
    parser.add_argument("--num-inference-steps", type=int, default=24)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--skip-existing", action="store_true")
    args = parser.parse_args()

    input_dir = Path(args.input_dir).resolve()
    output_dir = Path(args.output_dir).resolve()
    output_dir.mkdir(parents=True, exist_ok=True)

    client = Client(args.space)
    manifest_path = output_dir / "manifest.jsonl"

    for image_path in iter_images(input_dir):
        relative_parent = image_path.parent.relative_to(input_dir)
        target_dir = output_dir / relative_parent
        target_dir.mkdir(parents=True, exist_ok=True)

        transparent_target = target_dir / f"{image_path.stem}_transparent.png"
        square_target = target_dir / f"{image_path.stem}_catalog.png"
        preview_target = target_dir / f"{image_path.stem}_preview.png"
        metadata_target = target_dir / f"{image_path.stem}.json"

        if args.skip_existing and transparent_target.exists() and square_target.exists() and metadata_target.exists():
            continue

        try:
            _, _, transparent_file, square_file, preview_file, metadata = client.predict(
                image=handle_file(str(image_path)),
                mode=args.mode,
                output_size=args.output_size,
                padding_percent=args.padding_percent,
                alpha_threshold=args.alpha_threshold,
                crop_to_subject=True,
                extract_prompt=args.extract_prompt,
                num_inference_steps=args.num_inference_steps,
                seed=args.seed,
                api_name="/catalog_prep",
            )

            shutil.copy2(transparent_file, transparent_target)
            shutil.copy2(square_file, square_target)
            shutil.copy2(preview_file, preview_target)
            metadata_target.write_text(json.dumps(metadata, indent=2), encoding="utf-8")

            with manifest_path.open("a", encoding="utf-8") as manifest:
                manifest.write(
                    json.dumps(
                        {
                            "source": str(image_path),
                            "transparent_png": str(transparent_target),
                            "square_png": str(square_target),
                            "preview_png": str(preview_target),
                            "metadata_json": str(metadata_target),
                            "status": "ok",
                        }
                    )
                    + "\n"
                )
        except Exception as exc:
            with manifest_path.open("a", encoding="utf-8") as manifest:
                manifest.write(
                    json.dumps(
                        {
                            "source": str(image_path),
                            "status": "error",
                            "error": str(exc),
                        }
                    )
                    + "\n"
                )


if __name__ == "__main__":
    main()