Shalmoni commited on
Commit
6259109
·
verified ·
1 Parent(s): 5a6bbaa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -153
app.py CHANGED
@@ -1,15 +1,16 @@
1
- import os, io, time, random, base64, zipfile
2
- from typing import List, Tuple, Optional
 
3
 
4
  import requests
5
  from PIL import Image
6
  import gradio as gr
7
 
8
- # ========= Config =========
9
- MAX_FRAMES = 8 # how many upload slots & rows to render
10
  MODAL_BASE = "https://moonmath-ai--moonmath-i2v-backend-moonmathinference-run.modal.run"
11
 
12
- # ========= Helpers =========
13
  def _save_video_bytes(data: bytes, tag: str) -> str:
14
  os.makedirs("/mnt/data", exist_ok=True)
15
  path = f"/mnt/data/{tag}_{int(time.time())}.mp4"
@@ -27,126 +28,104 @@ def _download_to_bytes(url: str) -> bytes:
27
  r.raise_for_status()
28
  return r.content
29
 
30
- def call_modal_i2v(start_img: Image.Image, prompt: str, seed: Optional[int]) -> Tuple[Optional[str], str]:
31
  """
32
- POST to Modal with multipart 'image_bytes' and query args prompt & seed.
33
- Returns (mp4_path_or_None, debug_log).
 
 
 
34
  """
35
- dbg = []
36
  if seed in (None, 0, -1):
37
  seed = random.randint(1, 2**31 - 1)
38
 
39
- # Build URL (encode prompt)
40
- from urllib.parse import quote
41
  url = f"{MODAL_BASE}?prompt={quote(prompt)}&seed={seed}"
42
-
43
  files = {"image_bytes": ("start.png", _png_bytes_from_pil(start_img), "image/png")}
44
  headers = {"accept": "application/json"}
45
 
46
  try:
47
  resp = requests.post(url, files=files, headers=headers, timeout=600)
48
  ctype = (resp.headers.get("content-type") or "").lower()
49
- dbg.append(f"HTTP {resp.status_code}; content-type={ctype}")
50
 
51
- # Case A: raw bytes (not JSON)
52
  if "application/json" not in ctype:
53
  resp.raise_for_status()
54
- path = _save_video_bytes(resp.content, "pair")
55
- dbg.append(f"Saved raw video to {path}")
56
- return path, "\n".join(dbg)
57
 
58
- # Case B: JSON containing url or base64
59
  data = resp.json()
60
  video_url = data.get("video_url") or data.get("url") or data.get("result") or data.get("output")
61
- video_b64 = data.get("video_b64") or data.get("videoBase64")
62
-
63
- if video_url and isinstance(video_url, str):
64
  b = _download_to_bytes(video_url)
65
- path = _save_video_bytes(b, "pair")
66
- dbg.append(f"Downloaded video from {video_url} -> {path}")
67
- return path, "\n".join(dbg)
68
 
69
- if video_b64 and isinstance(video_b64, str):
 
70
  pad = (-len(video_b64)) % 4
71
  if pad: video_b64 += "=" * pad
72
  b = base64.b64decode(video_b64)
73
- path = _save_video_bytes(b, "pair")
74
- dbg.append("Decoded base64 video.")
75
- return path, "\n".join(dbg)
76
-
77
- # Nothing usable returned
78
- try:
79
- dbg.append(f"Backend JSON: {str(data)[:500]}")
80
- except Exception:
81
- pass
82
- return None, "\n".join(dbg)
83
 
84
- except Exception as e:
85
- dbg.append(f"Exception: {type(e).__name__}: {e}")
86
- return None, "\n".join(dbg)
87
 
88
- # ========= State handlers =========
89
- def add_images(files: List[str], images_state: List[Image.Image], names_state: List[str]):
90
  """
91
- Append uploads to state; return updated previews and row visibilities.
 
92
  """
93
- imgs, names = list(images_state), list(names_state)
94
  for f in files or []:
95
  try:
96
- img = Image.open(f).convert("RGB")
97
- imgs.append(img)
98
- names.append(os.path.basename(f))
99
  except Exception:
