StitchTool / app.py
Shalmoni's picture
Update app.py
92be6ab verified
raw
history blame
5.67 kB
import os, io, time, base64, random, subprocess
from typing import Optional
from urllib.parse import quote
import requests
from PIL import Image
import gradio as gr
# -------- Modal inference endpoint --------
INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"
# -------- Helpers --------
def _save_video_bytes(data: bytes, tag: str) -> str:
os.makedirs("/tmp", exist_ok=True)
path = f"/tmp/{tag}_{int(time.time())}.mp4"
with open(path, "wb") as f:
f.write(data)
return path
def _png_bytes(img: Image.Image) -> bytes:
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
def _download_to_bytes(url: str) -> bytes:
r = requests.get(url, timeout=180)
r.raise_for_status()
return r.content
def stitch_call(start_img: Image.Image, end_img: Image.Image, prompt: str, seed: Optional[int]) -> Optional[str]:
if start_img is None or end_img is None:
return None
if seed in (None, 0, -1):
seed = random.randint(1, 2**31 - 1)
url = f"{INFERENCE_URL}?prompt={quote(prompt or '')}&seed={seed}"
files = {
"image_bytes": ("start.png", _png_bytes(start_img), "image/png"),
"image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"),
}
headers = {"accept": "application/json"}
try:
resp = requests.post(url, files=files, headers=headers, timeout=600)
ctype = (resp.headers.get("content-type") or "").lower()
# Raw video bytes
if "application/json" not in ctype:
resp.raise_for_status()
return _save_video_bytes(resp.content, "stitch")
# JSON with url or base64
data = resp.json()
video_url = data.get("video_url") or data.get("url") or data.get("result") or data.get("output")
if isinstance(video_url, str) and (video_url.startswith("http://") or video_url.startswith("https://")):
b = _download_to_bytes(video_url)
return _save_video_bytes(b, "stitch")
video_b64 = data.get("video_b64") or data.get("videoBase64")
if isinstance(video_b64, str):
pad = (-len(video_b64)) % 4
if pad:
video_b64 += "=" * pad
b = base64.b64decode(video_b64)
return _save_video_bytes(b, "stitch")
except Exception as e:
print("Stitch call failed:", e)
return None
# -------- Gradio callbacks --------
def stitch_12(prompt12, seed, img1, img2):
if img1 is None or img2 is None:
gr.Warning("Please upload Image 1 and Image 2.")
return None
path = stitch_call(img1, img2, prompt12 or "", int(seed or 0))
if path is None:
gr.Warning("Stitch 1&2 failed. Try again or adjust the prompt.")
return path
def stitch_23(prompt23, seed, img2, img3):
if img2 is None or img3 is None:
gr.Warning("Please upload Image 2 and Image 3.")
return None
path = stitch_call(img2, img3, prompt23 or "", int(seed or 0))
if path is None:
gr.Warning("Stitch 2&3 failed. Try again or adjust the prompt.")
return path
def stitch_all(video12, video23):
if not video12 or not video23:
gr.Warning("Please generate both stitched videos first.")
return None
try:
# Final output path
out_path = f"/tmp/stitch_all_{int(time.time())}.mp4"
# Concatenate with ffmpeg
txt_file = f"/tmp/concat_{int(time.time())}.txt"
with open(txt_file, "w") as f:
f.write(f"file '{video12}'\n")
f.write(f"file '{video23}'\n")
cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", txt_file, "-c", "copy", out_path]
subprocess.run(cmd, check=True)
return out_path
except Exception as e:
print("Stitch all failed:", e)
gr.Warning("Failed to stitch all videos together.")
return None
# -------- UI --------
CSS = """
.gradio-container { padding: 24px; }
.pill button { border-radius: 999px !important; padding: 10px 18px; }
.rounded textarea { border-radius: 16px !important; }
"""
with gr.Blocks(css=CSS, title="Stitch β€” 3 uploads, 2 stitches, concat") as demo:
gr.Markdown("## Stitch β€” Upload 3 images, generate videos between 1β†’2 and 2β†’3, then merge them.")
with gr.Row():
with gr.Column(scale=1, min_width=320):
img1 = gr.Image(label="Image 1 upload", type="pil")
img2 = gr.Image(label="Image 2 upload", type="pil")
img3 = gr.Image(label="Image 3 upload", type="pil")
with gr.Column(scale=1, min_width=320):
seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
prompt12 = gr.Textbox(placeholder="Prompt for stitching 1β†’2", lines=2, label="Prompt (1β†’2)", elem_classes=["rounded"])
btn12 = gr.Button("Stitch 1&2", elem_classes=["pill"])
vid12 = gr.Video(label="Video (image 1+2) output")
prompt23 = gr.Textbox(placeholder="Prompt for stitching 2β†’3", lines=2, label="Prompt (2β†’3)", elem_classes=["rounded"])
btn23 = gr.Button("Stitch 2&3", elem_classes=["pill"])
vid23 = gr.Video(label="Video (image 2+3) output")
btn_all = gr.Button("Stitch All", elem_classes=["pill"])
vid_all = gr.Video(label="Final concatenated video")
# Wire buttons
btn12.click(stitch_12, inputs=[prompt12, seed, img1, img2], outputs=[vid12])
btn23.click(stitch_23, inputs=[prompt23, seed, img2, img3], outputs=[vid23])
btn_all.click(stitch_all, inputs=[vid12, vid23], outputs=[vid_all])
if __name__ == "__main__":
demo.queue().launch()