import os, io, time, base64, random, subprocess
from typing import Optional, List
from urllib.parse import urlencode
import requests
from PIL import Image
import gradio as gr
# -------- Modal inference endpoint (dev) --------
INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"
# -------- settings --------
MAX_SLOTS = 12 # max image slots user can reveal
# -------- small helpers --------
def _save_video_bytes(data: bytes, tag: str) -> str:
os.makedirs("/tmp", exist_ok=True)
path = f"/tmp/{tag}_{int(time.time())}.mp4"
with open(path, "wb") as f:
f.write(data)
return path
def _png_bytes(img: Image.Image) -> bytes:
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
def _download_to_bytes(url: str) -> bytes:
r = requests.get(url, timeout=180)
r.raise_for_status()
return r.content
def stitch_call(
start_img: Image.Image,
end_img: Image.Image,
prompt: str,
seed: Optional[int],
negative_prompt: Optional[str] = None,
frames_per_second: int = 24,
video_length: int = 4,
num_inference_steps: Optional[int] = None,
) -> Optional[str]:
"""
Required (in body): image_bytes (+ image_bytes_end)
In URL query: prompt, negative_prompt, frames_per_second, video_length, seed, num_inference_steps
"""
if start_img is None or end_img is None:
return None
# default seed behavior
if seed in (None, 0, -1):
seed = random.randint(1, 2**31 - 1)
# Build query string
q = {
"prompt": prompt or "",
"seed": int(seed),
"frames_per_second": int(frames_per_second),
"video_length": int(video_length),
}
if negative_prompt:
q["negative_prompt"] = negative_prompt
if num_inference_steps is not None:
q["num_inference_steps"] = int(num_inference_steps)
url = f"{INFERENCE_URL}?{urlencode(q)}"
# Images go in the body
files = {
"image_bytes": ("start.png", _png_bytes(start_img), "image/png"),
"image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"),
}
headers = {"accept": "application/json"}
try:
resp = requests.post(url, files=files, headers=headers, timeout=600)
ctype = (resp.headers.get("content-type") or "").lower()
# Raw video bytes
if "application/json" not in ctype:
resp.raise_for_status()
return _save_video_bytes(resp.content, "stitch")
# JSON with url or base64
data = resp.json()
video_url = data.get("video_url") or data.get("url") or data.get("result") or data.get("output")
if isinstance(video_url, str) and video_url.startswith(("http://", "https://")):
return _save_video_bytes(_download_to_bytes(video_url), "stitch")
video_b64 = data.get("video_b64") or data.get("videoBase64")
if isinstance(video_b64, str):
pad = (-len(video_b64)) % 4
if pad:
video_b64 += "=" * pad
return _save_video_bytes(base64.b64decode(video_b64), "stitch")
except Exception as e:
print("stitch_call error:", e)
return None
# -------- FFmpeg-based concatenation (N clips) --------
def concat_many(videos: List[str]) -> Optional[str]:
vids = [v for v in videos if v]
if len(vids) < 2:
return None
try:
os.makedirs("/tmp", exist_ok=True)
out_path = f"/tmp/final_{int(time.time())}.mp4"
list_file = f"/tmp/list_{int(time.time())}.txt"
with open(list_file, "w") as f:
for v in vids:
f.write(f"file '{v}'\n")
subprocess.run(
["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_file, "-c", "copy", out_path],
check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
return out_path
except Exception as e:
print("concat_many error:", e)
return None
# -------- Timeline HTML renderer --------
def render_timeline_html(paths: List[str]):
vids = [p for p in (paths or []) if p]
if not vids:
return "
No clips yet. Generate and click ‘Add to timeline’.
"
items = []
for i, p in enumerate(vids, 1):
items.append(
f"""
"""
)
return f"{''.join(items)}
"
# =========================
# Gradio callbacks / state ops
# =========================
def add_image_slot(visible_slots: int):
"""Reveal one more upload slot (up to MAX_SLOTS)."""
return min(MAX_SLOTS, int(visible_slots) + 1)
def _reveal_slots(n, *imgs):
"""Update visibility of image upload components based on visible_slots state."""
n = int(n)
updates = []
for i in range(MAX_SLOTS):
updates.append(gr.update(visible=(i < n)))
return updates
def collect_choices(*imgs):
"""Build dropdown choices of available indices (1-based labels) based on non-empty slots."""
choices = []
for i, img in enumerate(imgs, start=1):
if img is not None:
choices.append(str(i))
return gr.update(choices=choices), gr.update(choices=choices)
def stitch_selected(
prompt, negative_prompt, fps, length_sec, seed, start_idx_str, end_idx_str, *imgs
):
"""Run inference for selected start/end indices (1-based strings) + options."""
if not start_idx_str or not end_idx_str:
gr.Warning("Please select Start and End frames.")
return None
try:
s = int(start_idx_str) - 1
e = int(end_idx_str) - 1
except Exception:
gr.Warning("Invalid Start/End selection.")
return None
if s < 0 or e < 0 or s >= len(imgs) or e >= len(imgs):
gr.Warning("Start/End out of range.")
return None
start_img = imgs[s]
end_img = imgs[e]
if start_img is None or end_img is None:
gr.Warning("Selected slots are empty.")
return None
fps_val = int(str(fps)) if fps else 24
len_val = int(str(length_sec)) if length_sec else 4
vid = stitch_call(
start_img=start_img,
end_img=end_img,
prompt=prompt or "",
seed=int(seed or 0),
negative_prompt=(negative_prompt or "").strip() or None,
frames_per_second=fps_val,
video_length=len_val,
num_inference_steps=None,
)
if not vid:
gr.Warning("Generation failed.")
return None
return vid # path for preview
def add_to_timeline(preview_path, timeline_paths: List[str]):
"""Append preview to timeline; return updated state and HTML."""
tl = list(timeline_paths or [])
if not preview_path:
gr.Warning("Generate a clip first.")
return tl, gr.update(value=render_timeline_html(tl))
tl.append(preview_path)
return tl, gr.update(value=render_timeline_html(tl))
def stitch_all_from_timeline(timeline_paths: List[str]):
vids = list(timeline_paths or [])
if len(vids) < 2:
gr.Warning("Add at least two clips to the timeline first.")
return None
out = concat_many(vids)
if not out:
gr.Warning("Failed to concatenate clips.")
return out
# =========================
# UI
# =========================
CSS = """
.gradio-container { padding: 24px; }
.pill button { border-radius: 999px !important; padding: 10px 18px; }
.rounded textarea { border-radius: 16px !important; }
.gallery-row { display:flex; gap:16px; overflow-x:auto; padding:8px 4px; }
.gallery-row .gradio-image { min-width: 220px; }
.tl-grid {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(180px, 1fr));
gap: 12px;
}
.stitch-box {
background-color: #f0f4ff; /* pick any color you like */
border-radius: 12px;
padding: 16px;
}
.tl-grid video {
width: 100%;
height: 120px;
object-fit: cover;
border-radius: 12px;
display: block;
}
.tl-label {
font-size: 12px;
color: #9aa0a6;
margin-top: 4px;
text-align: center;
}
.tl-empty { color: #9aa0a6; padding: 8px 4px; }
"""
with gr.Blocks(css=CSS, title="StitchTool") as demo:
gr.Markdown("## StitchTool")
# --- State ---
visible_slots = gr.State(value=3) # number of visible image slots
timeline_state = gr.State(value=[]) # list[str] of video file paths (timeline)
# --- Image gallery (horizontal, grows on demand) ---
with gr.Row(elem_classes=["gallery-row"]):
img_comps = []
for i in range(MAX_SLOTS):
comp = gr.Image(label=f"Image {i+1} upload", type="pil", visible=(i < 3))
img_comps.append(comp)
add_btn = gr.Button("+ Add image")
# clicking add → reveal one more slot
add_btn.click(
fn=add_image_slot,
inputs=[visible_slots],
outputs=[visible_slots],
)
# reflect visibility changes whenever visible_slots changes
visible_slots.change(
fn=_reveal_slots,
inputs=[visible_slots] + img_comps,
outputs=img_comps
)
# Seed + Start/End selection + Prompt + options + Stitch + Preview
seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
with gr.Row():
# Left column: controls (with colored background via .stitch-box)
with gr.Column(scale=1, min_width=420, elem_classes=["stitch-box"]):
start_dd = gr.Dropdown(label="Start frame", choices=[], interactive=True)
end_dd = gr.Dropdown(label="End frame", choices=[], interactive=True)
prompt = gr.Textbox(
placeholder="Describe the transition between the selected start and end frames…",
lines=3,
label="Prompt",
elem_classes=["rounded"]
)
negative = gr.Textbox(
placeholder="Optional: things to avoid (e.g., 'bad quality, extra fingers, etc.')",
lines=2,
label="Negative prompt",
elem_classes=["rounded"]
)
with gr.Row():
fps = gr.Dropdown(
label="Frame rate",
choices=["16", "24", "32"],
value="24",
interactive=True,
)
length_sec = gr.Dropdown(
label="Video length (sec)",
choices=["2", "4"],
value="4",
interactive=True,
)
run_btn = gr.Button("Generate", elem_classes=["pill"])
add_tl_btn = gr.Button("Add to timeline", elem_classes=["pill"])
# Right column: preview video
with gr.Column(scale=1, min_width=420):
preview = gr.Video(label="Video output", interactive=False)
# keep start/end dropdowns up to date based on which slots have images
for comp in img_comps:
comp.change(
fn=collect_choices,
inputs=img_comps,
outputs=[start_dd, end_dd]
)
# stitch action → preview
run_btn.click(
fn=stitch_selected,
inputs=[prompt, negative, fps, length_sec, seed, start_dd, end_dd] + img_comps,
outputs=[preview]
)
# --- Dynamic timeline (no placeholders) ---
with gr.Row():
timeline_html = gr.HTML(value=render_timeline_html([]))
add_tl_btn.click(
fn=add_to_timeline,
inputs=[preview, timeline_state],
outputs=[timeline_state, timeline_html]
)
# final stitch all (concatenate in order)
with gr.Row():
with gr.Column(scale=1, min_width=420):
stitch_all_btn = gr.Button("Stitch All", elem_classes=["pill"])
with gr.Column(scale=1, min_width=420):
final_vid = gr.Video(label="Stitched Video Output", interactive=False)
stitch_all_btn.click(
fn=stitch_all_from_timeline,
inputs=[timeline_state],
outputs=[final_vid]
)
if __name__ == "__main__":
demo.queue().launch()