File size: 12,135 Bytes
7f0be4a
75c12be
7f0be4a
c401d55
7f0be4a
 
 
5a6bbaa
7f0be4a
 
f1cd1b3
7f0be4a
 
75c12be
7f0be4a
 
 
 
 
 
 
5a6bbaa
7f0be4a
 
 
 
0a73965
7f0be4a
 
 
 
c4bcf4f
7f0be4a
 
 
13a051d
7f0be4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13a051d
7f0be4a
a71fa9d
5a6bbaa
7f0be4a
 
 
 
 
 
 
 
 
 
 
 
 
f1cd1b3
7f0be4a
 
 
 
 
 
f1cd1b3
7f0be4a
 
f1cd1b3
7f0be4a
 
 
 
 
 
 
c401d55
7f0be4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75c12be
7f0be4a
 
c84843c
7f0be4a
 
 
75c12be
7f0be4a
 
 
 
 
0247b0d
7f0be4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13a051d
c576104
7f0be4a
 
 
 
 
 
 
 
 
 
13a051d
75c12be
7f0be4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1addcdc
 
7f0be4a
 
 
 
 
75c12be
 
7f0be4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1cd1b3
 
7f0be4a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import os, io, time, base64, random, subprocess
from typing import Optional, List
from urllib.parse import urlencode

import requests
from PIL import Image
import gradio as gr

# -------- Modal inference endpoint (dev) --------
INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"

# -------- settings --------
MAX_SLOTS = 12          # max image slots user can reveal

# -------- small helpers --------
def _save_video_bytes(data: bytes, tag: str) -> str:
    os.makedirs("/tmp", exist_ok=True)
    path = f"/tmp/{tag}_{int(time.time())}.mp4"
    with open(path, "wb") as f:
        f.write(data)
    return path

def _png_bytes(img: Image.Image) -> bytes:
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    return buf.getvalue()

def _download_to_bytes(url: str) -> bytes:
    r = requests.get(url, timeout=180)
    r.raise_for_status()
    return r.content

def stitch_call(
    start_img: Image.Image,
    end_img: Image.Image,
    prompt: str,
    seed: Optional[int],
    negative_prompt: Optional[str] = None,
    frames_per_second: int = 24,
    video_length: int = 4,
    num_inference_steps: Optional[int] = None,
) -> Optional[str]:
    """
    Required (in body): image_bytes (+ image_bytes_end)
    In URL query: prompt, negative_prompt, frames_per_second, video_length, seed, num_inference_steps
    """
    if start_img is None or end_img is None:
        return None

    # default seed behavior
    if seed in (None, 0, -1):
        seed = random.randint(1, 2**31 - 1)

    # Build query string
    q = {
        "prompt": prompt or "",
        "seed": int(seed),
        "frames_per_second": int(frames_per_second),
        "video_length": int(video_length),
    }
    if negative_prompt:
        q["negative_prompt"] = negative_prompt
    if num_inference_steps is not None:
        q["num_inference_steps"] = int(num_inference_steps)

    url = f"{INFERENCE_URL}?{urlencode(q)}"

    # Images go in the body
    files = {
        "image_bytes": ("start.png", _png_bytes(start_img), "image/png"),
        "image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"),
    }
    headers = {"accept": "application/json"}

    try:
        resp = requests.post(url, files=files, headers=headers, timeout=600)
        ctype = (resp.headers.get("content-type") or "").lower()

        # Raw video bytes
        if "application/json" not in ctype:
            resp.raise_for_status()
            return _save_video_bytes(resp.content, "stitch")

        # JSON with url or base64
        data = resp.json()
        video_url = data.get("video_url") or data.get("url") or data.get("result") or data.get("output")
        if isinstance(video_url, str) and video_url.startswith(("http://", "https://")):
            return _save_video_bytes(_download_to_bytes(video_url), "stitch")

        video_b64 = data.get("video_b64") or data.get("videoBase64")
        if isinstance(video_b64, str):
            pad = (-len(video_b64)) % 4
            if pad:
                video_b64 += "=" * pad
            return _save_video_bytes(base64.b64decode(video_b64), "stitch")

    except Exception as e:
        print("stitch_call error:", e)

    return None

