Spaces:
Sleeping
Sleeping
| import os, io, time, base64, random, subprocess | |
| from typing import Optional | |
| from urllib.parse import quote | |
| import requests | |
| from PIL import Image | |
| import gradio as gr | |
| # -------- Modal inference endpoint -------- | |
| INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run" | |
| # -------- Helpers -------- | |
| def _save_video_bytes(data: bytes, tag: str) -> str: | |
| os.makedirs("/tmp", exist_ok=True) | |
| path = f"/tmp/{tag}_{int(time.time())}.mp4" | |
| with open(path, "wb") as f: | |
| f.write(data) | |
| return path | |
| def _png_bytes(img: Image.Image) -> bytes: | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| return buf.getvalue() | |
| def _download_to_bytes(url: str) -> bytes: | |
| r = requests.get(url, timeout=180) | |
| r.raise_for_status() | |
| return r.content | |
| def stitch_call(start_img: Image.Image, end_img: Image.Image, prompt: str, seed: Optional[int]) -> Optional[str]: | |
| if start_img is None or end_img is None: | |
| return None | |
| if seed in (None, 0, -1): | |
| seed = random.randint(1, 2**31 - 1) | |
| url = f"{INFERENCE_URL}?prompt={quote(prompt or '')}&seed={seed}" | |
| files = { | |
| "image_bytes": ("start.png", _png_bytes(start_img), "image/png"), | |
| "image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"), | |
| } | |
| headers = {"accept": "application/json"} | |
| try: | |
| resp = requests.post(url, files=files, headers=headers, timeout=600) | |
| ctype = (resp.headers.get("content-type") or "").lower() | |
| # Raw video bytes | |
| if "application/json" not in ctype: | |
| resp.raise_for_status() | |
| return _save_video_bytes(resp.content, "stitch") | |
| # JSON with url or base64 | |
| data = resp.json() | |
| video_url = data.get("video_url") or data.get("url") or data.get("result") or data.get("output") | |
| if isinstance(video_url, str) and (video_url.startswith("http://") or video_url.startswith("https://")): | |
| b = _download_to_bytes(video_url) | |
| return _save_video_bytes(b, "stitch") | |
| video_b64 = data.get("video_b64") or data.get("videoBase64") | |
| if isinstance(video_b64, str): | |
| pad = (-len(video_b64)) % 4 | |
| if pad: | |
| video_b64 += "=" * pad | |
| b = base64.b64decode(video_b64) | |
| return _save_video_bytes(b, "stitch") | |
| except Exception as e: | |
| print("Stitch call failed:", e) | |
| return None | |
| # -------- Gradio callbacks -------- | |
| def stitch_12(prompt12, seed, img1, img2): | |
| if img1 is None or img2 is None: | |
| gr.Warning("Please upload Image 1 and Image 2.") | |
| return None | |
| path = stitch_call(img1, img2, prompt12 or "", int(seed or 0)) | |
| if path is None: | |
| gr.Warning("Stitch 1&2 failed. Try again or adjust the prompt.") | |
| return path | |
| def stitch_23(prompt23, seed, img2, img3): | |
| if img2 is None or img3 is None: | |
| gr.Warning("Please upload Image 2 and Image 3.") | |
| return None | |
| path = stitch_call(img2, img3, prompt23 or "", int(seed or 0)) | |
| if path is None: | |
| gr.Warning("Stitch 2&3 failed. Try again or adjust the prompt.") | |
| return path | |
| def stitch_all(video12, video23): | |
| if not video12 or not video23: | |
| gr.Warning("Please generate both stitched videos first.") | |
| return None | |
| try: | |
| # Final output path | |
| out_path = f"/tmp/stitch_all_{int(time.time())}.mp4" | |
| # Concatenate with ffmpeg | |
| txt_file = f"/tmp/concat_{int(time.time())}.txt" | |
| with open(txt_file, "w") as f: | |
| f.write(f"file '{video12}'\n") | |
| f.write(f"file '{video23}'\n") | |
| cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", txt_file, "-c", "copy", out_path] | |
| subprocess.run(cmd, check=True) | |
| return out_path | |
| except Exception as e: | |
| print("Stitch all failed:", e) | |
| gr.Warning("Failed to stitch all videos together.") | |
| return None | |
| # -------- UI -------- | |
| CSS = """ | |
| .gradio-container { padding: 24px; } | |
| .pill button { border-radius: 999px !important; padding: 10px 18px; } | |
| .rounded textarea { border-radius: 16px !important; } | |
| """ | |
| with gr.Blocks(css=CSS, title="Stitch β 3 uploads, 2 stitches, concat") as demo: | |
| gr.Markdown("## Stitch β Upload 3 images, generate videos between 1β2 and 2β3, then merge them.") | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=320): | |
| img1 = gr.Image(label="Image 1 upload", type="pil") | |
| img2 = gr.Image(label="Image 2 upload", type="pil") | |
| img3 = gr.Image(label="Image 3 upload", type="pil") | |
| with gr.Column(scale=1, min_width=320): | |
| seed = gr.Number(value=0, precision=0, label="Seed (0 = random)") | |
| prompt12 = gr.Textbox(placeholder="Prompt for stitching 1β2", lines=2, label="Prompt (1β2)", elem_classes=["rounded"]) | |
| btn12 = gr.Button("Stitch 1&2", elem_classes=["pill"]) | |
| vid12 = gr.Video(label="Video (image 1+2) output") | |
| prompt23 = gr.Textbox(placeholder="Prompt for stitching 2β3", lines=2, label="Prompt (2β3)", elem_classes=["rounded"]) | |
| btn23 = gr.Button("Stitch 2&3", elem_classes=["pill"]) | |
| vid23 = gr.Video(label="Video (image 2+3) output") | |
| btn_all = gr.Button("Stitch All", elem_classes=["pill"]) | |
| vid_all = gr.Video(label="Final concatenated video") | |
| # Wire buttons | |
| btn12.click(stitch_12, inputs=[prompt12, seed, img1, img2], outputs=[vid12]) | |
| btn23.click(stitch_23, inputs=[prompt23, seed, img2, img3], outputs=[vid23]) | |
| btn_all.click(stitch_all, inputs=[vid12, vid23], outputs=[vid_all]) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |