import os, io, time, base64, random, subprocess from typing import Optional from urllib.parse import quote import requests from PIL import Image import gradio as gr # -------- Modal inference endpoint -------- INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run" # -------- Helpers -------- def _save_video_bytes(data: bytes, tag: str) -> str: os.makedirs("/tmp", exist_ok=True) path = f"/tmp/{tag}_{int(time.time())}.mp4" with open(path, "wb") as f: f.write(data) return path def _png_bytes(img: Image.Image) -> bytes: buf = io.BytesIO() img.save(buf, format="PNG") return buf.getvalue() def _download_to_bytes(url: str) -> bytes: r = requests.get(url, timeout=180) r.raise_for_status() return r.content def stitch_call(start_img: Image.Image, end_img: Image.Image, prompt: str, seed: Optional[int]) -> Optional[str]: if start_img is None or end_img is None: return None if seed in (None, 0, -1): seed = random.randint(1, 2**31 - 1) url = f"{INFERENCE_URL}?prompt={quote(prompt or '')}&seed={seed}" files = { "image_bytes": ("start.png", _png_bytes(start_img), "image/png"), "image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"), } headers = {"accept": "application/json"} try: resp = requests.post(url, files=files, headers=headers, timeout=600) ctype = (resp.headers.get("content-type") or "").lower() # Raw video bytes if "application/json" not in ctype: resp.raise_for_status() return _save_video_bytes(resp.content, "stitch") # JSON with url or base64 data = resp.json() video_url = data.get("video_url") or data.get("url") or data.get("result") or data.get("output") if isinstance(video_url, str) and (video_url.startswith("http://") or video_url.startswith("https://")): b = _download_to_bytes(video_url) return _save_video_bytes(b, "stitch") video_b64 = data.get("video_b64") or data.get("videoBase64") if isinstance(video_b64, str): pad = (-len(video_b64)) % 4 if pad: video_b64 += "=" * pad b = base64.b64decode(video_b64) return _save_video_bytes(b, "stitch") except Exception as e: print("Stitch call failed:", e) return None # -------- Gradio callbacks -------- def stitch_12(prompt12, seed, img1, img2): if img1 is None or img2 is None: gr.Warning("Please upload Image 1 and Image 2.") return None path = stitch_call(img1, img2, prompt12 or "", int(seed or 0)) if path is None: gr.Warning("Stitch 1&2 failed. Try again or adjust the prompt.") return path def stitch_23(prompt23, seed, img2, img3): if img2 is None or img3 is None: gr.Warning("Please upload Image 2 and Image 3.") return None path = stitch_call(img2, img3, prompt23 or "", int(seed or 0)) if path is None: gr.Warning("Stitch 2&3 failed. Try again or adjust the prompt.") return path def stitch_all(video12, video23): if not video12 or not video23: gr.Warning("Please generate both stitched videos first.") return None try: # Final output path out_path = f"/tmp/stitch_all_{int(time.time())}.mp4" # Concatenate with ffmpeg txt_file = f"/tmp/concat_{int(time.time())}.txt" with open(txt_file, "w") as f: f.write(f"file '{video12}'\n") f.write(f"file '{video23}'\n") cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", txt_file, "-c", "copy", out_path] subprocess.run(cmd, check=True) return out_path except Exception as e: print("Stitch all failed:", e) gr.Warning("Failed to stitch all videos together.") return None # -------- UI -------- CSS = """ .gradio-container { padding: 24px; } .pill button { border-radius: 999px !important; padding: 10px 18px; } .rounded textarea { border-radius: 16px !important; } """ with gr.Blocks(css=CSS, title="Stitch — 3 uploads, 2 stitches, concat") as demo: gr.Markdown("## Stitch — Upload 3 images, generate videos between 1→2 and 2→3, then merge them.") with gr.Row(): with gr.Column(scale=1, min_width=320): img1 = gr.Image(label="Image 1 upload", type="pil") img2 = gr.Image(label="Image 2 upload", type="pil") img3 = gr.Image(label="Image 3 upload", type="pil") with gr.Column(scale=1, min_width=320): seed = gr.Number(value=0, precision=0, label="Seed (0 = random)") prompt12 = gr.Textbox(placeholder="Prompt for stitching 1→2", lines=2, label="Prompt (1→2)", elem_classes=["rounded"]) btn12 = gr.Button("Stitch 1&2", elem_classes=["pill"]) vid12 = gr.Video(label="Video (image 1+2) output") prompt23 = gr.Textbox(placeholder="Prompt for stitching 2→3", lines=2, label="Prompt (2→3)", elem_classes=["rounded"]) btn23 = gr.Button("Stitch 2&3", elem_classes=["pill"]) vid23 = gr.Video(label="Video (image 2+3) output") btn_all = gr.Button("Stitch All", elem_classes=["pill"]) vid_all = gr.Video(label="Final concatenated video") # Wire buttons btn12.click(stitch_12, inputs=[prompt12, seed, img1, img2], outputs=[vid12]) btn23.click(stitch_23, inputs=[prompt23, seed, img2, img3], outputs=[vid23]) btn_all.click(stitch_all, inputs=[vid12, vid23], outputs=[vid_all]) if __name__ == "__main__": demo.queue().launch()