# -------- FFmpeg-based concatenation (N clips) --------
def concat_many(videos: List[str]) -> Optional[str]:
    vids = [v for v in videos if v]
    if len(vids) < 2:
        return None
    try:
        os.makedirs("/tmp", exist_ok=True)
        out_path = f"/tmp/final_{int(time.time())}.mp4"
        list_file = f"/tmp/list_{int(time.time())}.txt"
        with open(list_file, "w") as f:
            for v in vids:
                f.write(f"file '{v}'\n")
        subprocess.run(
            ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_file, "-c", "copy", out_path],
            check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
        )
        return out_path
    except Exception as e:
        print("concat_many error:", e)
        return None

# -------- Timeline HTML renderer --------
def render_timeline_html(paths: List[str]):
    vids = [p for p in (paths or []) if p]
    if not vids:
        return "<div class='tl-grid tl-empty'>No clips yet. Generate and click ‘Add to timeline’.</div>"
    items = []
    for i, p in enumerate(vids, 1):
        items.append(
            f"""
            <div class="tl-item">
              <video src="{p}" controls playsinline></video>
              <div class="tl-label">Clip {i}</div>
            </div>
            """
        )
    return f"<div class='tl-grid'>{''.join(items)}</div>"

# =========================
# Gradio callbacks / state ops
# =========================
def add_image_slot(visible_slots: int):
    """Reveal one more upload slot (up to MAX_SLOTS)."""
    return min(MAX_SLOTS, int(visible_slots) + 1)

def _reveal_slots(n, *imgs):
    """Update visibility of image upload components based on visible_slots state."""
    n = int(n)
    updates = []
    for i in range(MAX_SLOTS):
        updates.append(gr.update(visible=(i < n)))
    return updates

def collect_choices(*imgs):
    """Build dropdown choices of available indices (1-based labels) based on non-empty slots."""
    choices = []
    for i, img in enumerate(imgs, start=1):
        if img is not None:
            choices.append(str(i))
    return gr.update(choices=choices), gr.update(choices=choices)

def stitch_selected(
    prompt, negative_prompt, fps, length_sec, seed, start_idx_str, end_idx_str, *imgs
):
    """Run inference for selected start/end indices (1-based strings) + options."""
    if not start_idx_str or not end_idx_str:
        gr.Warning("Please select Start and End frames.")
        return None
    try:
        s = int(start_idx_str) - 1
        e = int(end_idx_str) - 1
    except Exception:
        gr.Warning("Invalid Start/End selection.")
        return None

    if s < 0 or e < 0 or s >= len(imgs) or e >= len(imgs):
        gr.Warning("Start/End out of range.")
        return None

    start_img = imgs[s]
    end_img = imgs[e]
    if start_img is None or end_img is None:
        gr.Warning("Selected slots are empty.")
        return None

    fps_val = int(str(fps)) if fps else 24
    len_val = int(str(length_sec)) if length_sec else 4

    vid = stitch_call(
        start_img=start_img,
        end_img=end_img,
        prompt=prompt or "",
        seed=int(seed or 0),
        negative_prompt=(negative_prompt or "").strip() or None,
        frames_per_second=fps_val,
        video_length=len_val,
        num_inference_steps=None,
    )
    if not vid:
        gr.Warning("Generation failed.")
        return None
    return vid  # path for preview

def add_to_timeline(preview_path, timeline_paths: List[str]):
    """Append preview to timeline; return updated state and HTML."""
    tl = list(timeline_paths or [])
    if not preview_path:
        gr.Warning("Generate a clip first.")
        return tl, gr.update(value=render_timeline_html(tl))
    tl.append(preview_path)
    return tl, gr.update(value=render_timeline_html(tl))

def stitch_all_from_timeline(timeline_paths: List[str]):
    vids = list(timeline_paths or [])
    if len(vids) < 2:
        gr.Warning("Add at least two clips to the timeline first.")
        return None
    out = concat_many(vids)
    if not out:
        gr.Warning("Failed to concatenate clips.")
    return out

