StitchTool / app.py
Shalmoni's picture
Update app.py
13a051d verified
raw
history blame
12.1 kB
import os, io, time, base64, random, subprocess
from typing import Optional, List
from urllib.parse import urlencode
import requests
from PIL import Image
import gradio as gr
# -------- Modal inference endpoint (dev) --------
INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"
# -------- settings --------
MAX_SLOTS = 12 # max image slots user can reveal
# -------- small helpers --------
def _save_video_bytes(data: bytes, tag: str) -> str:
os.makedirs("/tmp", exist_ok=True)
path = f"/tmp/{tag}_{int(time.time())}.mp4"
with open(path, "wb") as f:
f.write(data)
return path
def _png_bytes(img: Image.Image) -> bytes:
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
def _download_to_bytes(url: str) -> bytes:
r = requests.get(url, timeout=180)
r.raise_for_status()
return r.content
def stitch_call(
start_img: Image.Image,
end_img: Image.Image,
prompt: str,
seed: Optional[int],
negative_prompt: Optional[str] = None,
frames_per_second: int = 24,
video_length: int = 4,
num_inference_steps: Optional[int] = None,
) -> Optional[str]:
"""
Required (in body): image_bytes (+ image_bytes_end)
In URL query: prompt, negative_prompt, frames_per_second, video_length, seed, num_inference_steps
"""
if start_img is None or end_img is None:
return None
# default seed behavior
if seed in (None, 0, -1):
seed = random.randint(1, 2**31 - 1)
# Build query string
q = {
"prompt": prompt or "",
"seed": int(seed),
"frames_per_second": int(frames_per_second),
"video_length": int(video_length),
}
if negative_prompt:
q["negative_prompt"] = negative_prompt
if num_inference_steps is not None:
q["num_inference_steps"] = int(num_inference_steps)
url = f"{INFERENCE_URL}?{urlencode(q)}"
# Images go in the body
files = {
"image_bytes": ("start.png", _png_bytes(start_img), "image/png"),
"image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"),
}
headers = {"accept": "application/json"}
try:
resp = requests.post(url, files=files, headers=headers, timeout=600)
ctype = (resp.headers.get("content-type") or "").lower()
# Raw video bytes
if "application/json" not in ctype:
resp.raise_for_status()
return _save_video_bytes(resp.content, "stitch")
# JSON with url or base64
data = resp.json()
video_url = data.get("video_url") or data.get("url") or data.get("result") or data.get("output")
if isinstance(video_url, str) and video_url.startswith(("http://", "https://")):
return _save_video_bytes(_download_to_bytes(video_url), "stitch")
video_b64 = data.get("video_b64") or data.get("videoBase64")
if isinstance(video_b64, str):
pad = (-len(video_b64)) % 4
if pad:
video_b64 += "=" * pad
return _save_video_bytes(base64.b64decode(video_b64), "stitch")
except Exception as e:
print("stitch_call error:", e)
return None
# -------- FFmpeg-based concatenation (N clips) --------
def concat_many(videos: List[str]) -> Optional[str]:
vids = [v for v in videos if v]
if len(vids) < 2:
return None
try:
os.makedirs("/tmp", exist_ok=True)
out_path = f"/tmp/final_{int(time.time())}.mp4"
list_file = f"/tmp/list_{int(time.time())}.txt"
with open(list_file, "w") as f:
for v in vids:
f.write(f"file '{v}'\n")
subprocess.run(
["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_file, "-c", "copy", out_path],
check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
return out_path
except Exception as e:
print("concat_many error:", e)
return None
# -------- Timeline HTML renderer --------
def render_timeline_html(paths: List[str]):
vids = [p for p in (paths or []) if p]
if not vids:
return "<div class='tl-grid tl-empty'>No clips yet. Generate and click ‘Add to timeline’.</div>"
items = []
for i, p in enumerate(vids, 1):
items.append(
f"""
<div class="tl-item">
<video src="{p}" controls playsinline></video>
<div class="tl-label">Clip {i}</div>
</div>
"""
)
return f"<div class='tl-grid'>{''.join(items)}</div>"
# =========================
# Gradio callbacks / state ops
# =========================
def add_image_slot(visible_slots: int):
"""Reveal one more upload slot (up to MAX_SLOTS)."""
return min(MAX_SLOTS, int(visible_slots) + 1)
def _reveal_slots(n, *imgs):
"""Update visibility of image upload components based on visible_slots state."""
n = int(n)
updates = []
for i in range(MAX_SLOTS):
updates.append(gr.update(visible=(i < n)))
return updates
def collect_choices(*imgs):
"""Build dropdown choices of available indices (1-based labels) based on non-empty slots."""
choices = []
for i, img in enumerate(imgs, start=1):
if img is not None:
choices.append(str(i))
return gr.update(choices=choices), gr.update(choices=choices)
def stitch_selected(
prompt, negative_prompt, fps, length_sec, seed, start_idx_str, end_idx_str, *imgs
):
"""Run inference for selected start/end indices (1-based strings) + options."""
if not start_idx_str or not end_idx_str:
gr.Warning("Please select Start and End frames.")
return None
try:
s = int(start_idx_str) - 1
e = int(end_idx_str) - 1
except Exception:
gr.Warning("Invalid Start/End selection.")
return None
if s < 0 or e < 0 or s >= len(imgs) or e >= len(imgs):
gr.Warning("Start/End out of range.")
return None
start_img = imgs[s]
end_img = imgs[e]
if start_img is None or end_img is None:
gr.Warning("Selected slots are empty.")
return None
fps_val = int(str(fps)) if fps else 24
len_val = int(str(length_sec)) if length_sec else 4
vid = stitch_call(
start_img=start_img,
end_img=end_img,
prompt=prompt or "",
seed=int(seed or 0),
negative_prompt=(negative_prompt or "").strip() or None,
frames_per_second=fps_val,
video_length=len_val,
num_inference_steps=None,
)
if not vid:
gr.Warning("Generation failed.")
return None
return vid # path for preview
def add_to_timeline(preview_path, timeline_paths: List[str]):
"""Append preview to timeline; return updated state and HTML."""
tl = list(timeline_paths or [])
if not preview_path:
gr.Warning("Generate a clip first.")
return tl, gr.update(value=render_timeline_html(tl))
tl.append(preview_path)
return tl, gr.update(value=render_timeline_html(tl))
def stitch_all_from_timeline(timeline_paths: List[str]):
vids = list(timeline_paths or [])
if len(vids) < 2:
gr.Warning("Add at least two clips to the timeline first.")
return None
out = concat_many(vids)
if not out:
gr.Warning("Failed to concatenate clips.")
return out
# =========================
# UI
# =========================
CSS = """
.gradio-container { padding: 24px; }
.pill button { border-radius: 999px !important; padding: 10px 18px; }
.rounded textarea { border-radius: 16px !important; }
.gallery-row { display:flex; gap:16px; overflow-x:auto; padding:8px 4px; }
.gallery-row .gradio-image { min-width: 220px; }
.tl-grid {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(180px, 1fr));
gap: 12px;
}
.stitch-box {
background-color: #f0f4ff; /* pick any color you like */
border-radius: 12px;
padding: 16px;
}
.tl-grid video {
width: 100%;
height: 120px;
object-fit: cover;
border-radius: 12px;
display: block;
}
.tl-label {
font-size: 12px;
color: #9aa0a6;
margin-top: 4px;
text-align: center;
}
.tl-empty { color: #9aa0a6; padding: 8px 4px; }
"""
with gr.Blocks(css=CSS, title="StitchTool") as demo:
gr.Markdown("## StitchTool")
# --- State ---
visible_slots = gr.State(value=3) # number of visible image slots
timeline_state = gr.State(value=[]) # list[str] of video file paths (timeline)
# --- Image gallery (horizontal, grows on demand) ---
with gr.Row(elem_classes=["gallery-row"]):
img_comps = []
for i in range(MAX_SLOTS):
comp = gr.Image(label=f"Image {i+1} upload", type="pil", visible=(i < 3))
img_comps.append(comp)
add_btn = gr.Button("+ Add image")
# clicking add → reveal one more slot
add_btn.click(
fn=add_image_slot,
inputs=[visible_slots],
outputs=[visible_slots],
)
# reflect visibility changes whenever visible_slots changes
visible_slots.change(
fn=_reveal_slots,
inputs=[visible_slots] + img_comps,
outputs=img_comps
)
# Seed + Start/End selection + Prompt + options + Stitch + Preview
seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
with gr.Row():
# Left column: controls (with colored background via .stitch-box)
with gr.Column(scale=1, min_width=420, elem_classes=["stitch-box"]):
start_dd = gr.Dropdown(label="Start frame", choices=[], interactive=True)
end_dd = gr.Dropdown(label="End frame", choices=[], interactive=True)
prompt = gr.Textbox(
placeholder="Describe the transition between the selected start and end frames…",
lines=3,
label="Prompt",
elem_classes=["rounded"]
)
negative = gr.Textbox(
placeholder="Optional: things to avoid (e.g., 'no cuts, no angle switch, no text overlays')",
lines=2,
label="Negative prompt",
elem_classes=["rounded"]
)
with gr.Row():
fps = gr.Dropdown(
label="Frame rate",
choices=["16", "24", "32"],
value="24",
interactive=True,
)
length_sec = gr.Dropdown(
label="Video length (sec)",
choices=["2", "4"],
value="4",
interactive=True,
)
run_btn = gr.Button("Generate", elem_classes=["pill"])
add_tl_btn = gr.Button("Add to timeline", elem_classes=["pill"])
# Right column: preview video
with gr.Column(scale=1, min_width=420):
preview = gr.Video(label="Video output", interactive=False)
# keep start/end dropdowns up to date based on which slots have images
for comp in img_comps:
comp.change(
fn=collect_choices,
inputs=img_comps,
outputs=[start_dd, end_dd]
)
# stitch action → preview
run_btn.click(
fn=stitch_selected,
inputs=[prompt, negative, fps, length_sec, seed, start_dd, end_dd] + img_comps,
outputs=[preview]
)
# --- Dynamic timeline (no placeholders) ---
with gr.Row():
timeline_html = gr.HTML(value=render_timeline_html([]))
add_tl_btn.click(
fn=add_to_timeline,
inputs=[preview, timeline_state],
outputs=[timeline_state, timeline_html]
)
# final stitch all (concatenate in order)
with gr.Row():
with gr.Column(scale=1, min_width=420):
stitch_all_btn = gr.Button("Stitch All", elem_classes=["pill"])
with gr.Column(scale=1, min_width=420):
final_vid = gr.Video(label="Stitched Video Output", interactive=False)
stitch_all_btn.click(
fn=stitch_all_from_timeline,
inputs=[timeline_state],
outputs=[final_vid]
)
if __name__ == "__main__":
demo.queue().launch()