File size: 9,624 Bytes
26b0caf
 
 
 
 
 
 
5a6bbaa
 
0d7b5a8
5a6bbaa
26b0caf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a6bbaa
26b0caf
5a6bbaa
 
0a73965
26b0caf
 
5a6bbaa
 
c4bcf4f
26b0caf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a71fa9d
5a6bbaa
26b0caf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a6bbaa
26b0caf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a6bbaa
26b0caf
 
 
 
5a6bbaa
 
26b0caf
 
 
 
5a6bbaa
 
26b0caf
 
 
 
 
 
 
 
 
 
 
0d7b5a8
26b0caf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import os
import io
import time
import random
import base64
from urllib.parse import quote_plus
from typing import Optional, Tuple

import requests
import gradio as gr

# -----------------------------
# Config
# -----------------------------
# You can set your Modal endpoint via env var MM_I2V_URL
DEFAULT_API_URL = os.getenv(
    "MM_I2V_URL",
    "https://moonmath-ai--moonmath-i2v-backend-moonmathinference-run.modal.run",
)

SAVE_DIR = "outputs"
os.makedirs(SAVE_DIR, exist_ok=True)


# -----------------------------
# Helpers
# -----------------------------
def _save_bytes_to_mp4(buf: bytes, name_prefix: str) -> str:
    ts = int(time.time() * 1000)
    path = os.path.join(SAVE_DIR, f"{name_prefix}-{ts}.mp4")
    with open(path, "wb") as f:
        f.write(buf)
    return path


def _download(url: str) -> bytes:
    r = requests.get(url, timeout=600)
    r.raise_for_status()
    return r.content


