Shalmoni commited on
Commit
26b0caf
·
verified ·
1 Parent(s): 6259109

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +223 -177
app.py CHANGED
@@ -1,197 +1,243 @@
1
- import os, io, time, random, base64
2
- from typing import List, Optional, Tuple
3
- from urllib.parse import quote
 
 
 
 
4
 
5
  import requests
6
- from PIL import Image
7
  import gradio as gr
8
 
9
- # ---------- CONFIG ----------
10
- MAX_FRAMES = 12 # how many visible "Image N" slots & rows to render
11
- MODAL_BASE = "https://moonmath-ai--moonmath-i2v-backend-moonmathinference-run.modal.run"
12
-
13
- # ---------- Helpers ----------
14
- def _save_video_bytes(data: bytes, tag: str) -> str:
15
- os.makedirs("/mnt/data", exist_ok=True)
16
- path = f"/mnt/data/{tag}_{int(time.time())}.mp4"
 
 
 
 
 
 
 
 
 
 
 
17
  with open(path, "wb") as f:
18
- f.write(data)
19
  return path
20
 
21
- def _png_bytes_from_pil(img: Image.Image) -> bytes:
22
- buf = io.BytesIO()
23
- img.save(buf, format="PNG")
24
- return buf.getvalue()
25
 
26
- def _download_to_bytes(url: str) -> bytes:
27
- r = requests.get(url, timeout=180)
28
  r.raise_for_status()
29
  return r.content
30
 
31
- def call_modal_i2v(start_img: Image.Image, prompt: str, seed: Optional[int]) -> Optional[str]:
32
- """
33
- Your Modal call, exactly like your JS snippet:
34
- POST …?prompt={user_prompt}&seed={seed}
35
- multipart field: image_bytes
36
- Accepts raw mp4 bytes OR JSON { video_url | url | video_b64 }.
37
- Returns path to saved mp4 or None on failure.
38
- """
39
- if seed in (None, 0, -1):
40
- seed = random.randint(1, 2**31 - 1)
41
-
42
- url = f"{MODAL_BASE}?prompt={quote(prompt)}&seed={seed}"
43
- files = {"image_bytes": ("start.png", _png_bytes_from_pil(start_img), "image/png")}
44
- headers = {"accept": "application/json"}
 
 
 
 
 
 
 
 
 
45
 
46
  try:
47
- resp = requests.post(url, files=files, headers=headers, timeout=600)
48
- ctype = (resp.headers.get("content-type") or "").lower()
49
-
50
- # Case A: raw video bytes
51
- if "application/json" not in ctype:
52
- resp.raise_for_status()
53
- return _save_video_bytes(resp.content, "pair")
54
-
55
- # Case B: JSON with URL or base64
56
- data = resp.json()
57
- video_url = data.get("video_url") or data.get("url") or data.get("result") or data.get("output")
58
- if isinstance(video_url, str) and (video_url.startswith("http://") or video_url.startswith("https://")):
59
- b = _download_to_bytes(video_url)
60
- return _save_video_bytes(b, "pair")
61
-
62
- video_b64 = data.get("video_b64") or data.get("videoBase64")
63
- if isinstance(video_b64, str):
64
- pad = (-len(video_b64)) % 4
65
- if pad: video_b64 += "=" * pad
66
- b = base64.b64decode(video_b64)
67
- return _save_video_bytes(b, "pair")
68
- except Exception:
69
- pass
70
-
71
- return None
72
-
73
- # ---------- State & wiring helpers ----------
74
- def handle_upload(files: List[str], images_state: List[Image.Image]):
75
- """
76
- Append newly uploaded images to state (keeps first MAX_FRAMES).
77
- Returns updates for all image slots, prompt rows, stitch buttons, and video boxes.
78
- """
79
- imgs = list(images_state)
80
- for f in files or []:
81
- try:
82
- imgs.append(Image.open(f).convert("RGB"))
83
- except Exception:
84
- continue
85
- imgs = imgs[:MAX_FRAMES] # cap
86
-
87
- # Build updates
88
- image_updates = []
89
- for i in range(MAX_FRAMES):
90
- if i < len(imgs):
91
- image_updates.append(gr.Image.update(value=imgs[i], visible=True, label=f"Image {i+1}"))
92
  else:
