Shalmoni commited on
Commit
c576104
·
verified ·
1 Parent(s): cb2acd5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -342
app.py CHANGED
@@ -1,371 +1,191 @@
1
- import os, io, time, base64, random, subprocess
 
2
  from typing import Optional, List
3
- from urllib.parse import urlencode
4
 
5
- import requests
6
- from PIL import Image
7
- import gradio as gr
 
 
8
 
9
- # -------- Modal inference endpoint (dev) --------
10
- INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"
 
11
 
12
- # -------- settings --------
13
- MAX_SLOTS = 12 # max image slots user can reveal
 
 
14
 
15
- # -------- small helpers --------
16
- def _save_video_bytes(data: bytes, tag: str) -> str:
17
- os.makedirs("/tmp", exist_ok=True)
18
- path = f"/tmp/{tag}_{int(time.time())}.mp4"
19
- with open(path, "wb") as f:
20
- f.write(data)
21
- return path
22
 
23
- def _png_bytes(img: Image.Image) -> bytes:
24
- buf = io.BytesIO()
25
- img.save(buf, format="PNG")
26
- return buf.getvalue()
 
 
 
 
 
 
 
27
 
28
- def _download_to_bytes(url: str) -> bytes:
29
- r = requests.get(url, timeout=180)
30
- r.raise_for_status()
31
- return r.content
 
 
 
 
 
 
 
 
 
 
32
 
33
- def stitch_call(
34
- start_img: Image.Image,
35
- end_img: Image.Image,
36
  prompt: str,
37
- seed: Optional[int],
38
- negative_prompt: Optional[str] = None,
39
- frames_per_second: int = 24,
40
- video_length: int = 4,
41
- num_inference_steps: Optional[int] = None,
42
- ) -> Optional[str]:
43
- """
44
- Required (in body): image_bytes (+ image_bytes_end)
45
- In URL query: prompt, negative_prompt, frames_per_second, video_length, seed, num_inference_steps
46
- """
47
- if start_img is None or end_img is None:
48
- return None
49
-
50
- # default seed behavior
51
- if seed in (None, 0, -1):
52
- seed = random.randint(1, 2**31 - 1)
53
-
54
- # Build query string
55
- q = {
56
- "prompt": prompt or "",
57
- "seed": int(seed),
58
- "frames_per_second": int(frames_per_second),
59
- "video_length": int(video_length),
60
  }
61
- if negative_prompt:
62
- q["negative_prompt"] = negative_prompt
63
- if num_inference_steps is not None:
64
- q["num_inference_steps"] = int(num_inference_steps)
65
-
66
- url = f"{INFERENCE_URL}?{urlencode(q)}"
67
-
68
- # Images go in the body
69
- files = {
70
- "image_bytes": ("start.png", _png_bytes(start_img), "image/png"),
71
- "image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"),
72
- }
73
- headers = {"accept": "application/json"}
74
 
 
 
75
  try:
76
- resp = requests.post(url, files=files, headers=headers, timeout=600)
77
- ctype = (resp.headers.get("content-type") or "").lower()
78
-
79
- # Raw video bytes
80
- if "application/json" not in ctype:
81
- resp.raise_for_status()
82
- return _save_video_bytes(resp.content, "stitch")
83
 
84
- # JSON with url or base64
85
- data = resp.json()
86
- video_url = data.get("video_url") or data.get("url") or data.get("result") or data.get("output")
87
- if isinstance(video_url, str) and video_url.startswith(("http://", "https://")):
88
- return _save_video_bytes(_download_to_bytes(video_url), "stitch")
89
 
90
- video_b64 = data.get("video_b64") or data.get("videoBase64")
91
- if isinstance(video_b64, str):
92
- pad = (-len(video_b64)) % 4
93
- if pad:
94
- video_b64 += "=" * pad
95
- return _save_video_bytes(base64.b64decode(video_b64), "stitch")
96
 
