mongle-lora-v3-32bit / test_grid.py
Hadimeeee's picture
Upload test_grid.py with huggingface_hub
6b822ad verified
import argparse
import os
from itertools import product
from pathlib import Path
import matplotlib.pyplot as plt
from PIL import Image
from pipeline import DEFAULT_NEGATIVE_PROMPT, DEFAULT_PROMPT, load_pipeline
PROMPT_PRESETS = {
"default": DEFAULT_PROMPT,
"soft": (
"a stuffed animal toy transformed into a cute 32-bit pixel art character sprite, "
"modern Korean mobile RPG style, soft rounded chibi proportions, "
"large rounded head, plush body, tiny limbs, warm pastel colors, "
"gentle 2-tone shading, thin dark brown outline, shiny dot eyes, "
"tiny nose, soft smile, pink blush cheeks, white background, full body, front-facing"
),
"sprite": (
"front-facing full body stuffed animal character sprite, "
"high quality 32-bit pixel art, cozy mobile RPG game asset, "
"clean pixel edges, readable silhouette, compact chibi body, "
"soft fur shading, warm colors, subtle highlights, thin brown outline, "
"cute face, dot eyes with highlights, pink blush, white background"
),
"reference": (
"cute plush toy, high resolution pixel art character, large visible pixels, "
"rounded chibi body, front-facing full body, warm soft shading, "
"brown pixel outline, glossy dot eyes, tiny nose, small smile, "
"pink cheek blush, clean white background"
),
}
def parse_csv_floats(value: str) -> list[float]:
return [float(item.strip()) for item in value.split(",") if item.strip()]
def collect_images(input_path: str) -> list[Path]:
path = Path(input_path)
if path.is_file():
return [path]
extensions = {".jpg", ".jpeg", ".png", ".webp"}
return sorted([p for p in path.iterdir() if p.suffix.lower() in extensions])
def save_grid(items: list[tuple[str, Image.Image]], output_path: Path, columns: int = 4):
if not items:
return
rows = (len(items) + columns - 1) // columns
fig, axes = plt.subplots(rows, columns, figsize=(columns * 4, rows * 4))
if rows == 1:
axes = [axes] if columns == 1 else axes
else:
axes = [ax for row in axes for ax in row]
for ax, (title, image) in zip(axes, items):
ax.imshow(image)
ax.set_title(title, fontsize=8)
ax.axis("off")
for ax in axes[len(items):]:
ax.axis("off")
plt.tight_layout()
fig.savefig(output_path, dpi=130)
plt.close(fig)
def main():
parser = argparse.ArgumentParser(description="Run a parameter grid for Mongle 32-bit LoRA.")
parser.add_argument("--input", default="image", help="Input image file or folder.")
parser.add_argument("--output", default="outputs/grid_test", help="Output folder.")
parser.add_argument("--lora-path", default=os.getenv("LORA_PATH", "."), help="LoRA repo/path.")
parser.add_argument("--strengths", default="0.45,0.55,0.65,0.75")
parser.add_argument("--controlnet-scales", default="0.6,0.8,1.0")
parser.add_argument("--guidance-scales", default="7.5")
parser.add_argument("--steps", type=int, default=30)
parser.add_argument("--prompt-presets", default="default,soft,sprite")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--limit", type=int, default=0, help="Limit number of input images. 0 means all.")
parser.add_argument("--save-individual", action="store_true", help="Also save every generated image.")
args = parser.parse_args()
images = collect_images(args.input)
if args.limit > 0:
images = images[: args.limit]
if not images:
raise SystemExit(f"No input images found: {args.input}")
strengths = parse_csv_floats(args.strengths)
controlnet_scales = parse_csv_floats(args.controlnet_scales)
guidance_scales = parse_csv_floats(args.guidance_scales)
prompt_keys = [key.strip() for key in args.prompt_presets.split(",") if key.strip()]
unknown = [key for key in prompt_keys if key not in PROMPT_PRESETS]
if unknown:
raise SystemExit(f"Unknown prompt preset(s): {', '.join(unknown)}")
output_dir = Path(args.output)
output_dir.mkdir(parents=True, exist_ok=True)
pipe = load_pipeline(args.lora_path)
for image_path in images:
original = Image.open(image_path).convert("RGB")
stem = image_path.stem
image_output_dir = output_dir / stem
image_output_dir.mkdir(parents=True, exist_ok=True)
grid_items = [("input", original.resize((512, 512)))]
for prompt_key, strength, controlnet_scale, guidance_scale in product(
prompt_keys,
strengths,
controlnet_scales,
guidance_scales,
):
result = pipe(
original,
prompt=PROMPT_PRESETS[prompt_key],
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
num_inference_steps=args.steps,
guidance_scale=guidance_scale,
controlnet_conditioning_scale=controlnet_scale,
strength=strength,
seed=args.seed,
)
filename = (
f"{stem}_{prompt_key}"
f"_st{strength:.2f}_cn{controlnet_scale:.2f}_g{guidance_scale:.1f}.png"
)
output_path = image_output_dir / filename
if args.save_individual:
result["image"].save(output_path)
print(f"saved {output_path}")
title = f"{prompt_key}\nst={strength:.2f} cn={controlnet_scale:.2f} g={guidance_scale:.1f}"
grid_items.append((title, result["image"]))
grid_path = image_output_dir / f"{stem}_grid.png"
save_grid(grid_items, grid_path, columns=4)
print(f"saved grid {grid_path}")
print(f"done: {output_dir}")
if __name__ == "__main__":
main()