93
- image_updates.append(gr.Image.update(value=None, visible=False, label=f"Image {i+1}"))
94
-
95
- row_visible = [i < len(imgs) - 1 for i in range(MAX_FRAMES - 1)]
96
- prompt_updates = [gr.Textbox.update(visible=row_visible[i], value="") for i in range(MAX_FRAMES - 1)]
97
- button_updates = [gr.Button.update(visible=row_visible[i]) for i in range(MAX_FRAMES - 1)]
98
- video_updates = [gr.Video.update(visible=row_visible[i], value=None) for i in range(MAX_FRAMES - 1)]
99
-
100
- return imgs, image_updates, prompt_updates, button_updates, video_updates
101
-
102
- def clear_all():
103
- imgs = []
104
- image_updates = [gr.Image.update(value=None, visible=False) for _ in range(MAX_FRAMES)]
105
- prompt_updates = [gr.Textbox.update(visible=False, value="") for _ in range(MAX_FRAMES - 1)]
106
- button_updates = [gr.Button.update(visible=False) for _ in range(MAX_FRAMES - 1)]
107
- video_updates = [gr.Video.update(visible=False, value=None) for _ in range(MAX_FRAMES - 1)]
108
- return imgs, image_updates, prompt_updates, button_updates, video_updates
109
-
110
- def stitch_pair(idx: int, images: List[Image.Image], prompt: str, seed: int):
111
- """
112
- idx is 0-based pair: 0 => stitch 1&2, 1 => stitch 2&3, ...
113
- We send Image(idx) as the start frame, plus the user prompt (you can add your own template here).
114
- """
115
- if not images or len(images) < idx + 2:
116
- gr.Warning("Please upload enough images first.")
117
- return None
118
-
119
- user = (prompt or "").strip()
120
- # Optional light context; tweak or remove as you wish:
121
- final_prompt = f"{user} (Transition between frame {idx+1} → {idx+2}.)".strip()
122
-
123
- path = call_modal_i2v(images[idx], final_prompt, int(seed or 0))
124
- if not path:
125
- gr.Warning("Stitch failed. Try again or adjust your prompt.")
126
- return path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- # ---------- UI ----------
129
- CSS = """
130
- .gradio-container { padding: 24px; }
131
- .pill button { border-radius: 999px !important; padding: 10px 18px; }
132
- .rounded textarea { border-radius: 16px !important; }
133
- """
134
-
135
- with gr.Blocks(css=CSS, title="Stitch — Upload & Stitch Adjacent Pairs") as demo:
136
- gr.Markdown("## Stitch — Upload stills, then generate between-frames videos")
137
- gr.Markdown("Upload images in order. For each adjacent pair (1&2, 2&3, …), write a short transition prompt and click **Stitch**.")
138
-
139
- images_state = gr.State([]) # List[PIL.Image]
140
-
141
- with gr.Row():
142
- # Left: Image slots
143
- with gr.Column(scale=1, min_width=340):
144
- uploader = gr.Files(label="Add images (in order)", file_types=["image"], file_count="multiple")
145
- clear_btn = gr.Button("Clear all", elem_classes=["pill"])
146
- image_slots = [gr.Image(label=f"Image {i+1}", interactive=False, visible=False) for i in range(MAX_FRAMES)]
147
-
148
- # Middle: per-pair prompt + Stitch button
149
- with gr.Column(scale=1, min_width=360):
150
- seed_in = gr.Number(value=0, precision=0, label="Seed (0 = random)")
151
- prompt_boxes = [gr.Textbox(placeholder=f"Prompt for transition between Image {i+1} & {i+2}",
152
- lines=2, label="Prompt", elem_classes=["rounded"], visible=False)
153
- for i in range(MAX_FRAMES - 1)]
154
- stitch_buttons = [gr.Button(f"Stitch {i+1}&{i+2}", elem_classes=["pill"], visible=False)
155
- for i in range(MAX_FRAMES - 1)]
156
-
157
- # Right: per-pair video outputs
158
- with gr.Column(scale=1, min_width=360):
159
- video_outputs = [gr.Video(label=f"Video (image {i+1}+{i+2}) output", visible=False)
160
- for i in range(MAX_FRAMES - 1)]
161
-
162
- # Upload wiring
163
- uploader.upload(
164
- fn=handle_upload,
165
- inputs=[uploader, images_state],
166
- outputs=[
167
- images_state,
168
- *image_slots, # image updates
169
- *prompt_boxes, # prompt visibility + reset
170
- *stitch_buttons, # button visibility
171
- *video_outputs # video visibility + reset
172
- ]
173
  )
