Spaces:
Sleeping
Sleeping
| import os, io, time, base64, random, subprocess | |
| from typing import Optional, List | |
| from urllib.parse import urlencode | |
| import requests | |
| from PIL import Image | |
| import gradio as gr | |
| # -------- Modal inference endpoint (dev) -------- | |
| INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run" | |
| # -------- settings -------- | |
| MAX_SLOTS = 12 # max image slots user can reveal | |
| # -------- small helpers -------- | |
| def _save_video_bytes(data: bytes, tag: str) -> str: | |
| os.makedirs("/tmp", exist_ok=True) | |
| path = f"/tmp/{tag}_{int(time.time())}.mp4" | |
| with open(path, "wb") as f: | |
| f.write(data) | |
| return path | |
| def _png_bytes(img: Image.Image) -> bytes: | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| return buf.getvalue() | |
| def _download_to_bytes(url: str) -> bytes: | |
| r = requests.get(url, timeout=180) | |
| r.raise_for_status() | |
| return r.content | |
| def stitch_call( | |
| start_img: Image.Image, | |
| end_img: Image.Image, | |
| prompt: str, | |
| seed: Optional[int], | |
| negative_prompt: Optional[str] = None, | |
| frames_per_second: int = 24, | |
| video_length: int = 4, | |
| num_inference_steps: Optional[int] = None, | |
| ) -> Optional[str]: | |
| """ | |
| Required (in body): image_bytes (+ image_bytes_end) | |
| In URL query: prompt, negative_prompt, frames_per_second, video_length, seed, num_inference_steps | |
| """ | |
| if start_img is None or end_img is None: | |
| return None | |
| # default seed behavior | |
| if seed in (None, 0, -1): | |
| seed = random.randint(1, 2**31 - 1) | |
| # Build query string | |
| q = { | |
| "prompt": prompt or "", | |
| "seed": int(seed), | |
| "frames_per_second": int(frames_per_second), | |
| "video_length": int(video_length), | |
| } | |
| if negative_prompt: | |
| q["negative_prompt"] = negative_prompt | |
| if num_inference_steps is not None: | |
| q["num_inference_steps"] = int(num_inference_steps) | |
| url = f"{INFERENCE_URL}?{urlencode(q)}" | |
| # Images go in the body | |
| files = { | |
| "image_bytes": ("start.png", _png_bytes(start_img), "image/png"), | |
| "image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"), | |
| } | |
| headers = {"accept": "application/json"} | |
| try: | |
| resp = requests.post(url, files=files, headers=headers, timeout=600) | |
| ctype = (resp.headers.get("content-type") or "").lower() | |
| # Raw video bytes | |
| if "application/json" not in ctype: | |
| resp.raise_for_status() | |
| return _save_video_bytes(resp.content, "stitch") | |
| # JSON with url or base64 | |
| data = resp.json() | |
| video_url = data.get("video_url") or data.get("url") or data.get("result") or data.get("output") | |
| if isinstance(video_url, str) and video_url.startswith(("http://", "https://")): | |
| return _save_video_bytes(_download_to_bytes(video_url), "stitch") | |
| video_b64 = data.get("video_b64") or data.get("videoBase64") | |
| if isinstance(video_b64, str): | |
| pad = (-len(video_b64)) % 4 | |
| if pad: | |
| video_b64 += "=" * pad | |
| return _save_video_bytes(base64.b64decode(video_b64), "stitch") | |
| except Exception as e: | |
| print("stitch_call error:", e) | |
| return None | |
| # -------- FFmpeg-based concatenation (N clips) -------- | |
| def concat_many(videos: List[str]) -> Optional[str]: | |
| vids = [v for v in videos if v] | |
| if len(vids) < 2: | |
| return None | |
| try: | |
| os.makedirs("/tmp", exist_ok=True) | |
| out_path = f"/tmp/final_{int(time.time())}.mp4" | |
| list_file = f"/tmp/list_{int(time.time())}.txt" | |
| with open(list_file, "w") as f: | |
| for v in vids: | |
| f.write(f"file '{v}'\n") | |
| subprocess.run( | |
| ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_file, "-c", "copy", out_path], | |
| check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE | |
| ) | |
| return out_path | |
| except Exception as e: | |
| print("concat_many error:", e) | |
| return None | |
| # -------- Timeline HTML renderer -------- | |
| def render_timeline_html(paths: List[str]): | |
| vids = [p for p in (paths or []) if p] | |
| if not vids: | |
| return "<div class='tl-grid tl-empty'>No clips yet. Generate and click ‘Add to timeline’.</div>" | |
| items = [] | |
| for i, p in enumerate(vids, 1): | |
| items.append( | |
| f""" | |
| <div class="tl-item"> | |
| <video src="{p}" controls playsinline></video> | |
| <div class="tl-label">Clip {i}</div> | |
| </div> | |
| """ | |
| ) | |
| return f"<div class='tl-grid'>{''.join(items)}</div>" | |
| # ========================= | |
| # Gradio callbacks / state ops | |
| # ========================= | |
| def add_image_slot(visible_slots: int): | |
| """Reveal one more upload slot (up to MAX_SLOTS).""" | |
| return min(MAX_SLOTS, int(visible_slots) + 1) | |
| def _reveal_slots(n, *imgs): | |
| """Update visibility of image upload components based on visible_slots state.""" | |
| n = int(n) | |
| updates = [] | |
| for i in range(MAX_SLOTS): | |
| updates.append(gr.update(visible=(i < n))) | |
| return updates | |
| def collect_choices(*imgs): | |
| """Build dropdown choices of available indices (1-based labels) based on non-empty slots.""" | |
| choices = [] | |
| for i, img in enumerate(imgs, start=1): | |
| if img is not None: | |
| choices.append(str(i)) | |
| return gr.update(choices=choices), gr.update(choices=choices) | |
| def stitch_selected( | |
| prompt, negative_prompt, fps, length_sec, seed, start_idx_str, end_idx_str, *imgs | |
| ): | |
| """Run inference for selected start/end indices (1-based strings) + options.""" | |
| if not start_idx_str or not end_idx_str: | |
| gr.Warning("Please select Start and End frames.") | |
| return None | |
| try: | |
| s = int(start_idx_str) - 1 | |
| e = int(end_idx_str) - 1 | |
| except Exception: | |
| gr.Warning("Invalid Start/End selection.") | |
| return None | |
| if s < 0 or e < 0 or s >= len(imgs) or e >= len(imgs): | |
| gr.Warning("Start/End out of range.") | |
| return None | |
| start_img = imgs[s] | |
| end_img = imgs[e] | |
| if start_img is None or end_img is None: | |
| gr.Warning("Selected slots are empty.") | |
| return None | |
| fps_val = int(str(fps)) if fps else 24 | |
| len_val = int(str(length_sec)) if length_sec else 4 | |
| vid = stitch_call( | |
| start_img=start_img, | |
| end_img=end_img, | |
| prompt=prompt or "", | |
| seed=int(seed or 0), | |
| negative_prompt=(negative_prompt or "").strip() or None, | |
| frames_per_second=fps_val, | |
| video_length=len_val, | |
| num_inference_steps=None, | |
| ) | |
| if not vid: | |
| gr.Warning("Generation failed.") | |
| return None | |
| return vid # path for preview | |
| def add_to_timeline(preview_path, timeline_paths: List[str]): | |
| """Append preview to timeline; return updated state and HTML.""" | |
| tl = list(timeline_paths or []) | |
| if not preview_path: | |
| gr.Warning("Generate a clip first.") | |
| return tl, gr.update(value=render_timeline_html(tl)) | |
| tl.append(preview_path) | |
| return tl, gr.update(value=render_timeline_html(tl)) | |
| def stitch_all_from_timeline(timeline_paths: List[str]): | |
| vids = list(timeline_paths or []) | |
| if len(vids) < 2: | |
| gr.Warning("Add at least two clips to the timeline first.") | |
| return None | |
| out = concat_many(vids) | |
| if not out: | |
| gr.Warning("Failed to concatenate clips.") | |
| return out | |
| # ========================= | |
| # UI | |
| # ========================= | |
| CSS = """ | |
| .gradio-container { padding: 24px; } | |
| .pill button { border-radius: 999px !important; padding: 10px 18px; } | |
| .rounded textarea { border-radius: 16px !important; } | |
| .gallery-row { display:flex; gap:16px; overflow-x:auto; padding:8px 4px; } | |
| .gallery-row .gradio-image { min-width: 220px; } | |
| .tl-grid { | |
| display: grid; | |
| grid-template-columns: repeat(auto-fill, minmax(180px, 1fr)); | |
| gap: 12px; | |
| } | |
| .stitch-box { | |
| background-color: #f0f4ff; /* pick any color you like */ | |
| border-radius: 12px; | |
| padding: 16px; | |
| } | |
| .tl-grid video { | |
| width: 100%; | |
| height: 120px; | |
| object-fit: cover; | |
| border-radius: 12px; | |
| display: block; | |
| } | |
| .tl-label { | |
| font-size: 12px; | |
| color: #9aa0a6; | |
| margin-top: 4px; | |
| text-align: center; | |
| } | |
| .tl-empty { color: #9aa0a6; padding: 8px 4px; } | |
| """ | |
| with gr.Blocks(css=CSS, title="StitchTool") as demo: | |
| gr.Markdown("## StitchTool") | |
| # --- State --- | |
| visible_slots = gr.State(value=3) # number of visible image slots | |
| timeline_state = gr.State(value=[]) # list[str] of video file paths (timeline) | |
| # --- Image gallery (horizontal, grows on demand) --- | |
| with gr.Row(elem_classes=["gallery-row"]): | |
| img_comps = [] | |
| for i in range(MAX_SLOTS): | |
| comp = gr.Image(label=f"Image {i+1} upload", type="pil", visible=(i < 3)) | |
| img_comps.append(comp) | |
| add_btn = gr.Button("+ Add image") | |
| # clicking add → reveal one more slot | |
| add_btn.click( | |
| fn=add_image_slot, | |
| inputs=[visible_slots], | |
| outputs=[visible_slots], | |
| ) | |
| # reflect visibility changes whenever visible_slots changes | |
| visible_slots.change( | |
| fn=_reveal_slots, | |
| inputs=[visible_slots] + img_comps, | |
| outputs=img_comps | |
| ) | |
| # Seed + Start/End selection + Prompt + options + Stitch + Preview | |
| seed = gr.Number(value=0, precision=0, label="Seed (0 = random)") | |
| with gr.Row(): | |
| # Left column: controls (with colored background via .stitch-box) | |
| with gr.Column(scale=1, min_width=420, elem_classes=["stitch-box"]): | |
| start_dd = gr.Dropdown(label="Start frame", choices=[], interactive=True) | |
| end_dd = gr.Dropdown(label="End frame", choices=[], interactive=True) | |
| prompt = gr.Textbox( | |
| placeholder="Describe the transition between the selected start and end frames…", | |
| lines=3, | |
| label="Prompt", | |
| elem_classes=["rounded"] | |
| ) | |
| negative = gr.Textbox( | |
| placeholder="Optional: things to avoid (e.g., 'no cuts, no angle switch, no text overlays')", | |
| lines=2, | |
| label="Negative prompt", | |
| elem_classes=["rounded"] | |
| ) | |
| with gr.Row(): | |
| fps = gr.Dropdown( | |
| label="Frame rate", | |
| choices=["16", "24", "32"], | |
| value="24", | |
| interactive=True, | |
| ) | |
| length_sec = gr.Dropdown( | |
| label="Video length (sec)", | |
| choices=["2", "4"], | |
| value="4", | |
| interactive=True, | |
| ) | |
| run_btn = gr.Button("Generate", elem_classes=["pill"]) | |
| add_tl_btn = gr.Button("Add to timeline", elem_classes=["pill"]) | |
| # Right column: preview video | |
| with gr.Column(scale=1, min_width=420): | |
| preview = gr.Video(label="Video output", interactive=False) | |
| # keep start/end dropdowns up to date based on which slots have images | |
| for comp in img_comps: | |
| comp.change( | |
| fn=collect_choices, | |
| inputs=img_comps, | |
| outputs=[start_dd, end_dd] | |
| ) | |
| # stitch action → preview | |
| run_btn.click( | |
| fn=stitch_selected, | |
| inputs=[prompt, negative, fps, length_sec, seed, start_dd, end_dd] + img_comps, | |
| outputs=[preview] | |
| ) | |
| # --- Dynamic timeline (no placeholders) --- | |
| with gr.Row(): | |
| timeline_html = gr.HTML(value=render_timeline_html([])) | |
| add_tl_btn.click( | |
| fn=add_to_timeline, | |
| inputs=[preview, timeline_state], | |
| outputs=[timeline_state, timeline_html] | |
| ) | |
| # final stitch all (concatenate in order) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=420): | |
| stitch_all_btn = gr.Button("Stitch All", elem_classes=["pill"]) | |
| with gr.Column(scale=1, min_width=420): | |
| final_vid = gr.Video(label="Stitched Video Output", interactive=False) | |
| stitch_all_btn.click( | |
| fn=stitch_all_from_timeline, | |
| inputs=[timeline_state], | |
| outputs=[final_vid] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |