import os, io, time, base64, random, subprocess from typing import Optional, List from urllib.parse import urlencode import requests from PIL import Image import gradio as gr # -------- Modal inference endpoint (dev) -------- INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run" # -------- settings -------- MAX_SLOTS = 12 # max image slots user can reveal # -------- small helpers -------- def _save_video_bytes(data: bytes, tag: str) -> str: os.makedirs("/tmp", exist_ok=True) path = f"/tmp/{tag}_{int(time.time())}.mp4" with open(path, "wb") as f: f.write(data) return path def _png_bytes(img: Image.Image) -> bytes: buf = io.BytesIO() img.save(buf, format="PNG") return buf.getvalue() def _download_to_bytes(url: str) -> bytes: r = requests.get(url, timeout=180) r.raise_for_status() return r.content def stitch_call( start_img: Image.Image, end_img: Image.Image, prompt: str, seed: Optional[int], negative_prompt: Optional[str] = None, frames_per_second: int = 24, video_length: int = 4, num_inference_steps: Optional[int] = None, ) -> Optional[str]: """ Required (in body): image_bytes (+ image_bytes_end) In URL query: prompt, negative_prompt, frames_per_second, video_length, seed, num_inference_steps """ if start_img is None or end_img is None: return None # default seed behavior if seed in (None, 0, -1): seed = random.randint(1, 2**31 - 1) # Build query string q = { "prompt": prompt or "", "seed": int(seed), "frames_per_second": int(frames_per_second), "video_length": int(video_length), } if negative_prompt: q["negative_prompt"] = negative_prompt if num_inference_steps is not None: q["num_inference_steps"] = int(num_inference_steps) url = f"{INFERENCE_URL}?{urlencode(q)}" # Images go in the body files = { "image_bytes": ("start.png", _png_bytes(start_img), "image/png"), "image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"), } headers = {"accept": "application/json"} try: resp = requests.post(url, files=files, headers=headers, timeout=600) ctype = (resp.headers.get("content-type") or "").lower() # Raw video bytes if "application/json" not in ctype: resp.raise_for_status() return _save_video_bytes(resp.content, "stitch") # JSON with url or base64 data = resp.json() video_url = data.get("video_url") or data.get("url") or data.get("result") or data.get("output") if isinstance(video_url, str) and video_url.startswith(("http://", "https://")): return _save_video_bytes(_download_to_bytes(video_url), "stitch") video_b64 = data.get("video_b64") or data.get("videoBase64") if isinstance(video_b64, str): pad = (-len(video_b64)) % 4 if pad: video_b64 += "=" * pad return _save_video_bytes(base64.b64decode(video_b64), "stitch") except Exception as e: print("stitch_call error:", e) return None # -------- FFmpeg-based concatenation (N clips) -------- def concat_many(videos: List[str]) -> Optional[str]: vids = [v for v in videos if v] if len(vids) < 2: return None try: os.makedirs("/tmp", exist_ok=True) out_path = f"/tmp/final_{int(time.time())}.mp4" list_file = f"/tmp/list_{int(time.time())}.txt" with open(list_file, "w") as f: for v in vids: f.write(f"file '{v}'\n") subprocess.run( ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_file, "-c", "copy", out_path], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) return out_path except Exception as e: print("concat_many error:", e) return None # -------- Timeline HTML renderer -------- def render_timeline_html(paths: List[str]): vids = [p for p in (paths or []) if p] if not vids: return "
No clips yet. Generate and click ‘Add to timeline’.
" items = [] for i, p in enumerate(vids, 1): items.append( f"""
Clip {i}
""" ) return f"
{''.join(items)}
" # ========================= # Gradio callbacks / state ops # ========================= def add_image_slot(visible_slots: int): """Reveal one more upload slot (up to MAX_SLOTS).""" return min(MAX_SLOTS, int(visible_slots) + 1) def _reveal_slots(n, *imgs): """Update visibility of image upload components based on visible_slots state.""" n = int(n) updates = [] for i in range(MAX_SLOTS): updates.append(gr.update(visible=(i < n))) return updates def collect_choices(*imgs): """Build dropdown choices of available indices (1-based labels) based on non-empty slots.""" choices = [] for i, img in enumerate(imgs, start=1): if img is not None: choices.append(str(i)) return gr.update(choices=choices), gr.update(choices=choices) def stitch_selected( prompt, negative_prompt, fps, length_sec, seed, start_idx_str, end_idx_str, *imgs ): """Run inference for selected start/end indices (1-based strings) + options.""" if not start_idx_str or not end_idx_str: gr.Warning("Please select Start and End frames.") return None try: s = int(start_idx_str) - 1 e = int(end_idx_str) - 1 except Exception: gr.Warning("Invalid Start/End selection.") return None if s < 0 or e < 0 or s >= len(imgs) or e >= len(imgs): gr.Warning("Start/End out of range.") return None start_img = imgs[s] end_img = imgs[e] if start_img is None or end_img is None: gr.Warning("Selected slots are empty.") return None fps_val = int(str(fps)) if fps else 24 len_val = int(str(length_sec)) if length_sec else 4 vid = stitch_call( start_img=start_img, end_img=end_img, prompt=prompt or "", seed=int(seed or 0), negative_prompt=(negative_prompt or "").strip() or None, frames_per_second=fps_val, video_length=len_val, num_inference_steps=None, ) if not vid: gr.Warning("Generation failed.") return None return vid # path for preview def add_to_timeline(preview_path, timeline_paths: List[str]): """Append preview to timeline; return updated state and HTML.""" tl = list(timeline_paths or []) if not preview_path: gr.Warning("Generate a clip first.") return tl, gr.update(value=render_timeline_html(tl)) tl.append(preview_path) return tl, gr.update(value=render_timeline_html(tl)) def stitch_all_from_timeline(timeline_paths: List[str]): vids = list(timeline_paths or []) if len(vids) < 2: gr.Warning("Add at least two clips to the timeline first.") return None out = concat_many(vids) if not out: gr.Warning("Failed to concatenate clips.") return out # ========================= # UI # ========================= CSS = """ .gradio-container { padding: 24px; } .pill button { border-radius: 999px !important; padding: 10px 18px; } .rounded textarea { border-radius: 16px !important; } .gallery-row { display:flex; gap:16px; overflow-x:auto; padding:8px 4px; } .gallery-row .gradio-image { min-width: 220px; } .tl-grid { display: grid; grid-template-columns: repeat(auto-fill, minmax(180px, 1fr)); gap: 12px; } .stitch-box { background-color: #f0f4ff; /* pick any color you like */ border-radius: 12px; padding: 16px; } .tl-grid video { width: 100%; height: 120px; object-fit: cover; border-radius: 12px; display: block; } .tl-label { font-size: 12px; color: #9aa0a6; margin-top: 4px; text-align: center; } .tl-empty { color: #9aa0a6; padding: 8px 4px; } """ with gr.Blocks(css=CSS, title="StitchTool") as demo: gr.Markdown("## StitchTool") # --- State --- visible_slots = gr.State(value=3) # number of visible image slots timeline_state = gr.State(value=[]) # list[str] of video file paths (timeline) # --- Image gallery (horizontal, grows on demand) --- with gr.Row(elem_classes=["gallery-row"]): img_comps = [] for i in range(MAX_SLOTS): comp = gr.Image(label=f"Image {i+1} upload", type="pil", visible=(i < 3)) img_comps.append(comp) add_btn = gr.Button("+ Add image") # clicking add → reveal one more slot add_btn.click( fn=add_image_slot, inputs=[visible_slots], outputs=[visible_slots], ) # reflect visibility changes whenever visible_slots changes visible_slots.change( fn=_reveal_slots, inputs=[visible_slots] + img_comps, outputs=img_comps ) # Seed + Start/End selection + Prompt + options + Stitch + Preview seed = gr.Number(value=0, precision=0, label="Seed (0 = random)") with gr.Row(): # Left column: controls (with colored background via .stitch-box) with gr.Column(scale=1, min_width=420, elem_classes=["stitch-box"]): start_dd = gr.Dropdown(label="Start frame", choices=[], interactive=True) end_dd = gr.Dropdown(label="End frame", choices=[], interactive=True) prompt = gr.Textbox( placeholder="Describe the transition between the selected start and end frames…", lines=3, label="Prompt", elem_classes=["rounded"] ) negative = gr.Textbox( placeholder="Optional: things to avoid (e.g., 'bad quality, extra fingers, etc.')", lines=2, label="Negative prompt", elem_classes=["rounded"] ) with gr.Row(): fps = gr.Dropdown( label="Frame rate", choices=["16", "24", "32"], value="24", interactive=True, ) length_sec = gr.Dropdown( label="Video length (sec)", choices=["2", "4"], value="4", interactive=True, ) run_btn = gr.Button("Generate", elem_classes=["pill"]) add_tl_btn = gr.Button("Add to timeline", elem_classes=["pill"]) # Right column: preview video with gr.Column(scale=1, min_width=420): preview = gr.Video(label="Video output", interactive=False) # keep start/end dropdowns up to date based on which slots have images for comp in img_comps: comp.change( fn=collect_choices, inputs=img_comps, outputs=[start_dd, end_dd] ) # stitch action → preview run_btn.click( fn=stitch_selected, inputs=[prompt, negative, fps, length_sec, seed, start_dd, end_dd] + img_comps, outputs=[preview] ) # --- Dynamic timeline (no placeholders) --- with gr.Row(): timeline_html = gr.HTML(value=render_timeline_html([])) add_tl_btn.click( fn=add_to_timeline, inputs=[preview, timeline_state], outputs=[timeline_state, timeline_html] ) # final stitch all (concatenate in order) with gr.Row(): with gr.Column(scale=1, min_width=420): stitch_all_btn = gr.Button("Stitch All", elem_classes=["pill"]) with gr.Column(scale=1, min_width=420): final_vid = gr.Video(label="Stitched Video Output", interactive=False) stitch_all_btn.click( fn=stitch_all_from_timeline, inputs=[timeline_state], outputs=[final_vid] ) if __name__ == "__main__": demo.queue().launch()