Shalmoni commited on
Commit
92be6ab
Β·
verified Β·
1 Parent(s): 6a8b4f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -27
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import os, io, time, base64, random
2
  from typing import Optional
3
  from urllib.parse import quote
4
 
@@ -6,14 +6,13 @@ import requests
6
  from PIL import Image
7
  import gradio as gr
8
 
9
- # -------- Modal inference endpoint (dev) --------
10
  INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"
11
 
12
- # -------- small helpers --------
13
  def _save_video_bytes(data: bytes, tag: str) -> str:
14
- import os, time
15
  os.makedirs("/tmp", exist_ok=True)
16
- path = f"/tmp/{tag}_{int(time.time())}.mp4" # <- /tmp
17
  with open(path, "wb") as f:
18
  f.write(data)
19
  return path
@@ -29,13 +28,6 @@ def _download_to_bytes(url: str) -> bytes:
29
  return r.content
30
 
31
  def stitch_call(start_img: Image.Image, end_img: Image.Image, prompt: str, seed: Optional[int]) -> Optional[str]:
32
- """
33
- JS equivalent:
34
- const fd = new FormData();
35
- fd.append("image_bytes", start);
36
- fd.append("image_bytes_end", end);
37
- fetch(`${INFERENCE_URL}?prompt=${prompt}&seed=${seed}`, { method:"POST", body: fd })
38
- """
39
  if start_img is None or end_img is None:
40
  return None
41
 