97
- except Exception as e:
98
- print("stitch_call error:", e)
99
-
100
- return None
101
-
102
- # -------- FFmpeg-based concatenation (N clips) --------
103
- def concat_many(videos: List[str]) -> Optional[str]:
104
- vids = [v for v in videos if v]
105
- if len(vids) < 2:
106
- return None
107
  try:
108
- os.makedirs("/tmp", exist_ok=True)
109
- out_path = f"/tmp/final_{int(time.time())}.mp4"
110
- list_file = f"/tmp/list_{int(time.time())}.txt"
111
- with open(list_file, "w") as f:
112
- for v in vids:
113
- f.write(f"file '{v}'\n")
114
- subprocess.run(
115
- ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_file, "-c", "copy", out_path],
116
- check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
117
- )
118
- return out_path
119
- except Exception as e:
120
- print("concat_many error:", e)
121
- return None
122
-
123
- # -------- Timeline HTML renderer --------
124
- def render_timeline_html(paths: List[str]):
125
- vids = [p for p in (paths or []) if p]
126
- if not vids:
127
- return "<div class='tl-grid tl-empty'>No clips yet. Generate and click ‘Add to timeline’.</div>"
128
- items = []
129
- for i, p in enumerate(vids, 1):
130
- items.append(
131
- f"""
132
- <div class="tl-item">
133
- <video src="{p}" controls playsinline></video>
134
- <div class="tl-label">Clip {i}</div>
135
- </div>
136
- """
137
- )
138
- return f"<div class='tl-grid'>{''.join(items)}</div>"
139
-
140
- # =========================
141
- # Gradio callbacks / state ops
142
- # =========================
143
- def add_image_slot(visible_slots: int):
144
- """Reveal one more upload slot (up to MAX_SLOTS)."""
145
- return min(MAX_SLOTS, int(visible_slots) + 1)
146
-
147
- def _reveal_slots(n, *imgs):
148
- """Update visibility of image upload components based on visible_slots state."""
149
- n = int(n)
150
- updates = []
151
- for i in range(MAX_SLOTS):
152
- updates.append(gr.update(visible=(i < n)))
153
- return updates
154
-
155
- def collect_choices(*imgs):
156
- """Build dropdown choices of available indices (1-based labels) based on non-empty slots."""
157
- choices = []
158
- for i, img in enumerate(imgs, start=1):
159
- if img is not None:
160
- choices.append(str(i))
161
- return gr.update(choices=choices), gr.update(choices=choices)
162
-
163
- def stitch_selected(
164
- prompt, negative_prompt, fps, length_sec, seed, start_idx_str, end_idx_str, *imgs
165
- ):
166
- """Run inference for selected start/end indices (1-based strings) + options."""
167
- if not start_idx_str or not end_idx_str:
168
- gr.Warning("Please select Start and End frames.")
169
- return None
170
- try:
171
- s = int(start_idx_str) - 1
172
- e = int(end_idx_str) - 1
173
  except Exception:
174
- gr.Warning("Invalid Start/End selection.")
175
- return None
176
-
177
- if s < 0 or e < 0 or s >= len(imgs) or e >= len(imgs):
178
- gr.Warning("Start/End out of range.")
179
- return None
180
 
