Shalmoni commited on
Commit
f1cd1b3
Β·
verified Β·
1 Parent(s): 72df41e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -225
app.py CHANGED
@@ -1,243 +1,135 @@
1
- import os
2
- import io
3
- import time
4
- import random
5
- import base64
6
- from urllib.parse import quote_plus
7
- from typing import Optional, Tuple
8
 
9
  import requests
 
10
  import gradio as gr
11
 
12
- # -----------------------------
13
- # Config
14
- # -----------------------------
15
- # You can set your Modal endpoint via env var MM_I2V_URL
16
- DEFAULT_API_URL = os.getenv(
17
- "MM_I2V_URL",
18
- "https://moonmath-ai--moonmath-i2v-backend-moonmathinference-run.modal.run",
19
- )
20
-
21
- SAVE_DIR = "outputs"
22
- os.makedirs(SAVE_DIR, exist_ok=True)
23
-
24
-
25
- # -----------------------------
26
- # Helpers
27
- # -----------------------------
28
- def _save_bytes_to_mp4(buf: bytes, name_prefix: str) -> str:
29
- ts = int(time.time() * 1000)
30
- path = os.path.join(SAVE_DIR, f"{name_prefix}-{ts}.mp4")
31
  with open(path, "wb") as f:
32
- f.write(buf)
33
  return path
34
 
 
 
 
 
35
 
36
- def _download(url: str) -> bytes:
37
- r = requests.get(url, timeout=600)
38
  r.raise_for_status()