174
 
175
- # Clear wiring
176
- clear_btn.click(
177
- fn=clear_all,
178
- inputs=[],
179
- outputs=[
180
- images_state,
181
- *image_slots,
182
- *prompt_boxes,
183
- *stitch_buttons,
184
- *video_outputs
185
- ]
186
  )
187
 
188
- # Per-pair stitch wiring
189
- for i in range(MAX_FRAMES - 1):
190
- stitch_buttons[i].click(
191
- fn=lambda p, s, imgs, idx=i: stitch_pair(idx, imgs, p, s),
192
- inputs=[prompt_boxes[i], seed_in, images_state],
193
- outputs=[video_outputs[i]]
194
- )
 
 
 
 
195
 
196
- if __name__ == "__main__":
197
- demo.queue().launch()
 
1
+ import os
2
+ import io
3
+ import time
4
+ import random
5
+ import base64
6
+ from urllib.parse import quote_plus
7
+ from typing import Optional, Tuple
8
 
9
  import requests
 
10
  import gradio as gr
11
 
12
+ # -----------------------------
13
+ # Config
14
+ # -----------------------------
15
+ # You can set your Modal endpoint via env var MM_I2V_URL
16
+ DEFAULT_API_URL = os.getenv(
17
+ "MM_I2V_URL",
18
+ "https://moonmath-ai--moonmath-i2v-backend-moonmathinference-run.modal.run",
19
+ )
20
+
21
+ SAVE_DIR = "outputs"
22
+ os.makedirs(SAVE_DIR, exist_ok=True)
23
+
24
+
25
+ # -----------------------------
26
+ # Helpers
27
+ # -----------------------------
28
+ def _save_bytes_to_mp4(buf: bytes, name_prefix: str) -> str:
29
+ ts = int(time.time() * 1000)
30
+ path = os.path.join(SAVE_DIR, f"{name_prefix}-{ts}.mp4")
31
  with open(path, "wb") as f:
32
+ f.write(buf)
33
  return path
34
 
 
 
 
 
35
 
36
+ def _download(url: str) -> bytes:
37
+ r = requests.get(url, timeout=600)
38
  r.raise_for_status()
39
  return r.content
40
 