181
- start_img = imgs[s]
182
- end_img = imgs[e]
183
- if start_img is None or end_img is None:
184
- gr.Warning("Selected slots are empty.")
185
- return None
186
-
187
- fps_val = int(str(fps)) if fps else 24
188
- len_val = int(str(length_sec)) if length_sec else 4
189
-
190
- vid = stitch_call(
191
- start_img=start_img,
192
- end_img=end_img,
193
- prompt=prompt or "",
194
- seed=int(seed or 0),
195
- negative_prompt=(negative_prompt or "").strip() or None,
196
- frames_per_second=fps_val,
197
- video_length=len_val,
198
- num_inference_steps=None,
199
- )
200
- if not vid:
201
- gr.Warning("Generation failed.")
202
- return None
203
- return vid # path for preview
204
-
205
- def add_to_timeline(preview_path, timeline_paths: List[str]):
206
- """Append preview to timeline; return updated state and HTML."""
207
- tl = list(timeline_paths or [])
208
- if not preview_path:
209
- gr.Warning("Generate a clip first.")
210
- return tl, gr.update(value=render_timeline_html(tl))
211
- tl.append(preview_path)
212
- return tl, gr.update(value=render_timeline_html(tl))
213
-
214
- def stitch_all_from_timeline(timeline_paths: List[str]):
215
- vids = list(timeline_paths or [])
216
- if len(vids) < 2:
217
- gr.Warning("Add at least two clips to the timeline first.")
218
- return None
219
- out = concat_many(vids)
220
- if not out:
221
- gr.Warning("Failed to concatenate clips.")
222
- return out
223
 
224
- # =========================
225
- # UI
226
- # =========================
227
- CSS = """
228
- .gradio-container { padding: 24px; }
229
- .pill button { border-radius: 999px !important; padding: 10px 18px; }
230
- .rounded textarea { border-radius: 16px !important; }
231
- .gallery-row { display:flex; gap:16px; overflow-x:auto; padding:8px 4px; }
232
- .gallery-row .gradio-image { min-width: 220px; }
233
- .tl-grid {
234
- display: grid;
235
- grid-template-columns: repeat(auto-fill, minmax(180px, 1fr));
236
- gap: 12px;
237
- }
238
- .stitch-box {
239
- background-color: #f0f4ff; /* pick any color you like */
240
- border-radius: 12px;
241
- padding: 16px;
242
- }
243
- .tl-grid video {
244
- width: 100%;
245
- height: 120px;
246
- object-fit: cover;
247
- border-radius: 12px;
248
- display: block;
249
- }
250
- .tl-label {
251
- font-size: 12px;
252
- color: #9aa0a6;
253
- margin-top: 4px;
254
- text-align: center;
255
- }
256
- .tl-empty { color: #9aa0a6; padding: 8px 4px; }
257
- """
258
-
259
- with gr.Blocks(css=CSS, title="StitchTool") as demo:
260
- gr.Markdown("## StitchTool")
261
-
262
- # --- State ---
263
- visible_slots = gr.State(value=3) # number of visible image slots
264
- timeline_state = gr.State(value=[]) # list[str] of video file paths (timeline)
265
-
266
- # --- Image gallery (horizontal, grows on demand) ---
267
- with gr.Row(elem_classes=["gallery-row"]):
268
- img_comps = []
269
- for i in range(MAX_SLOTS):
270
- comp = gr.Image(label=f"Image {i+1} upload", type="pil", visible=(i < 3))
271
- img_comps.append(comp)
272
- add_btn = gr.Button("+ Add image")
273
-
274
- # clicking add → reveal one more slot
275
- add_btn.click(
276
- fn=add_image_slot,
277
- inputs=[visible_slots],
278
- outputs=[visible_slots],
279
- )
280
-
281
- # reflect visibility changes whenever visible_slots changes
282
- visible_slots.change(
283
- fn=_reveal_slots,
284
- inputs=[visible_slots] + img_comps,
285
- outputs=img_comps
286
- )
287
-
288
- # Seed + Start/End selection + Prompt + options + Stitch + Preview
289
- seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
290
 
291
  with gr.Row():
292
- # Left column: controls (with colored background via .stitch-box)
293
- with gr.Column(scale=1, min_width=420, elem_classes=["stitch-box"]):
294
- start_dd = gr.Dropdown(label="Start frame", choices=[], interactive=True)
295
- end_dd = gr.Dropdown(label="End frame", choices=[], interactive=True)
296
-
297
- prompt = gr.Textbox(
298
- placeholder="Describe the transition between the selected start and end frames…",
299
- lines=3,
300
- label="Prompt",
301
- elem_classes=["rounded"]
302
- )
303
 