100
  continue
 
101
 
102
- # Outputs to update: image slots, labels, visibilities; pair rows visible up to len-1
103
- img_values, img_labels, img_vis = [], [], []
104
- pair_vis = []
105
  for i in range(MAX_FRAMES):
106
  if i < len(imgs):
107
- img_values.append(imgs[i])
108
- img_labels.append(f"Image {i+1}")
109
- img_vis.append(True)
110
  else:
111
- img_values.append(None)
112
- img_labels.append(f"Image {i+1}")
113
- img_vis.append(False)
114
 
115
- for i in range(MAX_FRAMES - 1):
116
- pair_vis.append(i < len(imgs) - 1)
 
 
117
 
118
- return imgs, names, img_values, img_labels, img_vis, pair_vis
119
 
120
  def clear_all():
121
- img_values = [None]*MAX_FRAMES
122
- img_labels = [f"Image {i+1}" for i in range(MAX_FRAMES)]
123
- img_vis = [False]*MAX_FRAMES
124
- pair_vis = [False]*(MAX_FRAMES-1)
125
- return [], [], img_values, img_labels, img_vis, pair_vis
126
-
127
- def stitch_pair(index: int,
128
- images: List[Image.Image],
129
- prompt: str,
130
- seed: int):
131
  """
132
- index is 0-based pair (0 => 1&2, 1 => 2&3...)
133
- We call Modal using the *first* image of the pair as the init image.
134
  """
135
- if not images or len(images) < index+2:
136
- gr.Warning("Upload more images first.")
137
- return None, "Not enough images."
138
 
139
- # Compose a minimal helpful prompt for continuity
140
  user = (prompt or "").strip()
141
- extra = f"(Transition between frame {index+1} {index+2} of the same shot.)"
142
- final_prompt = f"{user} {extra}".strip()
143
 
144
- path, dbg = call_modal_i2v(images[index], final_prompt, seed)
145
- if path is None:
146
- gr.Warning("Stitch failed. See debug log.")
147
- return path, dbg
148
 
