StitchTool / app.py
Shalmoni's picture
Update app.py
c401d55 verified
raw
history blame
5.22 kB
import os, io, time, base64, random, subprocess
from typing import Optional
from urllib.parse import quote
import requests
from PIL import Image
import gradio as gr
# -------- Modal inference endpoint (dev) --------
INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"
# -------- small helpers --------
def _save_video_bytes(data: bytes, tag: str) -> str:
os.makedirs("/tmp", exist_ok=True)
path = f"/tmp/{tag}_{int(time.time())}.mp4"
with open(path, "wb") as f:
f.write(data)
return path
def _png_bytes(img: Image.Image) -> bytes:
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
def _download_to_bytes(url: str) -> bytes:
r = requests.get(url, timeout=180)
r.raise_for_status()
return r.content
def stitch_call(start_img: Image.Image, end_img: Image.Image, prompt: str, seed: Optional[int]) -> Optional[str]:
if start_img is None or end_img is None:
return None
if seed in (None, 0, -1):
seed = random.randint(1, 2**31 - 1)
url = f"{INFERENCE_URL}?prompt={quote(prompt or '')}&seed={seed}"
files = {
"image_bytes": ("start.png", _png_bytes(start_img), "image/png"),
"image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"),
}
headers = {"accept": "application/json"}
try:
resp = requests.post(url, files=files, headers=headers, timeout=600)
ctype = (resp.headers.get("content-type") or "").lower()
# Raw video bytes
if "application/json" not in ctype:
resp.raise_for_status()
return _save_video_bytes(resp.content, "stitch")
# JSON with url or base64
data = resp.json()
video_url = data.get("video_url") or data.get("url") or data.get("result")
if isinstance(video_url, str) and video_url.startswith("http"):
b = _download_to_bytes(video_url)
return _save_video_bytes(b, "stitch")
video_b64 = data.get("video_b64")
if isinstance(video_b64, str):
pad = (-len(video_b64)) % 4
if pad:
video_b64 += "=" * pad
b = base64.b64decode(video_b64)
return _save_video_bytes(b, "stitch")
except Exception as e:
print("stitch_call error:", e)
return None
# -------- FFmpeg-based concatenation --------
def concat_videos(vid1: str, vid2: str) -> Optional[str]:
if not vid1 or not vid2:
return None
try:
os.makedirs("/tmp", exist_ok=True)
out_path = f"/tmp/final_{int(time.time())}.mp4"
# Create a temporary file list for ffmpeg
list_file = f"/tmp/list_{int(time.time())}.txt"
with open(list_file, "w") as f:
f.write(f"file '{vid1}'\n")
f.write(f"file '{vid2}'\n")
# Run ffmpeg concat
subprocess.run(
["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_file, "-c", "copy", out_path],
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
return out_path
except Exception as e:
print("concat_videos error:", e)
return None
# -------- Gradio callbacks --------
def stitch_12(prompt12, seed, img1, img2):
path = stitch_call(img1, img2, prompt12 or "", int(seed or 0))
return path
def stitch_23(prompt23, seed, img2, img3):
path = stitch_call(img2, img3, prompt23 or "", int(seed or 0))
return path
def stitch_all(vid12, vid23):
if vid12 is None or vid23 is None:
gr.Warning("Generate both videos first before stitching all.")
return None
return concat_videos(vid12, vid23)
# -------- UI --------
CSS = """
.gradio-container { padding: 24px; }
"""
with gr.Blocks(css=CSS, title="Stitch β€” 3 uploads, 2 stitches") as demo:
gr.Markdown("## Stitch β€” Upload 3 images β†’ Generate 1β†’2, 2β†’3, then combine.")
with gr.Row():
# Top: images
with gr.Column():
img1 = gr.Image(label="Image 1 upload", type="pil")
img2 = gr.Image(label="Image 2 upload", type="pil")
img3 = gr.Image(label="Image 3 upload", type="pil")
with gr.Row():
# Prompts + buttons
with gr.Column():
seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
prompt12 = gr.Textbox(placeholder="Prompt for stitching 1β†’2", lines=2, label="Prompt (1β†’2)")
btn12 = gr.Button("Stitch 1&2")
prompt23 = gr.Textbox(placeholder="Prompt for stitching 2β†’3", lines=2, label="Prompt (2β†’3)")
btn23 = gr.Button("Stitch 2&3")
btn_all = gr.Button("Stitch All (combine 1β†’2 and 2β†’3)")
with gr.Column():
vid12 = gr.Video(label="Video (1β†’2)")
vid23 = gr.Video(label="Video (2β†’3)")
vid_all = gr.Video(label="Final Combined Video")
# Wire
btn12.click(stitch_12, inputs=[prompt12, seed, img1, img2], outputs=[vid12])
btn23.click(stitch_23, inputs=[prompt23, seed, img2, img3], outputs=[vid23])
btn_all.click(stitch_all, inputs=[vid12, vid23], outputs=[vid_all])
if __name__ == "__main__":
demo.queue().launch()