Spaces:
Running
Running
File size: 5,669 Bytes
92be6ab f1cd1b3 5a6bbaa f1cd1b3 0d7b5a8 5a6bbaa 92be6ab f1cd1b3 92be6ab f1cd1b3 6a8b4f6 92be6ab 5a6bbaa f1cd1b3 5a6bbaa f1cd1b3 0a73965 f1cd1b3 5a6bbaa c4bcf4f f1cd1b3 26b0caf f1cd1b3 26b0caf f1cd1b3 26b0caf f1cd1b3 a71fa9d 5a6bbaa f1cd1b3 92be6ab f1cd1b3 92be6ab f1cd1b3 92be6ab f1cd1b3 26b0caf f1cd1b3 92be6ab f1cd1b3 92be6ab f1cd1b3 92be6ab f1cd1b3 92be6ab f1cd1b3 92be6ab f1cd1b3 92be6ab f1cd1b3 92be6ab f1cd1b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import os, io, time, base64, random, subprocess
from typing import Optional
from urllib.parse import quote
import requests
from PIL import Image
import gradio as gr
# -------- Modal inference endpoint --------
INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"
# -------- Helpers --------
def _save_video_bytes(data: bytes, tag: str) -> str:
os.makedirs("/tmp", exist_ok=True)
path = f"/tmp/{tag}_{int(time.time())}.mp4"
with open(path, "wb") as f:
f.write(data)
return path
def _png_bytes(img: Image.Image) -> bytes:
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
def _download_to_bytes(url: str) -> bytes:
r = requests.get(url, timeout=180)
r.raise_for_status()
return r.content
def stitch_call(start_img: Image.Image, end_img: Image.Image, prompt: str, seed: Optional[int]) -> Optional[str]:
if start_img is None or end_img is None:
return None
if seed in (None, 0, -1):
seed = random.randint(1, 2**31 - 1)
url = f"{INFERENCE_URL}?prompt={quote(prompt or '')}&seed={seed}"
files = {
"image_bytes": ("start.png", _png_bytes(start_img), "image/png"),
"image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"),
}
headers = {"accept": "application/json"}
try:
resp = requests.post(url, files=files, headers=headers, timeout=600)
ctype = (resp.headers.get("content-type") or "").lower()
# Raw video bytes
if "application/json" not in ctype:
resp.raise_for_status()
return _save_video_bytes(resp.content, "stitch")
# JSON with url or base64
data = resp.json()
video_url = data.get("video_url") or data.get("url") or data.get("result") or data.get("output")
if isinstance(video_url, str) and (video_url.startswith("http://") or video_url.startswith("https://")):
b = _download_to_bytes(video_url)
return _save_video_bytes(b, "stitch")
video_b64 = data.get("video_b64") or data.get("videoBase64")
if isinstance(video_b64, str):
pad = (-len(video_b64)) % 4
if pad:
video_b64 += "=" * pad
b = base64.b64decode(video_b64)
return _save_video_bytes(b, "stitch")
except Exception as e:
print("Stitch call failed:", e)
return None
# -------- Gradio callbacks --------
def stitch_12(prompt12, seed, img1, img2):
if img1 is None or img2 is None:
gr.Warning("Please upload Image 1 and Image 2.")
return None
path = stitch_call(img1, img2, prompt12 or "", int(seed or 0))
if path is None:
gr.Warning("Stitch 1&2 failed. Try again or adjust the prompt.")
return path
def stitch_23(prompt23, seed, img2, img3):
if img2 is None or img3 is None:
gr.Warning("Please upload Image 2 and Image 3.")
return None
path = stitch_call(img2, img3, prompt23 or "", int(seed or 0))
if path is None:
gr.Warning("Stitch 2&3 failed. Try again or adjust the prompt.")
return path
def stitch_all(video12, video23):
if not video12 or not video23:
gr.Warning("Please generate both stitched videos first.")
return None
try:
# Final output path
out_path = f"/tmp/stitch_all_{int(time.time())}.mp4"
# Concatenate with ffmpeg
txt_file = f"/tmp/concat_{int(time.time())}.txt"
with open(txt_file, "w") as f:
f.write(f"file '{video12}'\n")
f.write(f"file '{video23}'\n")
cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", txt_file, "-c", "copy", out_path]
subprocess.run(cmd, check=True)
return out_path
except Exception as e:
print("Stitch all failed:", e)
gr.Warning("Failed to stitch all videos together.")
return None
# -------- UI --------
CSS = """
.gradio-container { padding: 24px; }
.pill button { border-radius: 999px !important; padding: 10px 18px; }
.rounded textarea { border-radius: 16px !important; }
"""
with gr.Blocks(css=CSS, title="Stitch β 3 uploads, 2 stitches, concat") as demo:
gr.Markdown("## Stitch β Upload 3 images, generate videos between 1β2 and 2β3, then merge them.")
with gr.Row():
with gr.Column(scale=1, min_width=320):
img1 = gr.Image(label="Image 1 upload", type="pil")
img2 = gr.Image(label="Image 2 upload", type="pil")
img3 = gr.Image(label="Image 3 upload", type="pil")
with gr.Column(scale=1, min_width=320):
seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
prompt12 = gr.Textbox(placeholder="Prompt for stitching 1β2", lines=2, label="Prompt (1β2)", elem_classes=["rounded"])
btn12 = gr.Button("Stitch 1&2", elem_classes=["pill"])
vid12 = gr.Video(label="Video (image 1+2) output")
prompt23 = gr.Textbox(placeholder="Prompt for stitching 2β3", lines=2, label="Prompt (2β3)", elem_classes=["rounded"])
btn23 = gr.Button("Stitch 2&3", elem_classes=["pill"])
vid23 = gr.Video(label="Video (image 2+3) output")
btn_all = gr.Button("Stitch All", elem_classes=["pill"])
vid_all = gr.Video(label="Final concatenated video")
# Wire buttons
btn12.click(stitch_12, inputs=[prompt12, seed, img1, img2], outputs=[vid12])
btn23.click(stitch_23, inputs=[prompt23, seed, img2, img3], outputs=[vid23])
btn_all.click(stitch_all, inputs=[vid12, vid23], outputs=[vid_all])
if __name__ == "__main__":
demo.queue().launch()
|