149
- # ========= UI =========
150
  CSS = """
151
  .gradio-container { padding: 24px; }
152
  .pill button { border-radius: 999px !important; padding: 10px 18px; }
@@ -154,100 +133,64 @@ CSS = """
154
  """
155
 
156
  with gr.Blocks(css=CSS, title="Stitch — Upload & Stitch Adjacent Pairs") as demo:
157
- gr.Markdown("## Stitch — Upload stills, then generate between-frames videos\n"
158
- "Upload images in order. For each adjacent pair (1&2, 2&3, …), write a short transition prompt and click **Stitch**.")
159
 
160
- images_state = gr.State([]) # List[PIL.Image]
161
- names_state = gr.State([]) # List[str]
162
 
163
  with gr.Row():
164
- # Left column: image slots
165
  with gr.Column(scale=1, min_width=340):
166
  uploader = gr.Files(label="Add images (in order)", file_types=["image"], file_count="multiple")
167
  clear_btn = gr.Button("Clear all", elem_classes=["pill"])
168
- image_slots = []
169
- for i in range(MAX_FRAMES):
170
- image_slots.append(
171
- gr.Image(label=f"Image {i+1}", interactive=False, visible=False)
172
- )
173
 
174
- # Middle column: per-pair prompt + button
175
- with gr.Column(scale=1, min_width=340):
176
- seed_in = gr.Number(value=0, precision=0, label="Seed (0 = random)")
177
- prompt_boxes = []
178
- stitch_buttons = []
179
- for i in range(MAX_FRAMES - 1):
180
- prompt_boxes.append(
181
- gr.Textbox(
182
- placeholder=f"Prompt for transition between Image {i+1} & {i+2}",
183
- lines=2, label="Prompt", elem_classes=["rounded"], visible=False
184
- )
185
- )
186
- stitch_buttons.append(
187
- gr.Button(f"Stitch {i+1}&{i+2}", elem_classes=["pill"], visible=False)
188
- )
189
-
190
- # Right column: per-pair video outputs + shared debug
191
  with gr.Column(scale=1, min_width=360):
192
- video_outputs = []
193
- for i in range(MAX_FRAMES - 1):
194
- video_outputs.append(
195
- gr.Video(label=f"Video (image {i+1}+{i+2}) output", visible=False)
196
- )
197
- debug_box = gr.Code(label="Debug log", interactive=False)
198
-
199
- # ---- Wiring: upload & clear ----
200
- uploader.upload(
201
- fn=add_images,
202
- inputs=[uploader, images_state, names_state],
203
- outputs=[
204
- images_state, names_state,
205
- # image values, labels, visibilities
206
- *image_slots, # values (Image components accept PIL Image)
207
- *[s for s in image_slots], # labels: set via .label below (we'll hack via .update)
208
- *[s for s in image_slots], # visibility
209
- *[b for b in stitch_buttons] # visibility for rows (we’ll mirror to prompt/video too)
210
- ],
211
- queue=False
212
- )
213
-
214
- # NOTE: Gradio can't directly set multiple attributes with one function return to each component slot,
215
- # so we will do a lightweight post-upload JS update using .update. Simpler: tie visibility of prompt/video
216
- # to the corresponding button's visibility in another handler:
217
 
218
- def reflect_row_visibility(images: List[Image.Image]):
219
- n = len(images)
220
- vis = [i < n-1 for i in range(MAX_FRAMES-1)]
221
- # return prompt visibilities, button visibilities, video visibilities
222
- return [gr.Textbox(visible=vis[i]) for i in range(MAX_FRAMES-1)] + \
223
- [gr.Button(visible=vis[i]) for i in range(MAX_FRAMES-1)] + \
224
- [gr.Video(visible=vis[i]) for i in range(MAX_FRAMES-1)]
225
 
 
226
  uploader.upload(
227
- fn=reflect_row_visibility,
228
- inputs=[images_state],
229
- outputs=[*prompt_boxes, *stitch_buttons, *video_outputs],
230
- queue=False
 
 
 
 
 
231
  )
232
 
 
233
  clear_btn.click(
234
  fn=clear_all,
235
  inputs=[],
236
- outputs=[images_state, names_state, *image_slots, *image_slots, *image_slots, *stitch_buttons],
237
- queue=False
238
- ).then(
239
- fn=lambda imgs: reflect_row_visibility(imgs),
240
- inputs=[images_state],
241
- outputs=[*prompt_boxes, *stitch_buttons, *video_outputs],
242
- queue=False
243
  )
244
 
245
- # ---- Wiring: per-pair stitchers ----
246
  for i in range(MAX_FRAMES - 1):
247
  stitch_buttons[i].click(
248
- fn=lambda prompt, seed, imgs, idx=i: stitch_pair(idx, imgs, prompt, int(seed or 0)),
249
  inputs=[prompt_boxes[i], seed_in, images_state],
250
- outputs=[video_outputs[i], debug_box]
251
  )
252
 
253
  if __name__ == "__main__":
 
1
+ import os, io, time, random, base64
2
+ from typing import List, Optional, Tuple
3
+ from urllib.parse import quote
4
 
5
  import requests
6
  from PIL import Image
7
  import gradio as gr
8
 
9
+ # ---------- CONFIG ----------
10
+ MAX_FRAMES = 12 # how many visible "Image N" slots & rows to render
11
  MODAL_BASE = "https://moonmath-ai--moonmath-i2v-backend-moonmathinference-run.modal.run"
12
 
13
+ # ---------- Helpers ----------
14
  def _save_video_bytes(data: bytes, tag: str) -> str:
15
  os.makedirs("/mnt/data", exist_ok=True)
16
  path = f"/mnt/data/{tag}_{int(time.time())}.mp4"
 
28
  r.raise_for_status()
29
  return r.content
30
 
31
+ def call_modal_i2v(start_img: Image.Image, prompt: str, seed: Optional[int]) -> Optional[str]:
32
  """