39
  return r.content
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- def call_i2v(
43
- image_path: str,
44
- prompt: str,
45
- seed: Optional[int],
46
- api_url: Optional[str] = None,
47
- ) -> Tuple[Optional[str], Optional[str]]:
48
- \"\"\"
49
- Call the image->video backend and return (video_path, error_message).
50
-
51
- Tries to handle several common response types:
52
- 1) raw mp4 bytes
53
- 2) JSON with {\"video\": \"<base64>\"} (mp4 base64)
54
- 3) JSON with {\"video_url\": \"https://...\"} (or \"result_url\")
55
- \"\"\"
56
- api = (api_url or DEFAULT_API_URL).strip().rstrip(\"/\")
57
- used_seed = seed if (seed is not None and str(seed).strip() != \"\") else random.randint(0, 2**31 - 1)
58
- url = f\"{api}?prompt={quote_plus(prompt)}&seed={used_seed}\"
59
 
60
  files = {
61
- \"image_bytes\": (os.path.basename(image_path), open(image_path, \"rb\"), \"application/octet-stream\")
 
62
  }
63
- headers = {\"accept\": \"application/json\"}
64
 
65
  try:
66
- resp = requests.post(url, headers=headers, files=files, timeout=1200)
67
- # Try to accommodate various backends
68
- ctype = resp.headers.get(\"Content-Type\", \"\")
69
- if \"application/json\" in ctype:
70
- data = resp.json()
71
- # base64 payload
72
- if \"video\" in data and isinstance(data[\"video\"], str) and len(data[\"video\"]) > 50:
73
- try:
74
- raw = base64.b64decode(data[\"video\"], validate=True)
75
- return _save_bytes_to_mp4(raw, \"clip\"), None
76
- except Exception as e:
77
- return None, f\"Could not decode base64 video: {e}\"
78
- # url payload
79
- for key in (\"video_url\", \"result_url\", \"url\"):
80
- if key in data and isinstance(data[key], str) and data[key].startswith(\"http\"):
81
- raw = _download(data[key])
82
- return _save_bytes_to_mp4(raw, \"clip\"), None
83
- return None, 'JSON response did not include \"video\" (base64) or a known url key.'
84
- # Raw bytes (ideally mp4)
85
- elif \"video\" in ctype or \"octet-stream\" in ctype:
86
- return _save_bytes_to_mp4(resp.content, \"clip\"), None
87
- else:
88
- # Some backends still reply bytes with missing/odd content-type
89
- if resp.content and len(resp.content) > 1024:
90
- return _save_bytes_to_mp4(resp.content, \"clip\"), None
91
- return None, f\"Unexpected content type: {ctype}\"
92
- except requests.RequestException as e:
93
- return None, f\"Request failed: {e}\"
94
-
95
-
96
- def stitch_pair(
97
- image_a: str,
98
- image_b: str,
99
- prompt: str,
100
- seed: Optional[int],
101
- api_url: Optional[str],
102
- crossfade: float,
103
- ) -> Tuple[Optional[str], str]:
104
- \"\"\"
105
- Strategy:
106
- - Generate a short clip from image A
107
- - Generate a short clip from image B (same prompt/seed unless user changes)
108
- - Concatenate with a short crossfade in Python (moviepy)
109
-
110
- If you already have a backend endpoint that does stitching directly,
111
- replace this function body with a single backend call.
112
- \"\"\"
113
- if not image_a or not image_b:
114
- return None, \"Please upload both images.\"
115
-
116
- # First generate both clips
117
- clip1_path, err1 = call_i2v(image_a, prompt, seed, api_url)
118
- if err1:
119
- return None, f\"Clip 1 failed: {err1}\"
120
- clip2_path, err2 = call_i2v(image_b, prompt, seed, api_url)
121
- if err2:
122
- return None, f\"Clip 2 failed: {err2}\"
123
-
124
- # If crossfade is 0, just concatenate directly
125
- try:
126
- from moviepy.editor import VideoFileClip, concatenate_videoclips
127
- except Exception as e:
128
- return None, f\"MoviePy import failed. Add moviepy & imageio-ffmpeg to requirements.txt. Error: {e}\"
129
 
130
- try:
131
- c1 = VideoFileClip(clip1_path)
132
- c2 = VideoFileClip(clip2_path)
133
-
134
- # Enforce same size/fps (compose handles mismatches)
135
- if crossfade and crossfade > 0:
136
- # Apply crossfade (second clip fades in)
137
- c2 = c2.crossfadein(crossfade)
138
- c1 = c1.crossfadeout(crossfade)
139
- merged = concatenate_videoclips([c1, c2], method=\"compose\", padding=-crossfade)
140
- else:
141
- merged = concatenate_videoclips([c1, c2], method=\"compose\")
142
-
143
- out_path = os.path.join(SAVE_DIR, f\"stitched-{int(time.time()*1000)}.mp4\")
144
- merged.write_videofile(out_path, codec=\"libx264\", audio_codec=\"aac\", verbose=False, logger=None)
145
- c1.close(); c2.close(); merged.close()
146
- return out_path, \"\"
147
- except Exception as e:
148
- return None, f\"Stitching failed: {e}\"
149
-
150
-
151
- # -----------------------------
152
- # UI
153
- # -----------------------------
154
- with gr.Blocks(title=\"Image Stitch to Video\", css=\"\"\"
155
- /* Rounded tiles like the sketch */
156
- .rounded { border-radius: 24px; }
157
- .tile { background: #f7f7ff; padding: 12px; }
158
- .tile-blue { background: #e8f0ff; }
159
- .tile-yellow { background: #fff7d6; }
160
- .small-btn button { padding: 6px 10px; border-radius: 999px; }
161
- .label-center label { text-align:center; width: 100%; }
162
- \"\"\") as demo:
163
- gr.Markdown(\"### Image β†’ Video (Stitch Adjacent Pairs)\\nUpload 3 images, enter prompts for each stitch, then click the stitch buttons.\")
164
-
165
- with gr.Row(equal_height=True):
166
- # Left column: images + add image
167
- with gr.Column(scale=1):
168
- gr.Markdown(\"**Images**\")
169
- img1 = gr.Image(type=\"filepath\", label=\"Image 1\", height=220, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"])
170
- img2 = gr.Image(type=\"filepath\", label=\"Image 2\", height=220, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"])
171
- img3 = gr.Image(type=\"filepath\", label=\"Image 3\", height=220, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"])
172
-
173
- # Optional extra slots (hidden until added)
174
- extra_imgs = []
175
- for i in range(4, 9):
176
- comp = gr.Image(type=\"filepath\", label=f\"Image {i}\", height=220, visible=False, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"])
177
- extra_imgs.append(comp)
178
-
179
- add_btn = gr.Button(\"Add Image\", variant=\"secondary\")
180
-
181
- # Middle column: prompts + stitch buttons
182
- with gr.Column(scale=1):
183
- gr.Markdown(\"**Prompts**\")
184
- prompt12 = gr.Textbox(label=\"Prompt for Stitch 1 & 2\", lines=3, placeholder=\"Describe motion/style/etc.\", elem_classes=[\"rounded\", \"tile\"])
185
- seed12 = gr.Number(label=\"Seed (optional)\", value=None, precision=0)
186
- stitch12 = gr.Button(\"Stitch 1 & 2\", elem_classes=[\"small-btn\"])
187
-
188
- prompt23 = gr.Textbox(label=\"Prompt for Stitch 2 & 3\", lines=3, placeholder=\"Describe motion/style/etc.\", elem_classes=[\"rounded\", \"tile\"])
189
- seed23 = gr.Number(label=\"Seed (optional)\", value=None, precision=0)
190
- stitch23 = gr.Button(\"Stitch 2 & 3\", elem_classes=[\"small-btn\"])
191
-
192
- with gr.Accordion(\"Advanced (API & Stitch)\", open=False):
193
- api_url = gr.Textbox(label=\"Backend API URL\", value=DEFAULT_API_URL)
194
- crossfade = gr.Slider(0.0, 1.5, value=0.4, step=0.1, label=\"Crossfade seconds\")
195
- clear_btn = gr.Button(\"Clear All\")
196
-
197
- # Right column: video outputs
198
- with gr.Column(scale=1):
199
- gr.Markdown(\"**Outputs**\")
200
- vid12 = gr.Video(label=\"Video (image 1 + 2) output\", elem_classes=[\"rounded\", \"tile\", \"tile-yellow\"])
201
- vid23 = gr.Video(label=\"Video (image 2 + 3) output\", elem_classes=[\"rounded\", \"tile\", \"tile-yellow\"])
202
-
203
- # Wire up actions
204
- def _on_add(*imgs):
205
- # Reveal the next hidden uploader
206
- for comp in extra_imgs:
207
- if comp.visible is False:
208
- comp.visible = True
209
- break
210
- return [gr.update(visible=comp.visible) for comp in extra_imgs]
211
-
212
- add_btn.click(
213
- _on_add,
214
- inputs=extra_imgs,
215
- outputs=extra_imgs,
216
- )
217
-
218
- stitch12.click(
219
- stitch_pair,
220
- inputs=[img1, img2, prompt12, seed12, api_url, crossfade],
221
- outputs=[vid12, gr.Textbox(visible=False)],
222
- )
223
-
224
- stitch23.click(
225
- stitch_pair,
226
- inputs=[img2, img3, prompt23, seed23, api_url, crossfade],
227
- outputs=[vid23, gr.Textbox(visible=False)],
228
- )
229
-
230
- def _on_clear():
231
- updates = []
232
- for comp in [img1, img2, img3, *extra_imgs]:
233
- updates.append(gr.update(value=None, visible=True if comp in [img1, img2, img3] else False))
234
- return updates + [None, None, \"\", \"\", gr.update(value=DEFAULT_API_URL), 0.4]
235
-
236
- clear_btn.click(
237
- _on_clear,
238
- inputs=None,
239
- outputs=[img1, img2, img3, *extra_imgs, vid12, vid23, prompt12, prompt23, api_url, crossfade],
240
- )
241
-
242
- if __name__ == \"__main__\":
243
- demo.launch()
 
1
+ import os, io, time, base64, random
2
+ from typing import Optional
3
+ from urllib.parse import quote
 
 
 
 
4
 
5
  import requests
6
+ from PIL import Image
7
  import gradio as gr
8
 
9
+ # -------- Modal inference endpoint (dev) --------
10
+ INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"
11
+
12
+ # -------- small helpers --------
13
+ def _save_video_bytes(data: bytes, tag: str) -> str:
14
+ os.makedirs("/mnt/data", exist_ok=True)
15
+ path = f"/mnt/data/{tag}_{int(time.time())}.mp4"
 
 
 
 
 
 
 
 
 
 
 
 
16
  with open(path, "wb") as f:
17
+ f.write(data)
18
  return path
19
 
20
+ def _png_bytes(img: Image.Image) -> bytes:
21
+ buf = io.BytesIO()
22
+ img.save(buf, format="PNG")
23
+ return buf.getvalue()
24
 
25
+ def _download_to_bytes(url: str) -> bytes:
26
+ r = requests.get(url, timeout=180)
27
  r.raise_for_status()
28
  return r.content
29
 
30
+ def stitch_call(start_img: Image.Image, end_img: Image.Image, prompt: str, seed: Optional[int]) -> Optional[str]:
31
+ """
32
+ JS equivalent:
33
+ const fd = new FormData();
34
+ fd.append("image_bytes", start);
35
+ fd.append("image_bytes_end", end);
36
+ fetch(`${INFERENCE_URL}?prompt=${prompt}&seed=${seed}`, { method:"POST", body: fd })
37
+ """
38
+ if start_img is None or end_img is None:
39
+ return None
40
+
41
+ if seed in (None, 0, -1):
42
+ seed = random.randint(1, 2**31 - 1)
43
 
44
+ url = f"{INFERENCE_URL}?prompt={quote(prompt or '')}&seed={seed}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  files = {
47
+ "image_bytes": ("start.png", _png_bytes(start_img), "image/png"),
48
+ "image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"),
49
  }
50
+ headers = {"accept": "application/json"}
51
 
52
  try:
53
+ resp = requests.post(url, files=files, headers=headers, timeout=600)
54
+ ctype = (resp.headers.get("content-type") or "").lower()
55
+
56
+ # Raw video bytes
57
+ if "application/json" not in ctype:
58
+ resp.raise_for_status()
59
+ return _save_video_bytes(resp.content, "stitch")
60
+
61
+ # JSON with url or base64
62
+ data = resp.json()
63
+ video_url = data.get("video_url") or data.get("url") or data.get("result") or data.get("output")
64
+ if isinstance(video_url, str) and (video_url.startswith("http://") or video_url.startswith("https://")):
65
+ b = _download_to_bytes(video_url)
66
+ return _save_video_bytes(b, "stitch")
67
+
68
+ video_b64 = data.get("video_b64") or data.get("videoBase64")
69
+ if isinstance(video_b64, str):
70
+ pad = (-len(video_b64)) % 4
71
+ if pad: video_b64 += "=" * pad
72
+ b = base64.b64decode(video_b64)
73
+ return _save_video_bytes(b, "stitch")
74
+
75
+ except Exception:
76
+ pass
77
+
78
+ return None
79
+
80
+ # -------- Gradio callbacks (exactly two stitches) --------
81
+ def stitch_12(prompt12, seed, img1, img2):
82
+ if img1 is None or img2 is None:
83
+ gr.Warning("Please upload Image 1 and Image 2.")
84
+ return None
85
+ path = stitch_call(img1, img2, prompt12 or "", int(seed or 0))
86
+ if path is None:
87
+ gr.Warning("Stitch 1&2 failed. Try again or adjust the prompt.")
88
+ return path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ def stitch_23(prompt23, seed, img2, img3):
91
+ if img2 is None or img3 is None:
92
+ gr.Warning("Please upload Image 2 and Image 3.")
93
+ return None
94
+ path = stitch_call(img2, img3, prompt23 or "", int(seed or 0))
95
+ if path is None:
96
+ gr.Warning("Stitch 2&3 failed. Try again or adjust the prompt.")
97
+ return path
98
+
99
+ # -------- UI --------
100
+ CSS = """
101
+ .gradio-container { padding: 24px; }
102
+ .pill button { border-radius: 999px !important; padding: 10px 18px; }
103
+ .rounded textarea { border-radius: 16px !important; }
104
+ """
105
+
106
+ with gr.Blocks(css=CSS, title="Stitch β€” 3 uploads, 2 stitches") as demo:
107
+ gr.Markdown("## Stitch β€” Upload 3 images, then generate videos between 1β†’2 and 2β†’3")
108
+
109
+ with gr.Row():
110
+ # Left: exactly 3 image inputs
111
+ with gr.Column(scale=1, min_width=360):
112
+ img1 = gr.Image(label="Image 1 upload", type="pil")
113
+ img2 = gr.Image(label="Image 2 upload", type="pil")
114
+ img3 = gr.Image(label="Image 3 upload", type="pil")
115
+
116
+ # Middle: prompts + buttons
117
+ with gr.Column(scale=1, min_width=360):
118
+ seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
119
+ prompt12 = gr.Textbox(placeholder="Prompt for stitching 1β†’2", lines=2, label="Prompt (1β†’2)", elem_classes=["rounded"])
120
+ btn12 = gr.Button("Stitch 1&2", elem_classes=["pill"])
121
+ gr.Markdown("---")
122
+ prompt23 = gr.Textbox(placeholder="Prompt for stitching 2β†’3", lines=2, label="Prompt (2β†’3)", elem_classes=["rounded"])
123
+ btn23 = gr.Button("Stitch 2&3", elem_classes=["pill"])
124
+
125
+ # Right: exactly 2 video outputs
126
+ with gr.Column(scale=1, min_width=360):
127
+ vid12 = gr.Video(label="Video (image 1+2) output")
128
+ vid23 = gr.Video(label="Video (image 2+3) output")
129
+
130
+ # Wire buttons
131
+ btn12.click(stitch_12, inputs=[prompt12, seed, img1, img2], outputs=[vid12])
132
+ btn23.click(stitch_23, inputs=[prompt23, seed, img2, img3], outputs=[vid23])
133
+
134
+ if __name__ == "__main__":
135
+ demo.queue().launch()