41
+
42
+ def call_i2v(
43
+ image_path: str,
44
+ prompt: str,
45
+ seed: Optional[int],
46
+ api_url: Optional[str] = None,
47
+ ) -> Tuple[Optional[str], Optional[str]]:
48
+ \"\"\"
49
+ Call the image->video backend and return (video_path, error_message).
50
+
51
+ Tries to handle several common response types:
52
+ 1) raw mp4 bytes
53
+ 2) JSON with {\"video\": \"<base64>\"} (mp4 base64)
54
+ 3) JSON with {\"video_url\": \"https://...\"} (or \"result_url\")
55
+ \"\"\"
56
+ api = (api_url or DEFAULT_API_URL).strip().rstrip(\"/\")
57
+ used_seed = seed if (seed is not None and str(seed).strip() != \"\") else random.randint(0, 2**31 - 1)
58
+ url = f\"{api}?prompt={quote_plus(prompt)}&seed={used_seed}\"
59
+
60
+ files = {
61
+ \"image_bytes\": (os.path.basename(image_path), open(image_path, \"rb\"), \"application/octet-stream\")
62
+ }
63
+ headers = {\"accept\": \"application/json\"}
64
 
65
  try:
66
+ resp = requests.post(url, headers=headers, files=files, timeout=1200)
67
+ # Try to accommodate various backends
68
+ ctype = resp.headers.get(\"Content-Type\", \"\")
69
+ if \"application/json\" in ctype:
70
+ data = resp.json()
71
+ # base64 payload
72
+ if \"video\" in data and isinstance(data[\"video\"], str) and len(data[\"video\"]) > 50:
73
+ try:
74
+ raw = base64.b64decode(data[\"video\"], validate=True)
75
+ return _save_bytes_to_mp4(raw, \"clip\"), None
76
+ except Exception as e:
77
+ return None, f\"Could not decode base64 video: {e}\"
78
+ # url payload
79
+ for key in (\"video_url\", \"result_url\", \"url\"):
80
+ if key in data and isinstance(data[key], str) and data[key].startswith(\"http\"):
81
+ raw = _download(data[key])
82
+ return _save_bytes_to_mp4(raw, \"clip\"), None
83
+ return None, 'JSON response did not include \"video\" (base64) or a known url key.'
84
+ # Raw bytes (ideally mp4)
85
+ elif \"video\" in ctype or \"octet-stream\" in ctype:
86
+ return _save_bytes_to_mp4(resp.content, \"clip\"), None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  else:
88
+ # Some backends still reply bytes with missing/odd content-type
89
+ if resp.content and len(resp.content) > 1024:
90
+ return _save_bytes_to_mp4(resp.content, \"clip\"), None
91
+ return None, f\"Unexpected content type: {ctype}\"
92
+ except requests.RequestException as e:
93
+ return None, f\"Request failed: {e}\"
94
+
95
+
96
+ def stitch_pair(
97
+ image_a: str,
98
+ image_b: str,
99
+ prompt: str,
100
+ seed: Optional[int],
101
+ api_url: Optional[str],
102
+ crossfade: float,
103
+ ) -> Tuple[Optional[str], str]:
104
+ \"\"\"
105
+ Strategy:
106
+ - Generate a short clip from image A
107
+ - Generate a short clip from image B (same prompt/seed unless user changes)
108
+ - Concatenate with a short crossfade in Python (moviepy)
109
+
110
+ If you already have a backend endpoint that does stitching directly,
111
+ replace this function body with a single backend call.
112
+ \"\"\"
113
+ if not image_a or not image_b:
114
+ return None, \"Please upload both images.\"
115
+
116
+ # First generate both clips
117
+ clip1_path, err1 = call_i2v(image_a, prompt, seed, api_url)
118
+ if err1:
119
+ return None, f\"Clip 1 failed: {err1}\"
120
+ clip2_path, err2 = call_i2v(image_b, prompt, seed, api_url)
121
+ if err2:
122
+ return None, f\"Clip 2 failed: {err2}\"
123
+
124
+ # If crossfade is 0, just concatenate directly
125
+ try:
126
+ from moviepy.editor import VideoFileClip, concatenate_videoclips
127
+ except Exception as e:
128
+ return None, f\"MoviePy import failed. Add moviepy & imageio-ffmpeg to requirements.txt. Error: {e}\"
129
+
130
+ try:
131
+ c1 = VideoFileClip(clip1_path)
132
+ c2 = VideoFileClip(clip2_path)
133
+
134
+ # Enforce same size/fps (compose handles mismatches)
135
+ if crossfade and crossfade > 0:
136
+ # Apply crossfade (second clip fades in)
137
+ c2 = c2.crossfadein(crossfade)
138
+ c1 = c1.crossfadeout(crossfade)
139
+ merged = concatenate_videoclips([c1, c2], method=\"compose\", padding=-crossfade)
140
+ else:
141
+ merged = concatenate_videoclips([c1, c2], method=\"compose\")
142
+
143
+ out_path = os.path.join(SAVE_DIR, f\"stitched-{int(time.time()*1000)}.mp4\")
144
+ merged.write_videofile(out_path, codec=\"libx264\", audio_codec=\"aac\", verbose=False, logger=None)
145
+ c1.close(); c2.close(); merged.close()
146
+ return out_path, \"\"
147
+ except Exception as e:
148
+ return None, f\"Stitching failed: {e}\"
149
+
150
+
151
+ # -----------------------------
152
+ # UI
153
+ # -----------------------------
154
+ with gr.Blocks(title=\"Image Stitch to Video\", css=\"\"\"
155
+ /* Rounded tiles like the sketch */
156
+ .rounded { border-radius: 24px; }
157
+ .tile { background: #f7f7ff; padding: 12px; }
158
+ .tile-blue { background: #e8f0ff; }
159
+ .tile-yellow { background: #fff7d6; }
160
+ .small-btn button { padding: 6px 10px; border-radius: 999px; }
161
+ .label-center label { text-align:center; width: 100%; }
162
+ \"\"\") as demo:
163
+ gr.Markdown(\"### Image → Video (Stitch Adjacent Pairs)\\nUpload 3 images, enter prompts for each stitch, then click the stitch buttons.\")
164
+
165
+ with gr.Row(equal_height=True):
166
+ # Left column: images + add image
167
+ with gr.Column(scale=1):
168
+ gr.Markdown(\"**Images**\")
169
+ img1 = gr.Image(type=\"filepath\", label=\"Image 1\", height=220, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"])
170
+ img2 = gr.Image(type=\"filepath\", label=\"Image 2\", height=220, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"])
171
+ img3 = gr.Image(type=\"filepath\", label=\"Image 3\", height=220, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"])
172
+
173
+ # Optional extra slots (hidden until added)
174
+ extra_imgs = []
175
+ for i in range(4, 9):
176
+ comp = gr.Image(type=\"filepath\", label=f\"Image {i}\", height=220, visible=False, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"])
177
+ extra_imgs.append(comp)
178
+
179
+ add_btn = gr.Button(\"Add Image\", variant=\"secondary\")
180
+
181
+ # Middle column: prompts + stitch buttons
182
+ with gr.Column(scale=1):
183
+ gr.Markdown(\"**Prompts**\")
184
+ prompt12 = gr.Textbox(label=\"Prompt for Stitch 1 & 2\", lines=3, placeholder=\"Describe motion/style/etc.\", elem_classes=[\"rounded\", \"tile\"])
185
+ seed12 = gr.Number(label=\"Seed (optional)\", value=None, precision=0)
186
+ stitch12 = gr.Button(\"Stitch 1 & 2\", elem_classes=[\"small-btn\"])
187
+
188
+ prompt23 = gr.Textbox(label=\"Prompt for Stitch 2 & 3\", lines=3, placeholder=\"Describe motion/style/etc.\", elem_classes=[\"rounded\", \"tile\"])
189
+ seed23 = gr.Number(label=\"Seed (optional)\", value=None, precision=0)
190
+ stitch23 = gr.Button(\"Stitch 2 & 3\", elem_classes=[\"small-btn\"])
191
+
192
+ with gr.Accordion(\"Advanced (API & Stitch)\", open=False):
193
+ api_url = gr.Textbox(label=\"Backend API URL\", value=DEFAULT_API_URL)
194
+ crossfade = gr.Slider(0.0, 1.5, value=0.4, step=0.1, label=\"Crossfade seconds\")
195
+ clear_btn = gr.Button(\"Clear All\")
196
+
197
+ # Right column: video outputs
198
+ with gr.Column(scale=1):
199
+ gr.Markdown(\"**Outputs**\")
200
+ vid12 = gr.Video(label=\"Video (image 1 + 2) output\", elem_classes=[\"rounded\", \"tile\", \"tile-yellow\"])
201
+ vid23 = gr.Video(label=\"Video (image 2 + 3) output\", elem_classes=[\"rounded\", \"tile\", \"tile-yellow\"])
202
+
203
+ # Wire up actions
204
+ def _on_add(*imgs):
205
+ # Reveal the next hidden uploader
206
+ for comp in extra_imgs:
207
+ if comp.visible is False:
208
+ comp.visible = True
209
+ break
210
+ return [gr.update(visible=comp.visible) for comp in extra_imgs]
211
+
212
+ add_btn.click(
213
+ _on_add,
214
+ inputs=extra_imgs,
215
+ outputs=extra_imgs,
216
+ )
217
 
218
+ stitch12.click(
219
+ stitch_pair,
220
+ inputs=[img1, img2, prompt12, seed12, api_url, crossfade],
221
+ outputs=[vid12, gr.Textbox(visible=False)],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  )
223
 
224
+ stitch23.click(
225
+ stitch_pair,
226
+ inputs=[img2, img3, prompt23, seed23, api_url, crossfade],
227
+ outputs=[vid23, gr.Textbox(visible=False)],
 
 
 
 
 
 
 
228
  )
229
 
230
+ def _on_clear():
231
+ updates = []
232
+ for comp in [img1, img2, img3, *extra_imgs]:
233
+ updates.append(gr.update(value=None, visible=True if comp in [img1, img2, img3] else False))
234
+ return updates + [None, None, \"\", \"\", gr.update(value=DEFAULT_API_URL), 0.4]
235
+
236
+ clear_btn.click(
237
+ _on_clear,
238
+ inputs=None,
239
+ outputs=[img1, img2, img3, *extra_imgs, vid12, vid23, prompt12, prompt23, api_url, crossfade],
240
+ )
241
 
242
+ if __name__ == \"__main__\":
243
+ demo.launch()