File size: 5,669 Bytes
92be6ab
f1cd1b3
 
5a6bbaa
 
f1cd1b3
0d7b5a8
5a6bbaa
92be6ab
f1cd1b3
 
92be6ab
f1cd1b3
6a8b4f6
92be6ab
5a6bbaa
f1cd1b3
5a6bbaa
 
f1cd1b3
 
 
 
0a73965
f1cd1b3
 
5a6bbaa
 
c4bcf4f
f1cd1b3
 
 
 
 
 
26b0caf
f1cd1b3
26b0caf
 
f1cd1b3
 
26b0caf
f1cd1b3
a71fa9d
5a6bbaa
f1cd1b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92be6ab
 
f1cd1b3
 
 
92be6ab
 
f1cd1b3
 
 
92be6ab
f1cd1b3
 
 
 
 
 
 
 
26b0caf
f1cd1b3
 
 
 
 
 
 
 
 
92be6ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1cd1b3
 
 
 
 
 
 
92be6ab
 
f1cd1b3
 
92be6ab
f1cd1b3
 
 
 
92be6ab
f1cd1b3
 
 
92be6ab
 
f1cd1b3
 
 
 
92be6ab
 
 
f1cd1b3
 
 
92be6ab
f1cd1b3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os, io, time, base64, random, subprocess
from typing import Optional
from urllib.parse import quote

import requests
from PIL import Image
import gradio as gr

# -------- Modal inference endpoint --------
INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"

# -------- Helpers --------
def _save_video_bytes(data: bytes, tag: str) -> str:
    os.makedirs("/tmp", exist_ok=True)
    path = f"/tmp/{tag}_{int(time.time())}.mp4"
    with open(path, "wb") as f:
        f.write(data)
    return path

def _png_bytes(img: Image.Image) -> bytes:
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    return buf.getvalue()

def _download_to_bytes(url: str) -> bytes:
    r = requests.get(url, timeout=180)
    r.raise_for_status()
    return r.content

def stitch_call(start_img: Image.Image, end_img: Image.Image, prompt: str, seed: Optional[int]) -> Optional[str]:
    if start_img is None or end_img is None:
        return None

    if seed in (None, 0, -1):
        seed = random.randint(1, 2**31 - 1)

    url = f"{INFERENCE_URL}?prompt={quote(prompt or '')}&seed={seed}"

    files = {
        "image_bytes": ("start.png", _png_bytes(start_img), "image/png"),
        "image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"),
    }
    headers = {"accept": "application/json"}

    try:
        resp = requests.post(url, files=files, headers=headers, timeout=600)
        ctype = (resp.headers.get("content-type") or "").lower()

        # Raw video bytes
        if "application/json" not in ctype:
            resp.raise_for_status()
            return _save_video_bytes(resp.content, "stitch")

        # JSON with url or base64
        data = resp.json()
        video_url = data.get("video_url") or data.get("url") or data.get("result") or data.get("output")
        if isinstance(video_url, str) and (video_url.startswith("http://") or video_url.startswith("https://")):
            b = _download_to_bytes(video_url)
            return _save_video_bytes(b, "stitch")

        video_b64 = data.get("video_b64") or data.get("videoBase64")
        if isinstance(video_b64, str):
            pad = (-len(video_b64)) % 4
            if pad: 
                video_b64 += "=" * pad
            b = base64.b64decode(video_b64)
            return _save_video_bytes(b, "stitch")

    except Exception as e:
        print("Stitch call failed:", e)

    return None

# -------- Gradio callbacks --------
def stitch_12(prompt12, seed, img1, img2):
    if img1 is None or img2 is None:
        gr.Warning("Please upload Image 1 and Image 2.")
        return None
    path = stitch_call(img1, img2, prompt12 or "", int(seed or 0))
    if path is None:
        gr.Warning("Stitch 1&2 failed. Try again or adjust the prompt.")
    return path

def stitch_23(prompt23, seed, img2, img3):
    if img2 is None or img3 is None:
        gr.Warning("Please upload Image 2 and Image 3.")
        return None
    path = stitch_call(img2, img3, prompt23 or "", int(seed or 0))
    if path is None:
        gr.Warning("Stitch 2&3 failed. Try again or adjust the prompt.")
    return path

def stitch_all(video12, video23):
    if not video12 or not video23:
        gr.Warning("Please generate both stitched videos first.")
        return None

    try:
        # Final output path
        out_path = f"/tmp/stitch_all_{int(time.time())}.mp4"

        # Concatenate with ffmpeg
        txt_file = f"/tmp/concat_{int(time.time())}.txt"
        with open(txt_file, "w") as f:
            f.write(f"file '{video12}'\n")
            f.write(f"file '{video23}'\n")

        cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", txt_file, "-c", "copy", out_path]
        subprocess.run(cmd, check=True)

        return out_path
    except Exception as e:
        print("Stitch all failed:", e)
        gr.Warning("Failed to stitch all videos together.")
        return None

# -------- UI --------
CSS = """
.gradio-container { padding: 24px; }
.pill button { border-radius: 999px !important; padding: 10px 18px; }
.rounded textarea { border-radius: 16px !important; }
"""

with gr.Blocks(css=CSS, title="Stitch β€” 3 uploads, 2 stitches, concat") as demo:
    gr.Markdown("## Stitch β€” Upload 3 images, generate videos between 1β†’2 and 2β†’3, then merge them.")

    with gr.Row():
        with gr.Column(scale=1, min_width=320):
            img1 = gr.Image(label="Image 1 upload", type="pil")
            img2 = gr.Image(label="Image 2 upload", type="pil")
            img3 = gr.Image(label="Image 3 upload", type="pil")

        with gr.Column(scale=1, min_width=320):
            seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
            prompt12 = gr.Textbox(placeholder="Prompt for stitching 1β†’2", lines=2, label="Prompt (1β†’2)", elem_classes=["rounded"])
            btn12 = gr.Button("Stitch 1&2", elem_classes=["pill"])
            vid12 = gr.Video(label="Video (image 1+2) output")

            prompt23 = gr.Textbox(placeholder="Prompt for stitching 2β†’3", lines=2, label="Prompt (2β†’3)", elem_classes=["rounded"])
            btn23 = gr.Button("Stitch 2&3", elem_classes=["pill"])
            vid23 = gr.Video(label="Video (image 2+3) output")

            btn_all = gr.Button("Stitch All", elem_classes=["pill"])
            vid_all = gr.Video(label="Final concatenated video")

    # Wire buttons
    btn12.click(stitch_12, inputs=[prompt12, seed, img1, img2], outputs=[vid12])
    btn23.click(stitch_23, inputs=[prompt23, seed, img2, img3], outputs=[vid23])
    btn_all.click(stitch_all, inputs=[vid12, vid23], outputs=[vid_all])

if __name__ == "__main__":
    demo.queue().launch()