33
+ Your Modal call, exactly like your JS snippet:
34
+ POST …?prompt={user_prompt}&seed={seed}
35
+ multipart field: image_bytes
36
+ Accepts raw mp4 bytes OR JSON { video_url | url | video_b64 }.
37
+ Returns path to saved mp4 or None on failure.
38
  """
 
39
  if seed in (None, 0, -1):
40
  seed = random.randint(1, 2**31 - 1)
41
 
 
 
42
  url = f"{MODAL_BASE}?prompt={quote(prompt)}&seed={seed}"
 
43
  files = {"image_bytes": ("start.png", _png_bytes_from_pil(start_img), "image/png")}
44
  headers = {"accept": "application/json"}
45
 
46
  try:
47
  resp = requests.post(url, files=files, headers=headers, timeout=600)
48
  ctype = (resp.headers.get("content-type") or "").lower()
 
49
 
50
+ # Case A: raw video bytes
51
  if "application/json" not in ctype:
52
  resp.raise_for_status()
53
+ return _save_video_bytes(resp.content, "pair")
 
 
54
 
55
+ # Case B: JSON with URL or base64
56
  data = resp.json()
57
  video_url = data.get("video_url") or data.get("url") or data.get("result") or data.get("output")
58
+ if isinstance(video_url, str) and (video_url.startswith("http://") or video_url.startswith("https://")):
 
 
59
  b = _download_to_bytes(video_url)
60
+ return _save_video_bytes(b, "pair")
 
 
61
 
62
+ video_b64 = data.get("video_b64") or data.get("videoBase64")
63
+ if isinstance(video_b64, str):
64
  pad = (-len(video_b64)) % 4
65
  if pad: video_b64 += "=" * pad
66
  b = base64.b64decode(video_b64)
67
+ return _save_video_bytes(b, "pair")
68
+ except Exception:
69
+ pass
 
 
 
 
 
 
 
70
 
71
+ return None
 
 
72
 
73
+ # ---------- State & wiring helpers ----------
74
+ def handle_upload(files: List[str], images_state: List[Image.Image]):
75
  """
76
+ Append newly uploaded images to state (keeps first MAX_FRAMES).
77
+ Returns updates for all image slots, prompt rows, stitch buttons, and video boxes.
78
  """
79
+ imgs = list(images_state)
80
  for f in files or []:
81
  try:
82
+ imgs.append(Image.open(f).convert("RGB"))
 
 
83
  except Exception:
84
  continue
85
+ imgs = imgs[:MAX_FRAMES] # cap
86
 
87
+ # Build updates
88
+ image_updates = []
 
89
  for i in range(MAX_FRAMES):
90
  if i < len(imgs):
91
+ image_updates.append(gr.Image.update(value=imgs[i], visible=True, label=f"Image {i+1}"))
 
 
92
  else:
93
+ image_updates.append(gr.Image.update(value=None, visible=False, label=f"Image {i+1}"))
 
 
94
 
95
+ row_visible = [i < len(imgs) - 1 for i in range(MAX_FRAMES - 1)]
96
+ prompt_updates = [gr.Textbox.update(visible=row_visible[i], value="") for i in range(MAX_FRAMES - 1)]
97
+ button_updates = [gr.Button.update(visible=row_visible[i]) for i in range(MAX_FRAMES - 1)]
98
+ video_updates = [gr.Video.update(visible=row_visible[i], value=None) for i in range(MAX_FRAMES - 1)]
99
 
100
+ return imgs, image_updates, prompt_updates, button_updates, video_updates
101
 
102
  def clear_all():
103
+ imgs = []
104
+ image_updates = [gr.Image.update(value=None, visible=False) for _ in range(MAX_FRAMES)]
105
+ prompt_updates = [gr.Textbox.update(visible=False, value="") for _ in range(MAX_FRAMES - 1)]
106
+ button_updates = [gr.Button.update(visible=False) for _ in range(MAX_FRAMES - 1)]
107
+ video_updates = [gr.Video.update(visible=False, value=None) for _ in range(MAX_FRAMES - 1)]
108
+ return imgs, image_updates, prompt_updates, button_updates, video_updates
109
+
110
+ def stitch_pair(idx: int, images: List[Image.Image], prompt: str, seed: int):
 
 
111
  """
