File size: 5,327 Bytes
c401d55
f1cd1b3
 
c401d55
 
f1cd1b3
0d7b5a8
5a6bbaa
c401d55
f1cd1b3
 
aa92cac
f1cd1b3
6a8b4f6
92be6ab
5a6bbaa
f1cd1b3
5a6bbaa
 
f1cd1b3
 
 
 
0a73965
f1cd1b3
 
5a6bbaa
 
c4bcf4f
f1cd1b3
 
 
 
 
 
26b0caf
f1cd1b3
26b0caf
 
f1cd1b3
 
26b0caf
f1cd1b3
a71fa9d
5a6bbaa
f1cd1b3
 
 
c401d55
f1cd1b3
 
 
 
c401d55
f1cd1b3
c401d55
 
f1cd1b3
 
 
c401d55
f1cd1b3
 
aa92cac
92be6ab
f1cd1b3
 
 
92be6ab
c401d55
f1cd1b3
 
 
c401d55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92be6ab
f1cd1b3
 
 
26b0caf
f1cd1b3
 
 
 
c401d55
 
 
92be6ab
c401d55
92be6ab
f1cd1b3
 
 
 
 
c401d55
 
f1cd1b3
e4082dc
 
 
 
 
 
 
 
f1cd1b3
c401d55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1cd1b3
 
92be6ab
f1cd1b3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os, io, time, base64, random, subprocess
from typing import Optional
from urllib.parse import quote

import requests
from PIL import Image
import gradio as gr

# -------- Modal inference endpoint (dev) --------
INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"

# -------- small helpers --------
def _save_video_bytes(data: bytes, tag: str) -> str:
    os.makedirs("/tmp", exist_ok=True)
    path = f"/tmp/{tag}_{int(time.time())}.mp4"
    with open(path, "wb") as f:
        f.write(data)
    return path

def _png_bytes(img: Image.Image) -> bytes:
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    return buf.getvalue()

def _download_to_bytes(url: str) -> bytes:
    r = requests.get(url, timeout=180)
    r.raise_for_status()
    return r.content

def stitch_call(start_img: Image.Image, end_img: Image.Image, prompt: str, seed: Optional[int]) -> Optional[str]:
    if start_img is None or end_img is None:
        return None

    if seed in (None, 0, -1):
        seed = random.randint(1, 2**31 - 1)

    url = f"{INFERENCE_URL}?prompt={quote(prompt or '')}&seed={seed}"

    files = {
        "image_bytes": ("start.png", _png_bytes(start_img), "image/png"),
        "image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"),
    }
    headers = {"accept": "application/json"}

    try:
        resp = requests.post(url, files=files, headers=headers, timeout=600)
        ctype = (resp.headers.get("content-type") or "").lower()

        # Raw video bytes
        if "application/json" not in ctype:
            resp.raise_for_status()
            return _save_video_bytes(resp.content, "stitch")

        # JSON with url or base64
        data = resp.json()
        video_url = data.get("video_url") or data.get("url") or data.get("result")
        if isinstance(video_url, str) and video_url.startswith("http"):
            b = _download_to_bytes(video_url)
            return _save_video_bytes(b, "stitch")

        video_b64 = data.get("video_b64")
        if isinstance(video_b64, str):
            pad = (-len(video_b64)) % 4
            if pad:
                video_b64 += "=" * pad
            b = base64.b64decode(video_b64)
            return _save_video_bytes(b, "stitch")

    except Exception as e:
        print("stitch_call error:", e)

    return None

# -------- FFmpeg-based concatenation --------
def concat_videos(vid1: str, vid2: str) -> Optional[str]:
    if not vid1 or not vid2:
        return None
    try:
        os.makedirs("/tmp", exist_ok=True)
        out_path = f"/tmp/final_{int(time.time())}.mp4"

        # Create a temporary file list for ffmpeg
        list_file = f"/tmp/list_{int(time.time())}.txt"
        with open(list_file, "w") as f:
            f.write(f"file '{vid1}'\n")
            f.write(f"file '{vid2}'\n")

        # Run ffmpeg concat
        subprocess.run(
            ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_file, "-c", "copy", out_path],
            check=True,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )

        return out_path
    except Exception as e:
        print("concat_videos error:", e)
        return None

# -------- Gradio callbacks --------
def stitch_12(prompt12, seed, img1, img2):
    path = stitch_call(img1, img2, prompt12 or "", int(seed or 0))
    return path

def stitch_23(prompt23, seed, img2, img3):
    path = stitch_call(img2, img3, prompt23 or "", int(seed or 0))
    return path

def stitch_all(vid12, vid23):
    if vid12 is None or vid23 is None:
        gr.Warning("Generate both videos first before stitching all.")
        return None
    return concat_videos(vid12, vid23)

# -------- UI --------
CSS = """
.gradio-container { padding: 24px; }
"""

with gr.Blocks(css=CSS, title="Stitch β€” 3 uploads, 2 stitches") as demo:
    gr.Markdown("## Stitch β€” Upload 3 images β†’ Generate 1β†’2, 2β†’3, then combine.")

# --- Uploads row (side-by-side) ---
with gr.Row():
    with gr.Column(scale=1, min_width=280):
        img1 = gr.Image(label="Image 1 upload", type="pil")
    with gr.Column(scale=1, min_width=280):
        img2 = gr.Image(label="Image 2 upload", type="pil")
    with gr.Column(scale=1, min_width=280):
        img3 = gr.Image(label="Image 3 upload", type="pil")

    with gr.Row():
        # Prompts + buttons
        with gr.Column():
            seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
            prompt12 = gr.Textbox(placeholder="Prompt for stitching 1β†’2", lines=2, label="Prompt (1β†’2)")
            btn12 = gr.Button("Stitch 1&2")
            prompt23 = gr.Textbox(placeholder="Prompt for stitching 2β†’3", lines=2, label="Prompt (2β†’3)")
            btn23 = gr.Button("Stitch 2&3")
            btn_all = gr.Button("Stitch All (combine 1β†’2 and 2β†’3)")

        with gr.Column():
            vid12 = gr.Video(label="Video (1β†’2)")
            vid23 = gr.Video(label="Video (2β†’3)")
            vid_all = gr.Video(label="Final Combined Video")

    # Wire
    btn12.click(stitch_12, inputs=[prompt12, seed, img1, img2], outputs=[vid12])
    btn23.click(stitch_23, inputs=[prompt23, seed, img2, img3], outputs=[vid23])
    btn_all.click(stitch_all, inputs=[vid12, vid23], outputs=[vid_all])

if __name__ == "__main__":
    demo.queue().launch()