File size: 5,831 Bytes
3b9f466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b822ad
 
 
 
 
 
3b9f466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b822ad
3b9f466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b822ad
 
 
3b9f466
 
 
 
6b822ad
 
 
3b9f466
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import argparse
import os
from itertools import product
from pathlib import Path

import matplotlib.pyplot as plt
from PIL import Image

from pipeline import DEFAULT_NEGATIVE_PROMPT, DEFAULT_PROMPT, load_pipeline


PROMPT_PRESETS = {
    "default": DEFAULT_PROMPT,
    "soft": (
        "a stuffed animal toy transformed into a cute 32-bit pixel art character sprite, "
        "modern Korean mobile RPG style, soft rounded chibi proportions, "
        "large rounded head, plush body, tiny limbs, warm pastel colors, "
        "gentle 2-tone shading, thin dark brown outline, shiny dot eyes, "
        "tiny nose, soft smile, pink blush cheeks, white background, full body, front-facing"
    ),
    "sprite": (
        "front-facing full body stuffed animal character sprite, "
        "high quality 32-bit pixel art, cozy mobile RPG game asset, "
        "clean pixel edges, readable silhouette, compact chibi body, "
        "soft fur shading, warm colors, subtle highlights, thin brown outline, "
        "cute face, dot eyes with highlights, pink blush, white background"
    ),
    "reference": (
        "cute plush toy, high resolution pixel art character, large visible pixels, "
        "rounded chibi body, front-facing full body, warm soft shading, "
        "brown pixel outline, glossy dot eyes, tiny nose, small smile, "
        "pink cheek blush, clean white background"
    ),
}


def parse_csv_floats(value: str) -> list[float]:
    return [float(item.strip()) for item in value.split(",") if item.strip()]


def collect_images(input_path: str) -> list[Path]:
    path = Path(input_path)
    if path.is_file():
        return [path]

    extensions = {".jpg", ".jpeg", ".png", ".webp"}
    return sorted([p for p in path.iterdir() if p.suffix.lower() in extensions])


def save_grid(items: list[tuple[str, Image.Image]], output_path: Path, columns: int = 4):
    if not items:
        return

    rows = (len(items) + columns - 1) // columns
    fig, axes = plt.subplots(rows, columns, figsize=(columns * 4, rows * 4))

    if rows == 1:
        axes = [axes] if columns == 1 else axes
    else:
        axes = [ax for row in axes for ax in row]

    for ax, (title, image) in zip(axes, items):
        ax.imshow(image)
        ax.set_title(title, fontsize=8)
        ax.axis("off")

    for ax in axes[len(items):]:
        ax.axis("off")

    plt.tight_layout()
    fig.savefig(output_path, dpi=130)
    plt.close(fig)


def main():
    parser = argparse.ArgumentParser(description="Run a parameter grid for Mongle 32-bit LoRA.")
    parser.add_argument("--input", default="image", help="Input image file or folder.")
    parser.add_argument("--output", default="outputs/grid_test", help="Output folder.")
    parser.add_argument("--lora-path", default=os.getenv("LORA_PATH", "."), help="LoRA repo/path.")
    parser.add_argument("--strengths", default="0.45,0.55,0.65,0.75")
    parser.add_argument("--controlnet-scales", default="0.6,0.8,1.0")
    parser.add_argument("--guidance-scales", default="7.5")
    parser.add_argument("--steps", type=int, default=30)
    parser.add_argument("--prompt-presets", default="default,soft,sprite")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--limit", type=int, default=0, help="Limit number of input images. 0 means all.")
    parser.add_argument("--save-individual", action="store_true", help="Also save every generated image.")
    args = parser.parse_args()

    images = collect_images(args.input)
    if args.limit > 0:
        images = images[: args.limit]
    if not images:
        raise SystemExit(f"No input images found: {args.input}")

    strengths = parse_csv_floats(args.strengths)
    controlnet_scales = parse_csv_floats(args.controlnet_scales)
    guidance_scales = parse_csv_floats(args.guidance_scales)
    prompt_keys = [key.strip() for key in args.prompt_presets.split(",") if key.strip()]

    unknown = [key for key in prompt_keys if key not in PROMPT_PRESETS]
    if unknown:
        raise SystemExit(f"Unknown prompt preset(s): {', '.join(unknown)}")

    output_dir = Path(args.output)
    output_dir.mkdir(parents=True, exist_ok=True)

    pipe = load_pipeline(args.lora_path)

    for image_path in images:
        original = Image.open(image_path).convert("RGB")
        stem = image_path.stem
        image_output_dir = output_dir / stem
        image_output_dir.mkdir(parents=True, exist_ok=True)

        grid_items = [("input", original.resize((512, 512)))]

        for prompt_key, strength, controlnet_scale, guidance_scale in product(
            prompt_keys,
            strengths,
            controlnet_scales,
            guidance_scales,
        ):
            result = pipe(
                original,
                prompt=PROMPT_PRESETS[prompt_key],
                negative_prompt=DEFAULT_NEGATIVE_PROMPT,
                num_inference_steps=args.steps,
                guidance_scale=guidance_scale,
                controlnet_conditioning_scale=controlnet_scale,
                strength=strength,
                seed=args.seed,
            )

            filename = (
                f"{stem}_{prompt_key}"
                f"_st{strength:.2f}_cn{controlnet_scale:.2f}_g{guidance_scale:.1f}.png"
            )
            output_path = image_output_dir / filename
            if args.save_individual:
                result["image"].save(output_path)
                print(f"saved {output_path}")

            title = f"{prompt_key}\nst={strength:.2f} cn={controlnet_scale:.2f} g={guidance_scale:.1f}"
            grid_items.append((title, result["image"]))

        grid_path = image_output_dir / f"{stem}_grid.png"
        save_grid(grid_items, grid_path, columns=4)
        print(f"saved grid {grid_path}")

    print(f"done: {output_dir}")


if __name__ == "__main__":
    main()