Spaces:
Running
Running
| import os, io, time, base64, random, subprocess | |
| from typing import Optional | |
| from urllib.parse import quote | |
| import requests | |
| from PIL import Image | |
| import gradio as gr | |
| # -------- Modal inference endpoint (dev) -------- | |
| INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run" | |
| # -------- small helpers -------- | |
| def _save_video_bytes(data: bytes, tag: str) -> str: | |
| os.makedirs("/tmp", exist_ok=True) | |
| path = f"/tmp/{tag}_{int(time.time())}.mp4" | |
| with open(path, "wb") as f: | |
| f.write(data) | |
| return path | |
| def _png_bytes(img: Image.Image) -> bytes: | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| return buf.getvalue() | |
| def _download_to_bytes(url: str) -> bytes: | |
| r = requests.get(url, timeout=180) | |
| r.raise_for_status() | |
| return r.content | |
| def stitch_call(start_img: Image.Image, end_img: Image.Image, prompt: str, seed: Optional[int]) -> Optional[str]: | |
| if start_img is None or end_img is None: | |
| return None | |
| if seed in (None, 0, -1): | |
| seed = random.randint(1, 2**31 - 1) | |
| url = f"{INFERENCE_URL}?prompt={quote(prompt or '')}&seed={seed}" | |
| files = { | |
| "image_bytes": ("start.png", _png_bytes(start_img), "image/png"), | |
| "image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"), | |
| } | |
| headers = {"accept": "application/json"} | |
| try: | |
| resp = requests.post(url, files=files, headers=headers, timeout=600) | |
| ctype = (resp.headers.get("content-type") or "").lower() | |
| # Raw video bytes | |
| if "application/json" not in ctype: | |
| resp.raise_for_status() | |
| return _save_video_bytes(resp.content, "stitch") | |
| # JSON with url or base64 | |
| data = resp.json() | |
| video_url = data.get("video_url") or data.get("url") or data.get("result") | |
| if isinstance(video_url, str) and video_url.startswith("http"): | |
| b = _download_to_bytes(video_url) | |
| return _save_video_bytes(b, "stitch") | |
| video_b64 = data.get("video_b64") | |
| if isinstance(video_b64, str): | |
| pad = (-len(video_b64)) % 4 | |
| if pad: | |
| video_b64 += "=" * pad | |
| b = base64.b64decode(video_b64) | |
| return _save_video_bytes(b, "stitch") | |
| except Exception as e: | |
| print("stitch_call error:", e) | |
| return None | |
| # -------- FFmpeg-based concatenation -------- | |
| def concat_videos(vid1: str, vid2: str) -> Optional[str]: | |
| if not vid1 or not vid2: | |
| return None | |
| try: | |
| os.makedirs("/tmp", exist_ok=True) | |
| out_path = f"/tmp/final_{int(time.time())}.mp4" | |
| # Create a temporary file list for ffmpeg | |
| list_file = f"/tmp/list_{int(time.time())}.txt" | |
| with open(list_file, "w") as f: | |
| f.write(f"file '{vid1}'\n") | |
| f.write(f"file '{vid2}'\n") | |
| # Run ffmpeg concat | |
| subprocess.run( | |
| ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_file, "-c", "copy", out_path], | |
| check=True, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| ) | |
| return out_path | |
| except Exception as e: | |
| print("concat_videos error:", e) | |
| return None | |
| # -------- Gradio callbacks -------- | |
| def stitch_12(prompt12, seed, img1, img2): | |
| path = stitch_call(img1, img2, prompt12 or "", int(seed or 0)) | |
| return path | |
| def stitch_23(prompt23, seed, img2, img3): | |
| path = stitch_call(img2, img3, prompt23 or "", int(seed or 0)) | |
| return path | |
| def stitch_all(vid12, vid23): | |
| if vid12 is None or vid23 is None: | |
| gr.Warning("Generate both videos first before stitching all.") | |
| return None | |
| return concat_videos(vid12, vid23) | |
| # -------- UI -------- | |
| CSS = """ | |
| .gradio-container { padding: 24px; } | |
| """ | |
| with gr.Blocks(css=CSS, title="Stitch β 3 uploads, 2 stitches") as demo: | |
| gr.Markdown("## Stitch β Upload 3 images β Generate 1β2, 2β3, then combine.") | |
| with gr.Row(): | |
| # Top: images | |
| with gr.Column(): | |
| img1 = gr.Image(label="Image 1 upload", type="pil") | |
| img2 = gr.Image(label="Image 2 upload", type="pil") | |
| img3 = gr.Image(label="Image 3 upload", type="pil") | |
| with gr.Row(): | |
| # Prompts + buttons | |
| with gr.Column(): | |
| seed = gr.Number(value=0, precision=0, label="Seed (0 = random)") | |
| prompt12 = gr.Textbox(placeholder="Prompt for stitching 1β2", lines=2, label="Prompt (1β2)") | |
| btn12 = gr.Button("Stitch 1&2") | |
| prompt23 = gr.Textbox(placeholder="Prompt for stitching 2β3", lines=2, label="Prompt (2β3)") | |
| btn23 = gr.Button("Stitch 2&3") | |
| btn_all = gr.Button("Stitch All (combine 1β2 and 2β3)") | |
| with gr.Column(): | |
| vid12 = gr.Video(label="Video (1β2)") | |
| vid23 = gr.Video(label="Video (2β3)") | |
| vid_all = gr.Video(label="Final Combined Video") | |
| # Wire | |
| btn12.click(stitch_12, inputs=[prompt12, seed, img1, img2], outputs=[vid12]) | |
| btn23.click(stitch_23, inputs=[prompt23, seed, img2, img3], outputs=[vid23]) | |
| btn_all.click(stitch_all, inputs=[vid12, vid23], outputs=[vid_all]) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |