Spaces:
Sleeping
Sleeping
File size: 5,327 Bytes
c401d55 f1cd1b3 c401d55 f1cd1b3 0d7b5a8 5a6bbaa c401d55 f1cd1b3 aa92cac f1cd1b3 6a8b4f6 92be6ab 5a6bbaa f1cd1b3 5a6bbaa f1cd1b3 0a73965 f1cd1b3 5a6bbaa c4bcf4f f1cd1b3 26b0caf f1cd1b3 26b0caf f1cd1b3 26b0caf f1cd1b3 a71fa9d 5a6bbaa f1cd1b3 c401d55 f1cd1b3 c401d55 f1cd1b3 c401d55 f1cd1b3 c401d55 f1cd1b3 aa92cac 92be6ab f1cd1b3 92be6ab c401d55 f1cd1b3 c401d55 92be6ab f1cd1b3 26b0caf f1cd1b3 c401d55 92be6ab c401d55 92be6ab f1cd1b3 c401d55 f1cd1b3 e4082dc f1cd1b3 c401d55 f1cd1b3 92be6ab f1cd1b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import os, io, time, base64, random, subprocess
from typing import Optional
from urllib.parse import quote
import requests
from PIL import Image
import gradio as gr
# -------- Modal inference endpoint (dev) --------
INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"
# -------- small helpers --------
def _save_video_bytes(data: bytes, tag: str) -> str:
os.makedirs("/tmp", exist_ok=True)
path = f"/tmp/{tag}_{int(time.time())}.mp4"
with open(path, "wb") as f:
f.write(data)
return path
def _png_bytes(img: Image.Image) -> bytes:
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
def _download_to_bytes(url: str) -> bytes:
r = requests.get(url, timeout=180)
r.raise_for_status()
return r.content
def stitch_call(start_img: Image.Image, end_img: Image.Image, prompt: str, seed: Optional[int]) -> Optional[str]:
if start_img is None or end_img is None:
return None
if seed in (None, 0, -1):
seed = random.randint(1, 2**31 - 1)
url = f"{INFERENCE_URL}?prompt={quote(prompt or '')}&seed={seed}"
files = {
"image_bytes": ("start.png", _png_bytes(start_img), "image/png"),
"image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"),
}
headers = {"accept": "application/json"}
try:
resp = requests.post(url, files=files, headers=headers, timeout=600)
ctype = (resp.headers.get("content-type") or "").lower()
# Raw video bytes
if "application/json" not in ctype:
resp.raise_for_status()
return _save_video_bytes(resp.content, "stitch")
# JSON with url or base64
data = resp.json()
video_url = data.get("video_url") or data.get("url") or data.get("result")
if isinstance(video_url, str) and video_url.startswith("http"):
b = _download_to_bytes(video_url)
return _save_video_bytes(b, "stitch")
video_b64 = data.get("video_b64")
if isinstance(video_b64, str):
pad = (-len(video_b64)) % 4
if pad:
video_b64 += "=" * pad
b = base64.b64decode(video_b64)
return _save_video_bytes(b, "stitch")
except Exception as e:
print("stitch_call error:", e)
return None
# -------- FFmpeg-based concatenation --------
def concat_videos(vid1: str, vid2: str) -> Optional[str]:
if not vid1 or not vid2:
return None
try:
os.makedirs("/tmp", exist_ok=True)
out_path = f"/tmp/final_{int(time.time())}.mp4"
# Create a temporary file list for ffmpeg
list_file = f"/tmp/list_{int(time.time())}.txt"
with open(list_file, "w") as f:
f.write(f"file '{vid1}'\n")
f.write(f"file '{vid2}'\n")
# Run ffmpeg concat
subprocess.run(
["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_file, "-c", "copy", out_path],
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
return out_path
except Exception as e:
print("concat_videos error:", e)
return None
# -------- Gradio callbacks --------
def stitch_12(prompt12, seed, img1, img2):
path = stitch_call(img1, img2, prompt12 or "", int(seed or 0))
return path
def stitch_23(prompt23, seed, img2, img3):
path = stitch_call(img2, img3, prompt23 or "", int(seed or 0))
return path
def stitch_all(vid12, vid23):
if vid12 is None or vid23 is None:
gr.Warning("Generate both videos first before stitching all.")
return None
return concat_videos(vid12, vid23)
# -------- UI --------
CSS = """
.gradio-container { padding: 24px; }
"""
with gr.Blocks(css=CSS, title="Stitch β 3 uploads, 2 stitches") as demo:
gr.Markdown("## Stitch β Upload 3 images β Generate 1β2, 2β3, then combine.")
# --- Uploads row (side-by-side) ---
with gr.Row():
with gr.Column(scale=1, min_width=280):
img1 = gr.Image(label="Image 1 upload", type="pil")
with gr.Column(scale=1, min_width=280):
img2 = gr.Image(label="Image 2 upload", type="pil")
with gr.Column(scale=1, min_width=280):
img3 = gr.Image(label="Image 3 upload", type="pil")
with gr.Row():
# Prompts + buttons
with gr.Column():
seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
prompt12 = gr.Textbox(placeholder="Prompt for stitching 1β2", lines=2, label="Prompt (1β2)")
btn12 = gr.Button("Stitch 1&2")
prompt23 = gr.Textbox(placeholder="Prompt for stitching 2β3", lines=2, label="Prompt (2β3)")
btn23 = gr.Button("Stitch 2&3")
btn_all = gr.Button("Stitch All (combine 1β2 and 2β3)")
with gr.Column():
vid12 = gr.Video(label="Video (1β2)")
vid23 = gr.Video(label="Video (2β3)")
vid_all = gr.Video(label="Final Combined Video")
# Wire
btn12.click(stitch_12, inputs=[prompt12, seed, img1, img2], outputs=[vid12])
btn23.click(stitch_23, inputs=[prompt23, seed, img2, img3], outputs=[vid23])
btn_all.click(stitch_all, inputs=[vid12, vid23], outputs=[vid_all])
if __name__ == "__main__":
demo.queue().launch()
|