| import argparse |
| import os |
| from itertools import product |
| from pathlib import Path |
|
|
| import matplotlib.pyplot as plt |
| from PIL import Image |
|
|
| from pipeline import DEFAULT_NEGATIVE_PROMPT, DEFAULT_PROMPT, load_pipeline |
|
|
|
|
| PROMPT_PRESETS = { |
| "default": DEFAULT_PROMPT, |
| "soft": ( |
| "a stuffed animal toy transformed into a cute 32-bit pixel art character sprite, " |
| "modern Korean mobile RPG style, soft rounded chibi proportions, " |
| "large rounded head, plush body, tiny limbs, warm pastel colors, " |
| "gentle 2-tone shading, thin dark brown outline, shiny dot eyes, " |
| "tiny nose, soft smile, pink blush cheeks, white background, full body, front-facing" |
| ), |
| "sprite": ( |
| "front-facing full body stuffed animal character sprite, " |
| "high quality 32-bit pixel art, cozy mobile RPG game asset, " |
| "clean pixel edges, readable silhouette, compact chibi body, " |
| "soft fur shading, warm colors, subtle highlights, thin brown outline, " |
| "cute face, dot eyes with highlights, pink blush, white background" |
| ), |
| "reference": ( |
| "cute plush toy, high resolution pixel art character, large visible pixels, " |
| "rounded chibi body, front-facing full body, warm soft shading, " |
| "brown pixel outline, glossy dot eyes, tiny nose, small smile, " |
| "pink cheek blush, clean white background" |
| ), |
| } |
|
|
|
|
| def parse_csv_floats(value: str) -> list[float]: |
| return [float(item.strip()) for item in value.split(",") if item.strip()] |
|
|
|
|
| def collect_images(input_path: str) -> list[Path]: |
| path = Path(input_path) |
| if path.is_file(): |
| return [path] |
|
|
| extensions = {".jpg", ".jpeg", ".png", ".webp"} |
| return sorted([p for p in path.iterdir() if p.suffix.lower() in extensions]) |
|
|
|
|
| def save_grid(items: list[tuple[str, Image.Image]], output_path: Path, columns: int = 4): |
| if not items: |
| return |
|
|
| rows = (len(items) + columns - 1) // columns |
| fig, axes = plt.subplots(rows, columns, figsize=(columns * 4, rows * 4)) |
|
|
| if rows == 1: |
| axes = [axes] if columns == 1 else axes |
| else: |
| axes = [ax for row in axes for ax in row] |
|
|
| for ax, (title, image) in zip(axes, items): |
| ax.imshow(image) |
| ax.set_title(title, fontsize=8) |
| ax.axis("off") |
|
|
| for ax in axes[len(items):]: |
| ax.axis("off") |
|
|
| plt.tight_layout() |
| fig.savefig(output_path, dpi=130) |
| plt.close(fig) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Run a parameter grid for Mongle 32-bit LoRA.") |
| parser.add_argument("--input", default="image", help="Input image file or folder.") |
| parser.add_argument("--output", default="outputs/grid_test", help="Output folder.") |
| parser.add_argument("--lora-path", default=os.getenv("LORA_PATH", "."), help="LoRA repo/path.") |
| parser.add_argument("--strengths", default="0.45,0.55,0.65,0.75") |
| parser.add_argument("--controlnet-scales", default="0.6,0.8,1.0") |
| parser.add_argument("--guidance-scales", default="7.5") |
| parser.add_argument("--steps", type=int, default=30) |
| parser.add_argument("--prompt-presets", default="default,soft,sprite") |
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--limit", type=int, default=0, help="Limit number of input images. 0 means all.") |
| parser.add_argument("--save-individual", action="store_true", help="Also save every generated image.") |
| args = parser.parse_args() |
|
|
| images = collect_images(args.input) |
| if args.limit > 0: |
| images = images[: args.limit] |
| if not images: |
| raise SystemExit(f"No input images found: {args.input}") |
|
|
| strengths = parse_csv_floats(args.strengths) |
| controlnet_scales = parse_csv_floats(args.controlnet_scales) |
| guidance_scales = parse_csv_floats(args.guidance_scales) |
| prompt_keys = [key.strip() for key in args.prompt_presets.split(",") if key.strip()] |
|
|
| unknown = [key for key in prompt_keys if key not in PROMPT_PRESETS] |
| if unknown: |
| raise SystemExit(f"Unknown prompt preset(s): {', '.join(unknown)}") |
|
|
| output_dir = Path(args.output) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| pipe = load_pipeline(args.lora_path) |
|
|
| for image_path in images: |
| original = Image.open(image_path).convert("RGB") |
| stem = image_path.stem |
| image_output_dir = output_dir / stem |
| image_output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| grid_items = [("input", original.resize((512, 512)))] |
|
|
| for prompt_key, strength, controlnet_scale, guidance_scale in product( |
| prompt_keys, |
| strengths, |
| controlnet_scales, |
| guidance_scales, |
| ): |
| result = pipe( |
| original, |
| prompt=PROMPT_PRESETS[prompt_key], |
| negative_prompt=DEFAULT_NEGATIVE_PROMPT, |
| num_inference_steps=args.steps, |
| guidance_scale=guidance_scale, |
| controlnet_conditioning_scale=controlnet_scale, |
| strength=strength, |
| seed=args.seed, |
| ) |
|
|
| filename = ( |
| f"{stem}_{prompt_key}" |
| f"_st{strength:.2f}_cn{controlnet_scale:.2f}_g{guidance_scale:.1f}.png" |
| ) |
| output_path = image_output_dir / filename |
| if args.save_individual: |
| result["image"].save(output_path) |
| print(f"saved {output_path}") |
|
|
| title = f"{prompt_key}\nst={strength:.2f} cn={controlnet_scale:.2f} g={guidance_scale:.1f}" |
| grid_items.append((title, result["image"])) |
|
|
| grid_path = image_output_dir / f"{stem}_grid.png" |
| save_grid(grid_items, grid_path, columns=4) |
| print(f"saved grid {grid_path}") |
|
|
| print(f"done: {output_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|