Shalmoni commited on
Commit
75c12be
·
verified ·
1 Parent(s): a667aee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -73
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import os, io, time, base64, random, subprocess
2
- from typing import Optional
3
  from urllib.parse import quote
4
 
5
  import requests
@@ -9,6 +9,10 @@ import gradio as gr
9
  # -------- Modal inference endpoint (dev) --------
10
  INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"
11
 
 
 
 
 
12
  # -------- small helpers --------
13
  def _save_video_bytes(data: bytes, tag: str) -> str:
14
  os.makedirs("/tmp", exist_ok=True)
@@ -35,7 +39,6 @@ def stitch_call(start_img: Image.Image, end_img: Image.Image, prompt: str, seed:
35
  seed = random.randint(1, 2**31 - 1)
36
 
37
  url = f"{INFERENCE_URL}?prompt={quote(prompt or '')}&seed={seed}"
38
-
39
  files = {
40
  "image_bytes": ("start.png", _png_bytes(start_img), "image/png"),
41
  "image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"),
@@ -55,122 +58,217 @@ def stitch_call(start_img: Image.Image, end_img: Image.Image, prompt: str, seed:
55
  data = resp.json()
56
  video_url = data.get("video_url") or data.get("url") or data.get("result")
57
  if isinstance(video_url, str) and video_url.startswith("http"):
58
- b = _download_to_bytes(video_url)
59
- return _save_video_bytes(b, "stitch")
60
 
61
  video_b64 = data.get("video_b64")
62
  if isinstance(video_b64, str):
63
  pad = (-len(video_b64)) % 4
64
  if pad:
65
  video_b64 += "=" * pad
66
- b = base64.b64decode(video_b64)
67
- return _save_video_bytes(b, "stitch")
68
 
69
  except Exception as e:
70
  print("stitch_call error:", e)
71
 
72
  return None
73
 
74
- # -------- FFmpeg-based concatenation --------
75
- def concat_videos(vid1: str, vid2: str) -> Optional[str]:
76
- if not vid1 or not vid2:
 
77
  return None
78
  try:
79
  os.makedirs("/tmp", exist_ok=True)
80
  out_path = f"/tmp/final_{int(time.time())}.mp4"
81
-
82
- # Create a temporary file list for ffmpeg
83
  list_file = f"/tmp/list_{int(time.time())}.txt"
84
  with open(list_file, "w") as f:
85
- f.write(f"file '{vid1}'\n")
86
- f.write(f"file '{vid2}'\n")
87
-
88
- # Run ffmpeg concat
89
  subprocess.run(
90
  ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_file, "-c", "copy", out_path],
91
- check=True,
92
- stdout=subprocess.PIPE,
93
- stderr=subprocess.PIPE,
94
  )
95
-
96
  return out_path
97
  except Exception as e:
98
- print("concat_videos error:", e)
99
  return None
100
 
101
- # -------- Gradio callbacks --------
102
- def stitch_12(prompt12, seed, img1, img2):
103
- path = stitch_call(img1, img2, prompt12 or "", int(seed or 0))
104
- return path
105
-
106
- def stitch_23(prompt23, seed, img2, img3):
107
- path = stitch_call(img2, img3, prompt23 or "", int(seed or 0))
108
- return path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
- def stitch_all(vid12, vid23):
111
- if vid12 is None or vid23 is None:
112
- gr.Warning("Generate both videos first before stitching all.")
 
 
 
 
113
  return None
114
- return concat_videos(vid12, vid23)
115
 
116
- # ---------- UI ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  CSS = """
118
  .gradio-container { padding: 24px; }
119
  .pill button { border-radius: 999px !important; padding: 10px 18px; }
120
  .rounded textarea { border-radius: 16px !important; }
 
 
 
121
  """
122
 
123
- with gr.Blocks(css=CSS, title="Stitch — vertical flow") as demo:
124
- gr.Markdown("## StitchMaster")
125
-
126
- # Top row: 1 - 2 - 3 (side-by-side)
127
- with gr.Row():
128
- with gr.Column(scale=1, min_width=280):
129
- img1 = gr.Image(label="Image 1 upload", type="pil")
130
- with gr.Column(scale=1, min_width=280):
131
- img2 = gr.Image(label="Image 2 upload", type="pil")
132
- with gr.Column(scale=1, min_width=280):
133
- img3 = gr.Image(label="Image 3 upload", type="pil")
134
-
135
- # Seed under the uploads
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
137
 
138
- # Stitch 1→2: LEFT = prompt+button, RIGHT = video
139
- with gr.Row():
140
- with gr.Column(scale=1, min_width=420):
141
- prompt12 = gr.Textbox(
142
- placeholder="Prompt for stitching 1→2",
143
- lines=2, label="Prompt (1→2)", elem_classes=["rounded"]
144
- )
145
- btn12 = gr.Button("Generate 1→2", elem_classes=["pill"])
146
- with gr.Column(scale=1, min_width=420):
147
- vid12 = gr.Video(label="Video (1→2)", interactive=False)
148
-
149
- # Stitch 2→3: LEFT = prompt+button, RIGHT = video
150
  with gr.Row():
151
  with gr.Column(scale=1, min_width=420):
152
- prompt23 = gr.Textbox(
153
- placeholder="Prompt for stitching 2→3",
154
- lines=2, label="Prompt (2→3)", elem_classes=["rounded"]
 
 
155
  )
156
- btn23 = gr.Button("Generate 2→3", elem_classes=["pill"])
 
157
  with gr.Column(scale=1, min_width=420):
158
- vid23 = gr.Video(label="Video (2→3)", interactive=False)
 
 
 
 
 
 
 
 
159
 
160
- # Final merge: LEFT = button, RIGHT = final video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  with gr.Row():
162
  with gr.Column(scale=1, min_width=420):
163
- btn_all = gr.Button("Stitch Together", elem_classes=["pill"])
164
  with gr.Column(scale=1, min_width=420):
165
- vid_all = gr.Video(label="Final Combined Video", interactive=False)
166
-
167
- # keep your existing .click wiring below this block
168
-
169
 
170
- # Wire buttons
171
- btn12.click(stitch_12, inputs=[prompt12, seed, img1, img2], outputs=[vid12])
172
- btn23.click(stitch_23, inputs=[prompt23, seed, img2, img3], outputs=[vid23])
173
- btn_all.click(stitch_all, inputs=[vid12, vid23], outputs=[vid_all])
 
174
 
175
  if __name__ == "__main__":
176
  demo.queue().launch()
 
1
  import os, io, time, base64, random, subprocess
2
+ from typing import Optional, List
3
  from urllib.parse import quote
4
 
5
  import requests
 
9
  # -------- Modal inference endpoint (dev) --------
10
  INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"
11
 
12
+ # -------- settings --------
13
+ MAX_SLOTS = 12 # max image slots user can reveal
14
+ MAX_TIMELINE = 20 # max clips in the timeline
15
+
16
  # -------- small helpers --------
17
  def _save_video_bytes(data: bytes, tag: str) -> str:
18
  os.makedirs("/tmp", exist_ok=True)
 
39
  seed = random.randint(1, 2**31 - 1)
40
 
41
  url = f"{INFERENCE_URL}?prompt={quote(prompt or '')}&seed={seed}"
 
42
  files = {
43
  "image_bytes": ("start.png", _png_bytes(start_img), "image/png"),
44
  "image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"),
 
58
  data = resp.json()
59
  video_url = data.get("video_url") or data.get("url") or data.get("result")
60
  if isinstance(video_url, str) and video_url.startswith("http"):
61
+ return _save_video_bytes(_download_to_bytes(video_url), "stitch")
 
62
 
63
  video_b64 = data.get("video_b64")
64
  if isinstance(video_b64, str):
65
  pad = (-len(video_b64)) % 4
66
  if pad:
67
  video_b64 += "=" * pad
68
+ return _save_video_bytes(base64.b64decode(video_b64), "stitch")
 
69
 
70
  except Exception as e:
71
  print("stitch_call error:", e)
72
 
73
  return None
74
 
75
+ # -------- FFmpeg-based concatenation (N clips) --------
76
+ def concat_many(videos: List[str]) -> Optional[str]:
77
+ vids = [v for v in videos if v]
78
+ if len(vids) < 2:
79
  return None
80
  try:
81
  os.makedirs("/tmp", exist_ok=True)
82
  out_path = f"/tmp/final_{int(time.time())}.mp4"
 
 
83
  list_file = f"/tmp/list_{int(time.time())}.txt"
84
  with open(list_file, "w") as f:
85
+ for v in vids:
86
+ f.write(f"file '{v}'\n")
 
 
87
  subprocess.run(
88
  ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_file, "-c", "copy", out_path],
89
+ check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
 
 
90
  )
 
91
  return out_path
92
  except Exception as e:
93
+ print("concat_many error:", e)
94
  return None
95
 
96
+ # =========================
97
+ # Gradio callbacks / state ops
98
+ # =========================
99
+ def add_image_slot(visible_slots: int):
100
+ """Reveal one more upload slot (up to MAX_SLOTS)."""
101
+ new_count = min(MAX_SLOTS, visible_slots + 1)
102
+ return new_count
103
+
104
+ def collect_choices(*imgs):
105
+ """Build dropdown choices of available indices (1-based labels)."""
106
+ choices = []
107
+ for i, img in enumerate(imgs, start=1):
108
+ if img is not None:
109
+ choices.append(str(i))
110
+ # Return same list for both start/end dropdowns
111
+ return gr.update(choices=choices), gr.update(choices=choices)
112
+
113
+ def stitch_selected(prompt, seed, start_idx_str, end_idx_str, *imgs):
114
+ """Run inference for selected start/end indices (1-based strings)."""
115
+ if not start_idx_str or not end_idx_str:
116
+ gr.Warning("Please select Start and End frames.")
117
+ return None
118
+ try:
119
+ s = int(start_idx_str) - 1
120
+ e = int(end_idx_str) - 1
121
+ except Exception:
122
+ gr.Warning("Invalid Start/End selection.")
123
+ return None
124
 
125
+ if s < 0 or e < 0 or s >= len(imgs) or e >= len(imgs):
126
+ gr.Warning("Start/End out of range.")
127
+ return None
128
+ start_img = imgs[s]
129
+ end_img = imgs[e]
130
+ if start_img is None or end_img is None:
131
+ gr.Warning("Selected slots are empty.")
132
  return None
 
133
 
134
+ vid = stitch_call(start_img, end_img, prompt or "", int(seed or 0))
135
+ if not vid:
136
+ gr.Warning("Generation failed.")
137
+ return None
138
+ return vid # path for preview
139
+
140
+ def add_to_timeline(preview_path, timeline_paths: List[str]):
141
+ """Append preview_path to timeline state; return updated per-slot outputs."""
142
+ if not preview_path:
143
+ gr.Warning("Generate a clip first.")
144
+ return timeline_paths, *([gr.update(value=None)] * MAX_TIMELINE)
145
+
146
+ # append if room
147
+ tl = list(timeline_paths or [])
148
+ if len(tl) >= MAX_TIMELINE:
149
+ gr.Warning("Timeline full.")
150
+ return tl, *([gr.update(value=None)] * MAX_TIMELINE)
151
+
152
+ tl.append(preview_path)
153
+
154
+ # map into video components
155
+ outputs = []
156
+ for i in range(MAX_TIMELINE):
157
+ outputs.append(gr.update(value=tl[i] if i < len(tl) else None))
158
+ return tl, *outputs
159
+
160
+ def stitch_all_from_timeline(timeline_paths: List[str]):
161
+ vids = list(timeline_paths or [])
162
+ if len(vids) < 2:
163
+ gr.Warning("Add at least two clips to the timeline first.")
164
+ return None
165
+ out = concat_many(vids)
166
+ if not out:
167
+ gr.Warning("Failed to concatenate clips.")
168
+ return out
169
+
170
+ # =========================
171
+ # UI
172
+ # =========================
173
  CSS = """
174
  .gradio-container { padding: 24px; }
175
  .pill button { border-radius: 999px !important; padding: 10px 18px; }
176
  .rounded textarea { border-radius: 16px !important; }
177
+ .gallery-row { display:flex; gap:16px; overflow-x:auto; padding:8px 4px; }
178
+ .gallery-row .gradio-image { min-width: 220px; }
179
+ .timeline-row { display:flex; gap:16px; overflow-x:auto; padding:8px 4px; }
180
  """
181
 
182
+ with gr.Blocks(css=CSS, title="StitchMaster") as demo:
183
+ gr.Markdown("## StitchMaster — Upload images, stitch between frames, build a timeline, and export a single video.")
184
+
185
+ # --- State ---
186
+ visible_slots = gr.State(value=3) # how many image slots are visible
187
+ timeline_state = gr.State(value=[]) # list[str] of video file paths (timeline)
188
+
189
+ # --- Image gallery (growing) ---
190
+ with gr.Row(elem_classes=["gallery-row"]):
191
+ img_comps = []
192
+ for i in range(MAX_SLOTS):
193
+ comp = gr.Image(label=f"Image {i+1} upload", type="pil", visible=(i < 3))
194
+ img_comps.append(comp)
195
+ add_btn = gr.Button("+ Add image")
196
+
197
+ # clicking add → reveal one more slot
198
+ add_btn.click(
199
+ fn=add_image_slot,
200
+ inputs=[visible_slots],
201
+ outputs=[visible_slots],
202
+ )
203
+
204
+ # reflect visibility changes whenever visible_slots changes
205
+ # (we re-render all image components with correct visibility)
206
+ def _reveal_slots(n, *imgs):
207
+ updates = []
208
+ for i in range(MAX_SLOTS):
209
+ updates.append(gr.update(visible=(i < int(n))))
210
+ return updates
211
+
212
+ visible_slots.change(
213
+ fn=_reveal_slots,
214
+ inputs=[visible_slots] + img_comps,
215
+ outputs=img_comps
216
+ )
217
+
218
+ # --- Stitch controls ---
219
  seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
220
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  with gr.Row():
222
  with gr.Column(scale=1, min_width=420):
223
+ start_dd = gr.Dropdown(label="Start frame", choices=[], interactive=True)
224
+ end_dd = gr.Dropdown(label="End frame", choices=[], interactive=True)
225
+ prompt = gr.Textbox(
226
+ placeholder="Describe the transition between the selected start and end frames…",
227
+ lines=3, label="Prompt", elem_classes=["rounded"]
228
  )
229
+ run_btn = gr.Button("Stitch", elem_classes=["pill"])
230
+ add_tl_btn = gr.Button("Add to timeline", elem_classes=["pill"])
231
  with gr.Column(scale=1, min_width=420):
232
+ preview = gr.Video(label="Video output", interactive=False)
233
+
234
+ # keep start/end dropdowns up to date based on which slots actually have images
235
+ for comp in img_comps:
236
+ comp.change(
237
+ fn=collect_choices,
238
+ inputs=img_comps,
239
+ outputs=[start_dd, end_dd]
240
+ )
241
 
242
+ # stitch action preview
243
+ run_btn.click(
244
+ fn=stitch_selected,
245
+ inputs=[prompt, seed, start_dd, end_dd] + img_comps,
246
+ outputs=[preview]
247
+ )
248
+
249
+ # add to timeline action → update state and visible clips
250
+ # Prepare timeline video components (scroll row)
251
+ with gr.Row(elem_classes=["timeline-row"]):
252
+ timeline_videos = [gr.Video(label=f"Clip {i+1}", interactive=False) for i in range(MAX_TIMELINE)]
253
+
254
+ add_tl_btn.click(
255
+ fn=add_to_timeline,
256
+ inputs=[preview, timeline_state],
257
+ outputs=[timeline_state] + timeline_videos
258
+ )
259
+
260
+ # final stitch all (concatenate in order)
261
  with gr.Row():
262
  with gr.Column(scale=1, min_width=420):
263
+ stitch_all_btn = gr.Button("Stitch All", elem_classes=["pill"])
264
  with gr.Column(scale=1, min_width=420):
265
+ final_vid = gr.Video(label="Stitched Video Output", interactive=False)
 
 
 
266
 
267
+ stitch_all_btn.click(
268
+ fn=stitch_all_from_timeline,
269
+ inputs=[timeline_state],
270
+ outputs=[final_vid]
271
+ )
272
 
273
  if __name__ == "__main__":
274
  demo.queue().launch()