@@ -69,16 +61,17 @@ def stitch_call(start_img: Image.Image, end_img: Image.Image, prompt: str, seed:
69
  video_b64 = data.get("video_b64") or data.get("videoBase64")
70
  if isinstance(video_b64, str):
71
  pad = (-len(video_b64)) % 4
72
- if pad: video_b64 += "=" * pad
 
73
  b = base64.b64decode(video_b64)
74
  return _save_video_bytes(b, "stitch")
75
 
76
- except Exception:
77
- pass
78
 
79
  return None
80
 
81
- # -------- Gradio callbacks (exactly two stitches) --------
82
  def stitch_12(prompt12, seed, img1, img2):
83
  if img1 is None or img2 is None:
84
  gr.Warning("Please upload Image 1 and Image 2.")
@@ -97,6 +90,30 @@ def stitch_23(prompt23, seed, img2, img3):
97
  gr.Warning("Stitch 2&3 failed. Try again or adjust the prompt.")
98
  return path
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  # -------- UI --------
101
  CSS = """
102
  .gradio-container { padding: 24px; }
@@ -104,33 +121,32 @@ CSS = """
104
  .rounded textarea { border-radius: 16px !important; }
105
  """
106
 
107
- with gr.Blocks(css=CSS, title="Stitch β€” 3 uploads, 2 stitches") as demo:
108
- gr.Markdown("## Stitch β€” Upload 3 images, then generate videos between 1β†’2 and 2β†’3")
109
 
110
  with gr.Row():
111
- # Left: exactly 3 image inputs
112
- with gr.Column(scale=1, min_width=360):
113
  img1 = gr.Image(label="Image 1 upload", type="pil")
114
  img2 = gr.Image(label="Image 2 upload", type="pil")
115
  img3 = gr.Image(label="Image 3 upload", type="pil")
116
 
117
- # Middle: prompts + buttons
118
- with gr.Column(scale=1, min_width=360):
119
  seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
120
  prompt12 = gr.Textbox(placeholder="Prompt for stitching 1β†’2", lines=2, label="Prompt (1β†’2)", elem_classes=["rounded"])
121
  btn12 = gr.Button("Stitch 1&2", elem_classes=["pill"])
122
- gr.Markdown("---")
 
123
  prompt23 = gr.Textbox(placeholder="Prompt for stitching 2β†’3", lines=2, label="Prompt (2β†’3)", elem_classes=["rounded"])
124
  btn23 = gr.Button("Stitch 2&3", elem_classes=["pill"])
125
-
126
- # Right: exactly 2 video outputs
127
- with gr.Column(scale=1, min_width=360):
128
- vid12 = gr.Video(label="Video (image 1+2) output")
129
  vid23 = gr.Video(label="Video (image 2+3) output")
130
 
 
 
 
131
  # Wire buttons
132
  btn12.click(stitch_12, inputs=[prompt12, seed, img1, img2], outputs=[vid12])
133
  btn23.click(stitch_23, inputs=[prompt23, seed, img2, img3], outputs=[vid23])
 
134
 
135
  if __name__ == "__main__":
136
  demo.queue().launch()
 
1
+ import os, io, time, base64, random, subprocess
2
  from typing import Optional
3
  from urllib.parse import quote
4
 
 
6
  from PIL import Image
7
  import gradio as gr
8
 
9
+ # -------- Modal inference endpoint --------
10
  INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"
11
 
12
+ # -------- Helpers --------
13
  def _save_video_bytes(data: bytes, tag: str) -> str:
 
14
  os.makedirs("/tmp", exist_ok=True)
15
+ path = f"/tmp/{tag}_{int(time.time())}.mp4"
16
  with open(path, "wb") as f:
17
  f.write(data)
18
  return path
 
28
  return r.content
29
 
30
  def stitch_call(start_img: Image.Image, end_img: Image.Image, prompt: str, seed: Optional[int]) -> Optional[str]:
 
 
 
 
 
 
 
31
  if start_img is None or end_img is None:
32
  return None
33
 
 
61
  video_b64 = data.get("video_b64") or data.get("videoBase64")
62
  if isinstance(video_b64, str):
63
  pad = (-len(video_b64)) % 4
64
+ if pad:
65
+ video_b64 += "=" * pad
66
  b = base64.b64decode(video_b64)
67
  return _save_video_bytes(b, "stitch")
68
 
69
+ except Exception as e:
70
+ print("Stitch call failed:", e)
71
 
72
  return None
73
 
74
+ # -------- Gradio callbacks --------
75
  def stitch_12(prompt12, seed, img1, img2):
76
  if img1 is None or img2 is None:
77
  gr.Warning("Please upload Image 1 and Image 2.")
 
90
  gr.Warning("Stitch 2&3 failed. Try again or adjust the prompt.")
91
  return path
92
 
93
+ def stitch_all(video12, video23):
94
+ if not video12 or not video23:
95
+ gr.Warning("Please generate both stitched videos first.")
96
+ return None
97
+
98
+ try:
99
+ # Final output path
100
+ out_path = f"/tmp/stitch_all_{int(time.time())}.mp4"
101
+
102
+ # Concatenate with ffmpeg
103
+ txt_file = f"/tmp/concat_{int(time.time())}.txt"
104
+ with open(txt_file, "w") as f:
105
+ f.write(f"file '{video12}'\n")
106
+ f.write(f"file '{video23}'\n")
107
+
108
+ cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", txt_file, "-c", "copy", out_path]
109
+ subprocess.run(cmd, check=True)
110
+
111
+ return out_path
112
+ except Exception as e:
113
+ print("Stitch all failed:", e)
114
+ gr.Warning("Failed to stitch all videos together.")
115
+ return None
116
+
117
  # -------- UI --------
118
  CSS = """
119
  .gradio-container { padding: 24px; }
 
121
  .rounded textarea { border-radius: 16px !important; }
122
  """
123
 
124
+ with gr.Blocks(css=CSS, title="Stitch β€” 3 uploads, 2 stitches, concat") as demo:
125
+ gr.Markdown("## Stitch β€” Upload 3 images, generate videos between 1β†’2 and 2β†’3, then merge them.")
126
 
127
  with gr.Row():
128
+ with gr.Column(scale=1, min_width=320):
 
129
  img1 = gr.Image(label="Image 1 upload", type="pil")
130
  img2 = gr.Image(label="Image 2 upload", type="pil")
131
  img3 = gr.Image(label="Image 3 upload", type="pil")
132
 
133
+ with gr.Column(scale=1, min_width=320):
 
134
  seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
135
  prompt12 = gr.Textbox(placeholder="Prompt for stitching 1β†’2", lines=2, label="Prompt (1β†’2)", elem_classes=["rounded"])
136
  btn12 = gr.Button("Stitch 1&2", elem_classes=["pill"])
137
+ vid12 = gr.Video(label="Video (image 1+2) output")
138
+
139
  prompt23 = gr.Textbox(placeholder="Prompt for stitching 2β†’3", lines=2, label="Prompt (2β†’3)", elem_classes=["rounded"])
140
  btn23 = gr.Button("Stitch 2&3", elem_classes=["pill"])
 
 
 
 
141
  vid23 = gr.Video(label="Video (image 2+3) output")
142
 
143
+ btn_all = gr.Button("Stitch All", elem_classes=["pill"])
144
+ vid_all = gr.Video(label="Final concatenated video")
145
+
146
  # Wire buttons
147
  btn12.click(stitch_12, inputs=[prompt12, seed, img1, img2], outputs=[vid12])
148
  btn23.click(stitch_23, inputs=[prompt23, seed, img2, img3], outputs=[vid23])
149
+ btn_all.click(stitch_all, inputs=[vid12, vid23], outputs=[vid_all])
150
 
151
  if __name__ == "__main__":
152
  demo.queue().launch()