112
+ idx is 0-based pair: 0 => stitch 1&2, 1 => stitch 2&3, ...
113
+ We send Image(idx) as the start frame, plus the user prompt (you can add your own template here).
114
  """
115
+ if not images or len(images) < idx + 2:
116
+ gr.Warning("Please upload enough images first.")
117
+ return None
118
 
 
119
  user = (prompt or "").strip()
120
+ # Optional light context; tweak or remove as you wish:
121
+ final_prompt = f"{user} (Transition between frame {idx+1} → {idx+2}.)".strip()
122
 
123
+ path = call_modal_i2v(images[idx], final_prompt, int(seed or 0))
124
+ if not path:
125
+ gr.Warning("Stitch failed. Try again or adjust your prompt.")
126
+ return path
127
 
128
+ # ---------- UI ----------
129
  CSS = """
130
  .gradio-container { padding: 24px; }
131
  .pill button { border-radius: 999px !important; padding: 10px 18px; }
 
133
  """
134
 
135
  with gr.Blocks(css=CSS, title="Stitch — Upload & Stitch Adjacent Pairs") as demo:
136
+ gr.Markdown("## Stitch — Upload stills, then generate between-frames videos")
137
+ gr.Markdown("Upload images in order. For each adjacent pair (1&2, 2&3, …), write a short transition prompt and click **Stitch**.")
138
 
139
+ images_state = gr.State([]) # List[PIL.Image]
 
140
 
141
  with gr.Row():
142
+ # Left: Image slots
143
  with gr.Column(scale=1, min_width=340):
144
  uploader = gr.Files(label="Add images (in order)", file_types=["image"], file_count="multiple")
145
  clear_btn = gr.Button("Clear all", elem_classes=["pill"])
146
+ image_slots = [gr.Image(label=f"Image {i+1}", interactive=False, visible=False) for i in range(MAX_FRAMES)]
 
 
 
 
147
 
148
+ # Middle: per-pair prompt + Stitch button
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  with gr.Column(scale=1, min_width=360):
150
+ seed_in = gr.Number(value=0, precision=0, label="Seed (0 = random)")
151
+ prompt_boxes = [gr.Textbox(placeholder=f"Prompt for transition between Image {i+1} & {i+2}",
152
+ lines=2, label="Prompt", elem_classes=["rounded"], visible=False)
153
+ for i in range(MAX_FRAMES - 1)]
154
+ stitch_buttons = [gr.Button(f"Stitch {i+1}&{i+2}", elem_classes=["pill"], visible=False)
155
+ for i in range(MAX_FRAMES - 1)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
+ # Right: per-pair video outputs
158
+ with gr.Column(scale=1, min_width=360):
159
+ video_outputs = [gr.Video(label=f"Video (image {i+1}+{i+2}) output", visible=False)
160
+ for i in range(MAX_FRAMES - 1)]
 
 
 
161
 
162
+ # Upload wiring
163
  uploader.upload(
164
+ fn=handle_upload,
165
+ inputs=[uploader, images_state],
166
+ outputs=[
167
+ images_state,
168
+ *image_slots, # image updates
169
+ *prompt_boxes, # prompt visibility + reset
170
+ *stitch_buttons, # button visibility
171
+ *video_outputs # video visibility + reset
172
+ ]
173
  )
174
 
175
+ # Clear wiring
176
  clear_btn.click(
177
  fn=clear_all,
178
  inputs=[],
179
+ outputs=[
180
+ images_state,
181
+ *image_slots,
182
+ *prompt_boxes,
183
+ *stitch_buttons,
184
+ *video_outputs
185
+ ]
186
  )
187
 
188
+ # Per-pair stitch wiring
189
  for i in range(MAX_FRAMES - 1):
190
  stitch_buttons[i].click(
191
+ fn=lambda p, s, imgs, idx=i: stitch_pair(idx, imgs, p, s),
192
  inputs=[prompt_boxes[i], seed_in, images_state],
193
+ outputs=[video_outputs[i]]
194
  )
195
 
196
  if __name__ == "__main__":