def call_i2v(
    image_path: str,
    prompt: str,
    seed: Optional[int],
    api_url: Optional[str] = None,
) -> Tuple[Optional[str], Optional[str]]:
    \"\"\"
    Call the image->video backend and return (video_path, error_message).

    Tries to handle several common response types:
      1) raw mp4 bytes
      2) JSON with {\"video\": \"<base64>\"} (mp4 base64)
      3) JSON with {\"video_url\": \"https://...\"} (or \"result_url\")
    \"\"\"
    api = (api_url or DEFAULT_API_URL).strip().rstrip(\"/\")
    used_seed = seed if (seed is not None and str(seed).strip() != \"\") else random.randint(0, 2**31 - 1)
    url = f\"{api}?prompt={quote_plus(prompt)}&seed={used_seed}\"

    files = {
        \"image_bytes\": (os.path.basename(image_path), open(image_path, \"rb\"), \"application/octet-stream\")
    }
    headers = {\"accept\": \"application/json\"}

    try:
        resp = requests.post(url, headers=headers, files=files, timeout=1200)
        # Try to accommodate various backends
        ctype = resp.headers.get(\"Content-Type\", \"\")
        if \"application/json\" in ctype:
            data = resp.json()
            # base64 payload
            if \"video\" in data and isinstance(data[\"video\"], str) and len(data[\"video\"]) > 50:
                try:
                    raw = base64.b64decode(data[\"video\"], validate=True)
                    return _save_bytes_to_mp4(raw, \"clip\"), None
                except Exception as e:
                    return None, f\"Could not decode base64 video: {e}\"
            # url payload
            for key in (\"video_url\", \"result_url\", \"url\"):
                if key in data and isinstance(data[key], str) and data[key].startswith(\"http\"):
                    raw = _download(data[key])
                    return _save_bytes_to_mp4(raw, \"clip\"), None
            return None, 'JSON response did not include \"video\" (base64) or a known url key.'
        # Raw bytes (ideally mp4)
        elif \"video\" in ctype or \"octet-stream\" in ctype:
            return _save_bytes_to_mp4(resp.content, \"clip\"), None
        else:
            # Some backends still reply bytes with missing/odd content-type
            if resp.content and len(resp.content) > 1024:
                return _save_bytes_to_mp4(resp.content, \"clip\"), None
            return None, f\"Unexpected content type: {ctype}\"
    except requests.RequestException as e:
        return None, f\"Request failed: {e}\"


def stitch_pair(
    image_a: str,
    image_b: str,
    prompt: str,
    seed: Optional[int],
    api_url: Optional[str],
    crossfade: float,
) -> Tuple[Optional[str], str]:
    \"\"\"
    Strategy:
      - Generate a short clip from image A
      - Generate a short clip from image B (same prompt/seed unless user changes)
      - Concatenate with a short crossfade in Python (moviepy)

    If you already have a backend endpoint that does stitching directly,
    replace this function body with a single backend call.
    \"\"\"
    if not image_a or not image_b:
        return None, \"Please upload both images.\"

    # First generate both clips
    clip1_path, err1 = call_i2v(image_a, prompt, seed, api_url)
    if err1:
        return None, f\"Clip 1 failed: {err1}\"
    clip2_path, err2 = call_i2v(image_b, prompt, seed, api_url)
    if err2:
        return None, f\"Clip 2 failed: {err2}\"

    # If crossfade is 0, just concatenate directly
    try:
        from moviepy.editor import VideoFileClip, concatenate_videoclips
    except Exception as e:
        return None, f\"MoviePy import failed. Add moviepy & imageio-ffmpeg to requirements.txt. Error: {e}\"

    try:
        c1 = VideoFileClip(clip1_path)
        c2 = VideoFileClip(clip2_path)

        # Enforce same size/fps (compose handles mismatches)
        if crossfade and crossfade > 0:
            # Apply crossfade (second clip fades in)
            c2 = c2.crossfadein(crossfade)
            c1 = c1.crossfadeout(crossfade)
            merged = concatenate_videoclips([c1, c2], method=\"compose\", padding=-crossfade)
        else:
            merged = concatenate_videoclips([c1, c2], method=\"compose\")

        out_path = os.path.join(SAVE_DIR, f\"stitched-{int(time.time()*1000)}.mp4\")
        merged.write_videofile(out_path, codec=\"libx264\", audio_codec=\"aac\", verbose=False, logger=None)
        c1.close(); c2.close(); merged.close()
        return out_path, \"\"
    except Exception as e:
        return None, f\"Stitching failed: {e}\"


# -----------------------------
# UI
# -----------------------------
with gr.Blocks(title=\"Image Stitch to Video\", css=\"\"\"
/* Rounded tiles like the sketch */
.rounded { border-radius: 24px; }
.tile { background: #f7f7ff; padding: 12px; }
.tile-blue { background: #e8f0ff; }
.tile-yellow { background: #fff7d6; }
.small-btn button { padding: 6px 10px; border-radius: 999px; }
.label-center label { text-align:center; width: 100%; }
\"\"\") as demo:
    gr.Markdown(\"### Image → Video (Stitch Adjacent Pairs)\\nUpload 3 images, enter prompts for each stitch, then click the stitch buttons.\")

    with gr.Row(equal_height=True):
        # Left column: images + add image
        with gr.Column(scale=1):
            gr.Markdown(\"**Images**\")
            img1 = gr.Image(type=\"filepath\", label=\"Image 1\", height=220, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"])
            img2 = gr.Image(type=\"filepath\", label=\"Image 2\", height=220, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"])
            img3 = gr.Image(type=\"filepath\", label=\"Image 3\", height=220, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"])

            # Optional extra slots (hidden until added)
            extra_imgs = []
            for i in range(4, 9):
                comp = gr.Image(type=\"filepath\", label=f\"Image {i}\", height=220, visible=False, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"])
                extra_imgs.append(comp)

            add_btn = gr.Button(\"Add Image\", variant=\"secondary\")

        # Middle column: prompts + stitch buttons
        with gr.Column(scale=1):
            gr.Markdown(\"**Prompts**\")
            prompt12 = gr.Textbox(label=\"Prompt for Stitch 1 & 2\", lines=3, placeholder=\"Describe motion/style/etc.\", elem_classes=[\"rounded\", \"tile\"])
            seed12 = gr.Number(label=\"Seed (optional)\", value=None, precision=0)
            stitch12 = gr.Button(\"Stitch 1 & 2\", elem_classes=[\"small-btn\"])

            prompt23 = gr.Textbox(label=\"Prompt for Stitch 2 & 3\", lines=3, placeholder=\"Describe motion/style/etc.\", elem_classes=[\"rounded\", \"tile\"])
            seed23 = gr.Number(label=\"Seed (optional)\", value=None, precision=0)
            stitch23 = gr.Button(\"Stitch 2 & 3\", elem_classes=[\"small-btn\"])

            with gr.Accordion(\"Advanced (API & Stitch)\", open=False):
                api_url = gr.Textbox(label=\"Backend API URL\", value=DEFAULT_API_URL)
                crossfade = gr.Slider(0.0, 1.5, value=0.4, step=0.1, label=\"Crossfade seconds\")
                clear_btn = gr.Button(\"Clear All\")

        # Right column: video outputs
        with gr.Column(scale=1):
            gr.Markdown(\"**Outputs**\")
            vid12 = gr.Video(label=\"Video (image 1 + 2) output\", elem_classes=[\"rounded\", \"tile\", \"tile-yellow\"])
            vid23 = gr.Video(label=\"Video (image 2 + 3) output\", elem_classes=[\"rounded\", \"tile\", \"tile-yellow\"])

    # Wire up actions
    def _on_add(*imgs):
        # Reveal the next hidden uploader
        for comp in extra_imgs:
            if comp.visible is False:
                comp.visible = True
                break
        return [gr.update(visible=comp.visible) for comp in extra_imgs]

    add_btn.click(
        _on_add,
        inputs=extra_imgs,
        outputs=extra_imgs,
    )

    stitch12.click(
        stitch_pair,
        inputs=[img1, img2, prompt12, seed12, api_url, crossfade],
        outputs=[vid12, gr.Textbox(visible=False)],
    )

    stitch23.click(
        stitch_pair,
        inputs=[img2, img3, prompt23, seed23, api_url, crossfade],
        outputs=[vid23, gr.Textbox(visible=False)],
    )

    def _on_clear():
        updates = []
        for comp in [img1, img2, img3, *extra_imgs]:
            updates.append(gr.update(value=None, visible=True if comp in [img1, img2, img3] else False))
        return updates + [None, None, \"\", \"\", gr.update(value=DEFAULT_API_URL), 0.4]

    clear_btn.click(
        _on_clear,
        inputs=None,
        outputs=[img1, img2, img3, *extra_imgs, vid12, vid23, prompt12, prompt23, api_url, crossfade],
    )

if __name__ == \"__main__\":
    demo.launch()