Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from PIL import Image, ImageDraw | |
| import numpy as np | |
| import os, zipfile, tempfile, time | |
| from spm import spm_augment | |
| TITLE = "Shuffle PatchMix (SPM) Augmentation" | |
| DESC = """ | |
| Upload an image, choose **number of patches (N×N)**, and generate SPM-augmented variants. | |
| Optionally enable **overlap (as % of patch size)** with feathered blending for smooth seams. | |
| For batch processing, upload a .zip of images (PNG/JPG/JPEG) and download the outputs as a .zip. | |
| """ | |
| EXAMPLES_DIR = "examples" | |
| CREATE_DEFAULTS_IF_EMPTY = True # set False if you never want auto-generated examples | |
| # ---------- Examples handling ---------- | |
| def _make_default_examples(): | |
| os.makedirs(EXAMPLES_DIR, exist_ok=True) | |
| # 1) Checkerboard | |
| cb_path = os.path.join(EXAMPLES_DIR, "checkerboard.png") | |
| if not os.path.exists(cb_path): | |
| cb = Image.new("RGB", (512, 512), "white") | |
| draw = ImageDraw.Draw(cb) | |
| tile = 64 | |
| for y in range(0, 512, tile): | |
| for x in range(0, 512, tile): | |
| if (x//tile + y//tile) % 2 == 0: | |
| draw.rectangle([x, y, x+tile-1, y+tile-1], fill=(30, 30, 30)) | |
| cb.save(cb_path) | |
| # 2) Gradient | |
| grad_path = os.path.join(EXAMPLES_DIR, "gradient.png") | |
| if not os.path.exists(grad_path): | |
| arr = np.zeros((360, 640, 3), dtype=np.uint8) | |
| for x in range(640): | |
| arr[:, x, 0] = int(255 * x / 639) | |
| for y in range(360): | |
| arr[y, :, 1] = int(255 * y / 359) | |
| arr[:, :, 2] = 160 | |
| Image.fromarray(arr).save(grad_path) | |
| # 3) Shapes | |
| shapes_path = os.path.join(EXAMPLES_DIR, "shapes.png") | |
| if not os.path.exists(shapes_path): | |
| sh = Image.new("RGB", (512, 384), "white") | |
| d = ImageDraw.Draw(sh) | |
| colors = [(220,20,60),(65,105,225),(60,179,113),(255,165,0),(148,0,211)] | |
| for i,c in enumerate(colors): | |
| d.rectangle([20+90*i, 30, 80+90*i, 180], fill=c, outline=(0,0,0), width=3) | |
| for i in range(6): | |
| d.ellipse([40+80*i, 200, 90+80*i, 350], fill=colors[i%len(colors)], outline=(0,0,0), width=3) | |
| sh.save(shapes_path) | |
| def _list_example_images(): | |
| """Return [[path], [path], ...] for all images under examples/ (recursive).""" | |
| exts = {".png", ".jpg", ".jpeg", ".bmp", ".webp"} | |
| items = [] | |
| if os.path.isdir(EXAMPLES_DIR): | |
| for root, _, files in os.walk(EXAMPLES_DIR): | |
| for f in files: | |
| if os.path.splitext(f)[1].lower() in exts: | |
| items.append([os.path.join(root, f)]) | |
| # sort by path for stable order | |
| items.sort(key=lambda x: x[0].lower()) | |
| return items | |
| def _get_examples(): | |
| items = _list_example_images() | |
| if not items and CREATE_DEFAULTS_IF_EMPTY: | |
| _make_default_examples() | |
| items = _list_example_images() | |
| return items | |
| # ---------- App logic ---------- | |
| def _parse_grid(grid_choice: str) -> int: | |
| # Expect strings like "2x2", "4x4", "8x8", "16x16" | |
| try: | |
| n = int(grid_choice.lower().split("x")[0]) | |
| return max(1, n) | |
| except Exception: | |
| return 4 | |
| def run_single(image, grid_choice, use_overlap, overlap_pct, mix_prob, beta_a, beta_b, num_augs, seed): | |
| if image is None: | |
| return [] | |
| outs = [] | |
| base_seed = int(seed) if seed is not None else None | |
| N = _parse_grid(grid_choice) | |
| pct = float(overlap_pct) if use_overlap else 0.0 | |
| for i in range(num_augs): | |
| s = (base_seed + i) if base_seed is not None else None | |
| out_img = spm_augment( | |
| image, | |
| num_patches=N, | |
| mix_prob=float(mix_prob), | |
| beta_a=float(beta_a), | |
| beta_b=float(beta_b), | |
| overlap_pct=pct, | |
| seed=s | |
| ) | |
| outs.append(out_img) | |
| return outs | |
| def run_batch(zip_file, grid_choice, use_overlap, overlap_pct, mix_prob, beta_a, beta_b, seed): | |
| if zip_file is None: | |
| return None, "Please upload a .zip file with images." | |
| tempdir = tempfile.mkdtemp() | |
| outdir = os.path.join(tempdir, "outputs") | |
| os.makedirs(outdir, exist_ok=True) | |
| with zipfile.ZipFile(zip_file, 'r') as zf: | |
| zf.extractall(tempdir) | |
| valid_exts = {".png", ".jpg", ".jpeg"} | |
| count_in, count_out = 0, 0 | |
| N = _parse_grid(grid_choice) | |
| pct = float(overlap_pct) if use_overlap else 0.0 | |
| for root_dir, _, files in os.walk(tempdir): | |
| for f in files: | |
| if f.lower().endswith(tuple(valid_exts)): | |
| in_path = os.path.join(root_dir, f) | |
| try: | |
| img = Image.open(in_path).convert("RGB") | |
| except Exception: | |
| continue | |
| count_in += 1 | |
| out_img = spm_augment( | |
| img, | |
| num_patches=N, | |
| mix_prob=float(mix_prob), | |
| beta_a=float(beta_a), | |
| beta_b=float(beta_b), | |
| overlap_pct=pct, | |
| seed=int(seed) if seed is not None else None | |
| ) | |
| rel = os.path.relpath(in_path, tempdir) | |
| out_path = os.path.join(outdir, rel) | |
| os.makedirs(os.path.dirname(out_path), exist_ok=True) | |
| out_img.save(out_path) | |
| count_out += 1 | |
| out_zip = os.path.join(tempdir, f"spm_outputs_{int(time.time())}.zip") | |
| with zipfile.ZipFile(out_zip, "w", compression=zipfile.ZIP_DEFLATED) as zf: | |
| for root_dir, _, files in os.walk(outdir): | |
| for f in files: | |
| p = os.path.join(root_dir, f) | |
| arc = os.path.relpath(p, outdir) | |
| zf.write(p, arcname=arc) | |
| msg = f"Processed {count_out}/{count_in} files." | |
| return out_zip, msg | |
| # ---------- UI ---------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown(f"# {TITLE}") | |
| gr.Markdown(DESC) | |
| examples = _get_examples() | |
| with gr.Tabs(): | |
| with gr.TabItem("Single Image"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| inp = gr.Image(label="Input image", type="pil") | |
| gr.Examples(examples, inputs=[inp], label="Try these") | |
| grid_choice = gr.Radio(choices=["2x2","4x4","8x8","16x16"], value="8x8", label="Grid (N×N)") | |
| use_overlap = gr.Checkbox(value=True, label="Enable Overlap Patch Blend") | |
| overlap_pct = gr.Slider(0, 49, value=20, step=1, label="Overlap (% of patch)") | |
| mix_prob = gr.Slider(0, 1, value=0.8, step=0.05, label="Mix probability (per patch)") | |
| with gr.Row(): | |
| beta_a = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta(α, β), α =") | |
| beta_b = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta(α, β), β =") | |
| num_augs = gr.Slider(1, 12, value=4, step=1, label="Number of variants") | |
| seed = gr.Number(value=42, precision=0, label="Seed (int, optional)") | |
| run_btn = gr.Button("Generate") | |
| with gr.Column(scale=1): | |
| gallery = gr.Gallery(label="Augmented outputs", columns=2, height="auto") | |
| run_btn.click( | |
| fn=run_single, | |
| inputs=[inp, grid_choice, use_overlap, overlap_pct, mix_prob, beta_a, beta_b, num_augs, seed], | |
| outputs=[gallery] | |
| ) | |
| with gr.TabItem("Batch (.zip)"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| zip_in = gr.File(label="Upload a .zip of images", file_types=[".zip"]) | |
| grid_choice_b = gr.Radio(choices=["2x2","4x4","8x8","16x16"], value="8x8", label="Grid (N×N)") | |
| use_overlap_b = gr.Checkbox(value=True, label="Enable Overlap Patch Blend") | |
| overlap_pct_b = gr.Slider(0, 49, value=20, step=1, label="Overlap (% of patch)") | |
| mix_prob_b = gr.Slider(0, 1, value=0.8, step=0.05, label="Mix probability (per patch)") | |
| with gr.Row(): | |
| beta_a_b = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta(α, β), α =") | |
| beta_b_b = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta(α, β), β =") | |
| seed_b = gr.Number(value=42, precision=0, label="Seed (int, optional)") | |
| run_b = gr.Button("Process Zip") | |
| with gr.Column(scale=1): | |
| zip_out = gr.File(label="Download results (.zip)") | |
| status = gr.Markdown() | |
| run_b.click( | |
| fn=run_batch, | |
| inputs=[zip_in, grid_choice_b, use_overlap_b, overlap_pct_b, mix_prob_b, beta_a_b, beta_b_b, seed_b], | |
| outputs=[zip_out, status] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |