import os import io import time import random import base64 from urllib.parse import quote_plus from typing import Optional, Tuple import requests import gradio as gr # ----------------------------- # Config # ----------------------------- # You can set your Modal endpoint via env var MM_I2V_URL DEFAULT_API_URL = os.getenv( "MM_I2V_URL", "https://moonmath-ai--moonmath-i2v-backend-moonmathinference-run.modal.run", ) SAVE_DIR = "outputs" os.makedirs(SAVE_DIR, exist_ok=True) # ----------------------------- # Helpers # ----------------------------- def _save_bytes_to_mp4(buf: bytes, name_prefix: str) -> str: ts = int(time.time() * 1000) path = os.path.join(SAVE_DIR, f"{name_prefix}-{ts}.mp4") with open(path, "wb") as f: f.write(buf) return path def _download(url: str) -> bytes: r = requests.get(url, timeout=600) r.raise_for_status() return r.content def call_i2v( image_path: str, prompt: str, seed: Optional[int], api_url: Optional[str] = None, ) -> Tuple[Optional[str], Optional[str]]: \"\"\" Call the image->video backend and return (video_path, error_message). Tries to handle several common response types: 1) raw mp4 bytes 2) JSON with {\"video\": \"\"} (mp4 base64) 3) JSON with {\"video_url\": \"https://...\"} (or \"result_url\") \"\"\" api = (api_url or DEFAULT_API_URL).strip().rstrip(\"/\") used_seed = seed if (seed is not None and str(seed).strip() != \"\") else random.randint(0, 2**31 - 1) url = f\"{api}?prompt={quote_plus(prompt)}&seed={used_seed}\" files = { \"image_bytes\": (os.path.basename(image_path), open(image_path, \"rb\"), \"application/octet-stream\") } headers = {\"accept\": \"application/json\"} try: resp = requests.post(url, headers=headers, files=files, timeout=1200) # Try to accommodate various backends ctype = resp.headers.get(\"Content-Type\", \"\") if \"application/json\" in ctype: data = resp.json() # base64 payload if \"video\" in data and isinstance(data[\"video\"], str) and len(data[\"video\"]) > 50: try: raw = base64.b64decode(data[\"video\"], validate=True) return _save_bytes_to_mp4(raw, \"clip\"), None except Exception as e: return None, f\"Could not decode base64 video: {e}\" # url payload for key in (\"video_url\", \"result_url\", \"url\"): if key in data and isinstance(data[key], str) and data[key].startswith(\"http\"): raw = _download(data[key]) return _save_bytes_to_mp4(raw, \"clip\"), None return None, 'JSON response did not include \"video\" (base64) or a known url key.' # Raw bytes (ideally mp4) elif \"video\" in ctype or \"octet-stream\" in ctype: return _save_bytes_to_mp4(resp.content, \"clip\"), None else: # Some backends still reply bytes with missing/odd content-type if resp.content and len(resp.content) > 1024: return _save_bytes_to_mp4(resp.content, \"clip\"), None return None, f\"Unexpected content type: {ctype}\" except requests.RequestException as e: return None, f\"Request failed: {e}\" def stitch_pair( image_a: str, image_b: str, prompt: str, seed: Optional[int], api_url: Optional[str], crossfade: float, ) -> Tuple[Optional[str], str]: \"\"\" Strategy: - Generate a short clip from image A - Generate a short clip from image B (same prompt/seed unless user changes) - Concatenate with a short crossfade in Python (moviepy) If you already have a backend endpoint that does stitching directly, replace this function body with a single backend call. \"\"\" if not image_a or not image_b: return None, \"Please upload both images.\" # First generate both clips clip1_path, err1 = call_i2v(image_a, prompt, seed, api_url) if err1: return None, f\"Clip 1 failed: {err1}\" clip2_path, err2 = call_i2v(image_b, prompt, seed, api_url) if err2: return None, f\"Clip 2 failed: {err2}\" # If crossfade is 0, just concatenate directly try: from moviepy.editor import VideoFileClip, concatenate_videoclips except Exception as e: return None, f\"MoviePy import failed. Add moviepy & imageio-ffmpeg to requirements.txt. Error: {e}\" try: c1 = VideoFileClip(clip1_path) c2 = VideoFileClip(clip2_path) # Enforce same size/fps (compose handles mismatches) if crossfade and crossfade > 0: # Apply crossfade (second clip fades in) c2 = c2.crossfadein(crossfade) c1 = c1.crossfadeout(crossfade) merged = concatenate_videoclips([c1, c2], method=\"compose\", padding=-crossfade) else: merged = concatenate_videoclips([c1, c2], method=\"compose\") out_path = os.path.join(SAVE_DIR, f\"stitched-{int(time.time()*1000)}.mp4\") merged.write_videofile(out_path, codec=\"libx264\", audio_codec=\"aac\", verbose=False, logger=None) c1.close(); c2.close(); merged.close() return out_path, \"\" except Exception as e: return None, f\"Stitching failed: {e}\" # ----------------------------- # UI # ----------------------------- with gr.Blocks(title=\"Image Stitch to Video\", css=\"\"\" /* Rounded tiles like the sketch */ .rounded { border-radius: 24px; } .tile { background: #f7f7ff; padding: 12px; } .tile-blue { background: #e8f0ff; } .tile-yellow { background: #fff7d6; } .small-btn button { padding: 6px 10px; border-radius: 999px; } .label-center label { text-align:center; width: 100%; } \"\"\") as demo: gr.Markdown(\"### Image → Video (Stitch Adjacent Pairs)\\nUpload 3 images, enter prompts for each stitch, then click the stitch buttons.\") with gr.Row(equal_height=True): # Left column: images + add image with gr.Column(scale=1): gr.Markdown(\"**Images**\") img1 = gr.Image(type=\"filepath\", label=\"Image 1\", height=220, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"]) img2 = gr.Image(type=\"filepath\", label=\"Image 2\", height=220, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"]) img3 = gr.Image(type=\"filepath\", label=\"Image 3\", height=220, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"]) # Optional extra slots (hidden until added) extra_imgs = [] for i in range(4, 9): comp = gr.Image(type=\"filepath\", label=f\"Image {i}\", height=220, visible=False, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"]) extra_imgs.append(comp) add_btn = gr.Button(\"Add Image\", variant=\"secondary\") # Middle column: prompts + stitch buttons with gr.Column(scale=1): gr.Markdown(\"**Prompts**\") prompt12 = gr.Textbox(label=\"Prompt for Stitch 1 & 2\", lines=3, placeholder=\"Describe motion/style/etc.\", elem_classes=[\"rounded\", \"tile\"]) seed12 = gr.Number(label=\"Seed (optional)\", value=None, precision=0) stitch12 = gr.Button(\"Stitch 1 & 2\", elem_classes=[\"small-btn\"]) prompt23 = gr.Textbox(label=\"Prompt for Stitch 2 & 3\", lines=3, placeholder=\"Describe motion/style/etc.\", elem_classes=[\"rounded\", \"tile\"]) seed23 = gr.Number(label=\"Seed (optional)\", value=None, precision=0) stitch23 = gr.Button(\"Stitch 2 & 3\", elem_classes=[\"small-btn\"]) with gr.Accordion(\"Advanced (API & Stitch)\", open=False): api_url = gr.Textbox(label=\"Backend API URL\", value=DEFAULT_API_URL) crossfade = gr.Slider(0.0, 1.5, value=0.4, step=0.1, label=\"Crossfade seconds\") clear_btn = gr.Button(\"Clear All\") # Right column: video outputs with gr.Column(scale=1): gr.Markdown(\"**Outputs**\") vid12 = gr.Video(label=\"Video (image 1 + 2) output\", elem_classes=[\"rounded\", \"tile\", \"tile-yellow\"]) vid23 = gr.Video(label=\"Video (image 2 + 3) output\", elem_classes=[\"rounded\", \"tile\", \"tile-yellow\"]) # Wire up actions def _on_add(*imgs): # Reveal the next hidden uploader for comp in extra_imgs: if comp.visible is False: comp.visible = True break return [gr.update(visible=comp.visible) for comp in extra_imgs] add_btn.click( _on_add, inputs=extra_imgs, outputs=extra_imgs, ) stitch12.click( stitch_pair, inputs=[img1, img2, prompt12, seed12, api_url, crossfade], outputs=[vid12, gr.Textbox(visible=False)], ) stitch23.click( stitch_pair, inputs=[img2, img3, prompt23, seed23, api_url, crossfade], outputs=[vid23, gr.Textbox(visible=False)], ) def _on_clear(): updates = [] for comp in [img1, img2, img3, *extra_imgs]: updates.append(gr.update(value=None, visible=True if comp in [img1, img2, img3] else False)) return updates + [None, None, \"\", \"\", gr.update(value=DEFAULT_API_URL), 0.4] clear_btn.click( _on_clear, inputs=None, outputs=[img1, img2, img3, *extra_imgs, vid12, vid23, prompt12, prompt23, api_url, crossfade], ) if __name__ == \"__main__\": demo.launch()