SPM / app.py
prasannareddyp's picture
Update app.py
7f1eca4 verified
import gradio as gr
from PIL import Image, ImageDraw
import numpy as np
import os, zipfile, tempfile, time
from spm import spm_augment
TITLE = "Shuffle PatchMix (SPM) Augmentation"
DESC = """
Upload an image, choose **number of patches (N×N)**, and generate SPM-augmented variants.
Optionally enable **overlap (as % of patch size)** with feathered blending for smooth seams.
For batch processing, upload a .zip of images (PNG/JPG/JPEG) and download the outputs as a .zip.
"""
EXAMPLES_DIR = "examples"
CREATE_DEFAULTS_IF_EMPTY = True # set False if you never want auto-generated examples
# ---------- Examples handling ----------
def _make_default_examples():
os.makedirs(EXAMPLES_DIR, exist_ok=True)
# 1) Checkerboard
cb_path = os.path.join(EXAMPLES_DIR, "checkerboard.png")
if not os.path.exists(cb_path):
cb = Image.new("RGB", (512, 512), "white")
draw = ImageDraw.Draw(cb)
tile = 64
for y in range(0, 512, tile):
for x in range(0, 512, tile):
if (x//tile + y//tile) % 2 == 0:
draw.rectangle([x, y, x+tile-1, y+tile-1], fill=(30, 30, 30))
cb.save(cb_path)
# 2) Gradient
grad_path = os.path.join(EXAMPLES_DIR, "gradient.png")
if not os.path.exists(grad_path):
arr = np.zeros((360, 640, 3), dtype=np.uint8)
for x in range(640):
arr[:, x, 0] = int(255 * x / 639)
for y in range(360):
arr[y, :, 1] = int(255 * y / 359)
arr[:, :, 2] = 160
Image.fromarray(arr).save(grad_path)
# 3) Shapes
shapes_path = os.path.join(EXAMPLES_DIR, "shapes.png")
if not os.path.exists(shapes_path):
sh = Image.new("RGB", (512, 384), "white")
d = ImageDraw.Draw(sh)
colors = [(220,20,60),(65,105,225),(60,179,113),(255,165,0),(148,0,211)]
for i,c in enumerate(colors):
d.rectangle([20+90*i, 30, 80+90*i, 180], fill=c, outline=(0,0,0), width=3)
for i in range(6):
d.ellipse([40+80*i, 200, 90+80*i, 350], fill=colors[i%len(colors)], outline=(0,0,0), width=3)
sh.save(shapes_path)
def _list_example_images():
"""Return [[path], [path], ...] for all images under examples/ (recursive)."""
exts = {".png", ".jpg", ".jpeg", ".bmp", ".webp"}
items = []
if os.path.isdir(EXAMPLES_DIR):
for root, _, files in os.walk(EXAMPLES_DIR):
for f in files:
if os.path.splitext(f)[1].lower() in exts:
items.append([os.path.join(root, f)])
# sort by path for stable order
items.sort(key=lambda x: x[0].lower())
return items
def _get_examples():
items = _list_example_images()
if not items and CREATE_DEFAULTS_IF_EMPTY:
_make_default_examples()
items = _list_example_images()
return items
# ---------- App logic ----------
def _parse_grid(grid_choice: str) -> int:
# Expect strings like "2x2", "4x4", "8x8", "16x16"
try:
n = int(grid_choice.lower().split("x")[0])
return max(1, n)
except Exception:
return 4
def run_single(image, grid_choice, use_overlap, overlap_pct, mix_prob, beta_a, beta_b, num_augs, seed):
if image is None:
return []
outs = []
base_seed = int(seed) if seed is not None else None
N = _parse_grid(grid_choice)
pct = float(overlap_pct) if use_overlap else 0.0
for i in range(num_augs):
s = (base_seed + i) if base_seed is not None else None
out_img = spm_augment(
image,
num_patches=N,
mix_prob=float(mix_prob),
beta_a=float(beta_a),
beta_b=float(beta_b),
overlap_pct=pct,
seed=s
)
outs.append(out_img)
return outs
def run_batch(zip_file, grid_choice, use_overlap, overlap_pct, mix_prob, beta_a, beta_b, seed):
if zip_file is None:
return None, "Please upload a .zip file with images."
tempdir = tempfile.mkdtemp()
outdir = os.path.join(tempdir, "outputs")
os.makedirs(outdir, exist_ok=True)
with zipfile.ZipFile(zip_file, 'r') as zf:
zf.extractall(tempdir)
valid_exts = {".png", ".jpg", ".jpeg"}
count_in, count_out = 0, 0
N = _parse_grid(grid_choice)
pct = float(overlap_pct) if use_overlap else 0.0
for root_dir, _, files in os.walk(tempdir):
for f in files:
if f.lower().endswith(tuple(valid_exts)):
in_path = os.path.join(root_dir, f)
try:
img = Image.open(in_path).convert("RGB")
except Exception:
continue
count_in += 1
out_img = spm_augment(
img,
num_patches=N,
mix_prob=float(mix_prob),
beta_a=float(beta_a),
beta_b=float(beta_b),
overlap_pct=pct,
seed=int(seed) if seed is not None else None
)
rel = os.path.relpath(in_path, tempdir)
out_path = os.path.join(outdir, rel)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
out_img.save(out_path)
count_out += 1
out_zip = os.path.join(tempdir, f"spm_outputs_{int(time.time())}.zip")
with zipfile.ZipFile(out_zip, "w", compression=zipfile.ZIP_DEFLATED) as zf:
for root_dir, _, files in os.walk(outdir):
for f in files:
p = os.path.join(root_dir, f)
arc = os.path.relpath(p, outdir)
zf.write(p, arcname=arc)
msg = f"Processed {count_out}/{count_in} files."
return out_zip, msg
# ---------- UI ----------
with gr.Blocks() as demo:
gr.Markdown(f"# {TITLE}")
gr.Markdown(DESC)
examples = _get_examples()
with gr.Tabs():
with gr.TabItem("Single Image"):
with gr.Row():
with gr.Column(scale=1):
inp = gr.Image(label="Input image", type="pil")
gr.Examples(examples, inputs=[inp], label="Try these")
grid_choice = gr.Radio(choices=["2x2","4x4","8x8","16x16"], value="8x8", label="Grid (N×N)")
use_overlap = gr.Checkbox(value=True, label="Enable Overlap Patch Blend")
overlap_pct = gr.Slider(0, 49, value=20, step=1, label="Overlap (% of patch)")
mix_prob = gr.Slider(0, 1, value=0.8, step=0.05, label="Mix probability (per patch)")
with gr.Row():
beta_a = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta(α, β), α =")
beta_b = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta(α, β), β =")
num_augs = gr.Slider(1, 12, value=4, step=1, label="Number of variants")
seed = gr.Number(value=42, precision=0, label="Seed (int, optional)")
run_btn = gr.Button("Generate")
with gr.Column(scale=1):
gallery = gr.Gallery(label="Augmented outputs", columns=2, height="auto")
run_btn.click(
fn=run_single,
inputs=[inp, grid_choice, use_overlap, overlap_pct, mix_prob, beta_a, beta_b, num_augs, seed],
outputs=[gallery]
)
with gr.TabItem("Batch (.zip)"):
with gr.Row():
with gr.Column(scale=1):
zip_in = gr.File(label="Upload a .zip of images", file_types=[".zip"])
grid_choice_b = gr.Radio(choices=["2x2","4x4","8x8","16x16"], value="8x8", label="Grid (N×N)")
use_overlap_b = gr.Checkbox(value=True, label="Enable Overlap Patch Blend")
overlap_pct_b = gr.Slider(0, 49, value=20, step=1, label="Overlap (% of patch)")
mix_prob_b = gr.Slider(0, 1, value=0.8, step=0.05, label="Mix probability (per patch)")
with gr.Row():
beta_a_b = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta(α, β), α =")
beta_b_b = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta(α, β), β =")
seed_b = gr.Number(value=42, precision=0, label="Seed (int, optional)")
run_b = gr.Button("Process Zip")
with gr.Column(scale=1):
zip_out = gr.File(label="Download results (.zip)")
status = gr.Markdown()
run_b.click(
fn=run_batch,
inputs=[zip_in, grid_choice_b, use_overlap_b, overlap_pct_b, mix_prob_b, beta_a_b, beta_b_b, seed_b],
outputs=[zip_out, status]
)
if __name__ == "__main__":
demo.launch()