# =========================
# UI
# =========================
CSS = """
.gradio-container { padding: 24px; }
.pill button { border-radius: 999px !important; padding: 10px 18px; }
.rounded textarea { border-radius: 16px !important; }
.gallery-row { display:flex; gap:16px; overflow-x:auto; padding:8px 4px; }
.gallery-row .gradio-image { min-width: 220px; }
.tl-grid {
  display: grid;
  grid-template-columns: repeat(auto-fill, minmax(180px, 1fr));
  gap: 12px;
}
.stitch-box {
  background-color: #f0f4ff;   /* pick any color you like */
  border-radius: 12px;
  padding: 16px;
}
.tl-grid video {
  width: 100%;
  height: 120px;
  object-fit: cover;
  border-radius: 12px;
  display: block;
}
.tl-label {
  font-size: 12px;
  color: #9aa0a6;
  margin-top: 4px;
  text-align: center;
}
.tl-empty { color: #9aa0a6; padding: 8px 4px; }
"""

with gr.Blocks(css=CSS, title="StitchTool") as demo:
    gr.Markdown("## StitchTool")

    # --- State ---
    visible_slots = gr.State(value=3)      # number of visible image slots
    timeline_state = gr.State(value=[])    # list[str] of video file paths (timeline)

    # --- Image gallery (horizontal, grows on demand) ---
    with gr.Row(elem_classes=["gallery-row"]):
        img_comps = []
        for i in range(MAX_SLOTS):
            comp = gr.Image(label=f"Image {i+1} upload", type="pil", visible=(i < 3))
            img_comps.append(comp)
        add_btn = gr.Button("+ Add image")

    # clicking add → reveal one more slot
    add_btn.click(
        fn=add_image_slot,
        inputs=[visible_slots],
        outputs=[visible_slots],
    )

    # reflect visibility changes whenever visible_slots changes
    visible_slots.change(
        fn=_reveal_slots,
        inputs=[visible_slots] + img_comps,
        outputs=img_comps
    )

    # Seed + Start/End selection + Prompt + options + Stitch + Preview
    seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")

    with gr.Row():
        # Left column: controls (with colored background via .stitch-box)
        with gr.Column(scale=1, min_width=420, elem_classes=["stitch-box"]):
            start_dd = gr.Dropdown(label="Start frame", choices=[], interactive=True)
            end_dd = gr.Dropdown(label="End frame", choices=[], interactive=True)

            prompt = gr.Textbox(
                placeholder="Describe the transition between the selected start and end frames…",
                lines=3,
                label="Prompt",
                elem_classes=["rounded"]
            )

            negative = gr.Textbox(
                placeholder="Optional: things to avoid (e.g., 'bad quality, extra fingers, etc.')",
                lines=2,
                label="Negative prompt",
                elem_classes=["rounded"]
            )

            with gr.Row():
                fps = gr.Dropdown(
                    label="Frame rate",
                    choices=["16", "24", "32"],
                    value="24",
                    interactive=True,
                )
                length_sec = gr.Dropdown(
                    label="Video length (sec)",
                    choices=["2", "4"],
                    value="4",
                    interactive=True,
                )

            run_btn = gr.Button("Generate", elem_classes=["pill"])
            add_tl_btn = gr.Button("Add to timeline", elem_classes=["pill"])

        # Right column: preview video
        with gr.Column(scale=1, min_width=420):
            preview = gr.Video(label="Video output", interactive=False)

    # keep start/end dropdowns up to date based on which slots have images
    for comp in img_comps:
        comp.change(
            fn=collect_choices,
            inputs=img_comps,
            outputs=[start_dd, end_dd]
        )

    # stitch action → preview
    run_btn.click(
        fn=stitch_selected,
        inputs=[prompt, negative, fps, length_sec, seed, start_dd, end_dd] + img_comps,
        outputs=[preview]
    )

    # --- Dynamic timeline (no placeholders) ---
    with gr.Row():
        timeline_html = gr.HTML(value=render_timeline_html([]))

    add_tl_btn.click(
        fn=add_to_timeline,
        inputs=[preview, timeline_state],
        outputs=[timeline_state, timeline_html]
    )

    # final stitch all (concatenate in order)
    with gr.Row():
        with gr.Column(scale=1, min_width=420):
            stitch_all_btn = gr.Button("Stitch All", elem_classes=["pill"])
        with gr.Column(scale=1, min_width=420):
            final_vid = gr.Video(label="Stitched Video Output", interactive=False)

    stitch_all_btn.click(
        fn=stitch_all_from_timeline,
        inputs=[timeline_state],
        outputs=[final_vid]
    )

if __name__ == "__main__":
    demo.queue().launch()