304
- negative = gr.Textbox(
305
- placeholder="Optional: things to avoid (e.g., 'bad quality, extra fingers, etc.')",
306
- lines=2,
307
- label="Negative prompt",
308
- elem_classes=["rounded"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  )
310
-
311
- with gr.Row():
312
- fps = gr.Dropdown(
313
- label="Frame rate",
314
- choices=["16", "24", "32"],
315
- value="24",
316
- interactive=True,
317
- )
318
- length_sec = gr.Dropdown(
319
- label="Video length (sec)",
320
- choices=["2", "4"],
321
- value="4",
322
- interactive=True,
323
- )
324
-
325
- run_btn = gr.Button("Generate", elem_classes=["pill"])
326
- add_tl_btn = gr.Button("Add to timeline", elem_classes=["pill"])
327
-
328
- # Right column: preview video
329
- with gr.Column(scale=1, min_width=420):
330
- preview = gr.Video(label="Video output", interactive=False)
331
-
332
- # keep start/end dropdowns up to date based on which slots have images
333
- for comp in img_comps:
334
- comp.change(
335
- fn=collect_choices,
336
- inputs=img_comps,
337
- outputs=[start_dd, end_dd]
338
- )
339
-
340
- # stitch action → preview
341
- run_btn.click(
342
- fn=stitch_selected,
343
- inputs=[prompt, negative, fps, length_sec, seed, start_dd, end_dd] + img_comps,
344
- outputs=[preview]
345
  )
346
 
347
- # --- Dynamic timeline (no placeholders) ---
348
- with gr.Row():
349
- timeline_html = gr.HTML(value=render_timeline_html([]))
350
-
351
- add_tl_btn.click(
352
- fn=add_to_timeline,
353
- inputs=[preview, timeline_state],
354
- outputs=[timeline_state, timeline_html]
 
 
 
 
 
 
355
  )
356
 
357
- # final stitch all (concatenate in order)
358
- with gr.Row():
359
- with gr.Column(scale=1, min_width=420):
360
- stitch_all_btn = gr.Button("Stitch All", elem_classes=["pill"])
361
- with gr.Column(scale=1, min_width=420):
362
- final_vid = gr.Video(label="Stitched Video Output", interactive=False)
363
-
364
- stitch_all_btn.click(
365
- fn=stitch_all_from_timeline,
366
- inputs=[timeline_state],
367
- outputs=[final_vid]
368
- )
 
 
 
369
 
370
  if __name__ == "__main__":
371
- demo.queue().launch()
 
1
+ import os, uuid, time, json, shutil, mimetypes, subprocess, requests, gradio as gr
2
+ from datetime import datetime
3
  from typing import Optional, List
 
4
 
5
+ # -------- config --------
6
+ ENDPOINT = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"
7
+ FFMPEG = "ffmpeg"
8
+ OUT, TMP = "outputs", "tmp"
9
+ os.makedirs(OUT, exist_ok=True); os.makedirs(TMP, exist_ok=True)
10
 
11
+ ts = lambda: datetime.utcnow().strftime("%Y%m%d_%H%M%S")
12
+ fname = lambda p,e: f"{p}_{ts()}_{uuid.uuid4().hex[:6]}.{e}"
13
+ abspath= lambda p: os.path.abspath(p)
14
 
15
+ def run_ffmpeg(args: List[str]):
16
+ p = subprocess.run([FFMPEG]+args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
17
+ if p.returncode:
18
+ raise RuntimeError(p.stderr.decode(errors="ignore"))
19
 
20
+ def extract_last_frame(video_path: str) -> str:
21
+ out = os.path.join(TMP, fname("last", "png"))
22
+ run_ffmpeg(["-sseof","-1","-i",video_path,"-frames:v","1","-q:v","2",out])
23
+ return out
 
 
 
24
 
25
+ def concat_videos(paths: List[str]) -> str:
26
+ if not paths: raise ValueError("No videos selected.")
27
+ if len(paths)==1:
28
+ dst = os.path.join(OUT, fname("continuous","mp4")); shutil.copy(paths[0], dst); return dst
29
+ listfile = os.path.join(TMP, f"concat_{uuid.uuid4().hex}.txt")
30
+ with open(listfile,"w") as f:
31
+ for p in paths: f.write(f"file '{abspath(p)}'\n")
32
+ out = os.path.join(OUT, fname("continuous","mp4"))
33
+ run_ffmpeg(["-f","concat","-safe","0","-i",listfile,"-c","copy",out])
34
+ os.remove(listfile)
35
+ return out
36
 
37
+ def zip_used(paths: List[str]) -> str:
38
+ pack = os.path.join(TMP, f"pack_{uuid.uuid4().hex[:6]}")
39
+ os.makedirs(pack, exist_ok=True)
40
+ for p in paths: shutil.copy(p, pack)
41
+ base = os.path.join(OUT, f"used_{ts()}")
42
+ shutil.make_archive(base, "zip", pack)
43
+ shutil.rmtree(pack, ignore_errors=True)
44
+ return base + ".zip"
45
+
46
+ def save_video_bytes(content: bytes, content_type: str) -> str:
47
+ ext = (mimetypes.guess_extension(content_type) or ".mp4").lstrip(".")
48
+ path = os.path.join(OUT, fname("gen", ext))
49
+ with open(path,"wb") as f: f.write(content)
50
+ return path
51
 
52
+ def call_backend(
 
 
53
  prompt: str,
54
+ image_bytes_path: str,
55
+ negative_prompt: Optional[str],
56
+ fps: int,
57
+ vlen: int,
58
+ steps: Optional[int],
59
+ seed: Optional[int]
60
+ ) -> str:
61
+ params = {
62
+ "prompt": prompt,
63
+ "frames_per_second": str(fps),
64
+ "video_length": str(vlen),
 
 
 
 
 
 
 
 
 
 
 
 
65
  }
66
+ if negative_prompt: params["negative_prompt"] = negative_prompt
67
+ if steps is not None: params["num_inference_steps"] = str(steps)
68
+ if seed is None: seed = int(time.time())
69
+ params["seed"] = str(seed)
 
 
 
 
 
 
 
 
 
70
 
71
+ files = {"image_bytes": (os.path.basename(image_bytes_path), open(image_bytes_path,"rb"),
72
+ "application/octet-stream")}
73
  try:
74
+ r = requests.post(ENDPOINT, params=params, files=files, headers={"accept":"application/json"}, timeout=600)
75
+ finally:
76
+ try: files["image_bytes"][1].close()
77
+ except: pass
 
 
 
78
 
79
+ if r.status_code != 200:
80
+ raise RuntimeError(f"Backend {r.status_code}: {r.text[:500]}")
 
 
 
81
 
82
+ ctype = r.headers.get("Content-Type","")
83
+ if ctype.startswith("video/"): # raw video bytes
84
+ return save_video_bytes(r.content, ctype)
 
 
 
85
 
86
+ # expect JSON with { "video_url": ... } or direct mp4 URL
 
 
 
 
 
 
 
 
 
87
  try:
88
+ payload = r.json()
89
+ url = payload.get("video_url")
90
+ if not url: raise ValueError("no video_url in response")
91
+ r2 = requests.get(url, stream=True, timeout=600)
92
+ if r2.status_code != 200: raise RuntimeError(f"fetch video {r2.status_code}")
93
+ path = os.path.join(OUT, fname("gen","mp4"))
94
+ with open(path,"wb") as f:
95
+ for chunk in r2.iter_content(1<<20):
96
+ if chunk: f.write(chunk)
97
+ return path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  except Exception:
99
+ # if backend returns direct bytes but mislabeled JSON, fallback
100
+ return save_video_bytes(r.content, ctype or "video/mp4")
 
 
 
 
101
 
102
+ with gr.Blocks() as demo:
103
+ gr.Markdown("## Continuous Video — chain last frame → next first frame")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ # state
106
+ used_videos = gr.State([]) # list[str]
107
+ last_seed_img = gr.State(None) # str path (PNG) to send as image_bytes
108
+ current_video = gr.State(None) # str path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  with gr.Row():
111
+ prompt = gr.Textbox(label="Prompt", placeholder="Describe your shot…", lines=2)
112
+ with gr.Row():
113
+ start_file = gr.File(label="Initial start image or video (only needed for the very first clip)", file_types=["image","video"])
114
+ with gr.Row():
115
+ negative = gr.Textbox(label="Negative prompt (optional)", placeholder="What to avoid…")
116
+ with gr.Row():
117
+ fps = gr.Slider(1,60, value=24, step=1, label="Frames per second")
118
+ vlen = gr.Slider(1,12, value=4, step=1, label="Video length (seconds)")
119
+ steps = gr.Slider(1,100, value=30, step=1, label="Num inference steps (optional)", interactive=True)
120
+ seed = gr.Number(label="Seed (optional)", precision=0, value=None)
 
121
 
122
+ video_out = gr.Video(label="Output")
123
+ with gr.Row():
124
+ gen_btn = gr.Button("Generate", variant="primary")
125
+ use_btn = gr.Button("Chain")
126
+ dl_btn = gr.Button("Download")
127
+
128
+ files_out = gr.Files(label="Downloads")
129
+
130
+ def on_generate(prompt_txt, start_file_obj, neg, fps_val, vlen_val, steps_val, seed_val, seed_img):
131
+ if not prompt_txt or not prompt_txt.strip():
132
+ raise gr.Error("Prompt is required.")
133
+ # choose image_bytes source: last stolen frame > uploaded start file
134
+ image_path = seed_img
135
+ if not image_path:
136
+ if not start_file_obj: raise gr.Error("First generation requires an initial image OR video.")
137
+ image_path = start_file_obj.name
138
+ try:
139
+ out = call_backend(
140
+ prompt=prompt_txt.strip(),
141
+ image_bytes_path=image_path,
142
+ negative_prompt=(neg.strip() if (neg and neg.strip()) else None),
143
+ fps=int(fps_val),
144
+ vlen=int(vlen_val),
145
+ steps=int(steps_val) if steps_val else None,
146
+ seed=int(seed_val) if (seed_val not in [None,""]) else None
147
  )
148
+ except Exception as e:
149
+ raise gr.Error(str(e))
150
+ return out, out # preview path, current_video state
151
+
152
+ gen_btn.click(
153
+ on_generate,
154
+ inputs=[prompt, start_file, negative, fps, vlen, steps, seed, last_seed_img],
155
+ outputs=[video_out, current_video]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  )
157
 
158
+ def on_use(curr, used):
159
+ if not curr or not os.path.exists(curr):
160
+ raise gr.Error("Generate a video first.")
161
+ if curr not in used: used = used + [curr]
162
+ try:
163
+ seed_img = extract_last_frame(curr)
164
+ except Exception as e:
165
+ raise gr.Error(f"Could not extract last frame: {e}")
166
+ return used, seed_img
167
+
168
+ use_btn.click(
169
+ on_use,
170
+ inputs=[current_video, used_videos],
171
+ outputs=[used_videos, last_seed_img]
172
  )
173
 
174
+ def on_download(used):
175
+ if not used: raise gr.Error("Nothing to download. Click Use after generating clips.")
176
+ files = []
177
+ try:
178
+ merged = concat_videos(used); files.append(merged)
179
+ except Exception as e:
180
+ # still allow ZIP even if concat fails
181
+ print("concat error:", e)
182
+ try:
183
+ zipped = zip_used(used); files.append(zipped)
184
+ except Exception as e:
185
+ print("zip error:", e)
186
+ return files
187
+
188
+ dl_btn.click(on_download, inputs=[used_videos], outputs=[files_out])
189
 
190
  if __name__ == "__main__":
191
+ demo.launch()