import os, io, time, base64, random, subprocess from typing import Optional from urllib.parse import quote import requests from PIL import Image import gradio as gr # -------- Modal inference endpoint (dev) -------- INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run" # -------- small helpers -------- def _save_video_bytes(data: bytes, tag: str) -> str: os.makedirs("/tmp", exist_ok=True) path = f"/tmp/{tag}_{int(time.time())}.mp4" with open(path, "wb") as f: f.write(data) return path def _png_bytes(img: Image.Image) -> bytes: buf = io.BytesIO() img.save(buf, format="PNG") return buf.getvalue() def _download_to_bytes(url: str) -> bytes: r = requests.get(url, timeout=180) r.raise_for_status() return r.content def stitch_call(start_img: Image.Image, end_img: Image.Image, prompt: str, seed: Optional[int]) -> Optional[str]: if start_img is None or end_img is None: return None if seed in (None, 0, -1): seed = random.randint(1, 2**31 - 1) url = f"{INFERENCE_URL}?prompt={quote(prompt or '')}&seed={seed}" files = { "image_bytes": ("start.png", _png_bytes(start_img), "image/png"), "image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"), } headers = {"accept": "application/json"} try: resp = requests.post(url, files=files, headers=headers, timeout=600) ctype = (resp.headers.get("content-type") or "").lower() # Raw video bytes if "application/json" not in ctype: resp.raise_for_status() return _save_video_bytes(resp.content, "stitch") # JSON with url or base64 data = resp.json() video_url = data.get("video_url") or data.get("url") or data.get("result") if isinstance(video_url, str) and video_url.startswith("http"): b = _download_to_bytes(video_url) return _save_video_bytes(b, "stitch") video_b64 = data.get("video_b64") if isinstance(video_b64, str): pad = (-len(video_b64)) % 4 if pad: video_b64 += "=" * pad b = base64.b64decode(video_b64) return _save_video_bytes(b, "stitch") except Exception as e: print("stitch_call error:", e) return None # -------- FFmpeg-based concatenation -------- def concat_videos(vid1: str, vid2: str) -> Optional[str]: if not vid1 or not vid2: return None try: os.makedirs("/tmp", exist_ok=True) out_path = f"/tmp/final_{int(time.time())}.mp4" # Create a temporary file list for ffmpeg list_file = f"/tmp/list_{int(time.time())}.txt" with open(list_file, "w") as f: f.write(f"file '{vid1}'\n") f.write(f"file '{vid2}'\n") # Run ffmpeg concat subprocess.run( ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_file, "-c", "copy", out_path], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) return out_path except Exception as e: print("concat_videos error:", e) return None # -------- Gradio callbacks -------- def stitch_12(prompt12, seed, img1, img2): path = stitch_call(img1, img2, prompt12 or "", int(seed or 0)) return path def stitch_23(prompt23, seed, img2, img3): path = stitch_call(img2, img3, prompt23 or "", int(seed or 0)) return path def stitch_all(vid12, vid23): if vid12 is None or vid23 is None: gr.Warning("Generate both videos first before stitching all.") return None return concat_videos(vid12, vid23) # -------- UI -------- CSS = """ .gradio-container { padding: 24px; } """ with gr.Blocks(css=CSS, title="Stitch — 3 uploads, 2 stitches") as demo: gr.Markdown("## Stitch — Upload 3 images → Generate 1→2, 2→3, then combine.") with gr.Row(): # Top: images with gr.Column(): img1 = gr.Image(label="Image 1 upload", type="pil") img2 = gr.Image(label="Image 2 upload", type="pil") img3 = gr.Image(label="Image 3 upload", type="pil") with gr.Row(): # Prompts + buttons with gr.Column(): seed = gr.Number(value=0, precision=0, label="Seed (0 = random)") prompt12 = gr.Textbox(placeholder="Prompt for stitching 1→2", lines=2, label="Prompt (1→2)") btn12 = gr.Button("Stitch 1&2") prompt23 = gr.Textbox(placeholder="Prompt for stitching 2→3", lines=2, label="Prompt (2→3)") btn23 = gr.Button("Stitch 2&3") btn_all = gr.Button("Stitch All (combine 1→2 and 2→3)") with gr.Column(): vid12 = gr.Video(label="Video (1→2)") vid23 = gr.Video(label="Video (2→3)") vid_all = gr.Video(label="Final Combined Video") # Wire btn12.click(stitch_12, inputs=[prompt12, seed, img1, img2], outputs=[vid12]) btn23.click(stitch_23, inputs=[prompt23, seed, img2, img3], outputs=[vid23]) btn_all.click(stitch_all, inputs=[vid12, vid23], outputs=[vid_all]) if __name__ == "__main__": demo.queue().launch()