Shalmoni commited on
Commit
c401d55
Β·
verified Β·
1 Parent(s): 5587ee0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -56
app.py CHANGED
@@ -1,11 +1,12 @@
1
- import os, io, time, base64, random, requests
2
  from typing import Optional
3
  from urllib.parse import quote
 
 
4
  from PIL import Image
5
  import gradio as gr
6
- from moviepy.editor import VideoFileClip, concatenate_videoclips
7
 
8
- # -------- Modal inference endpoint --------
9
  INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"
10
 
11
  # -------- small helpers --------
@@ -45,17 +46,19 @@ def stitch_call(start_img: Image.Image, end_img: Image.Image, prompt: str, seed:
45
  resp = requests.post(url, files=files, headers=headers, timeout=600)
46
  ctype = (resp.headers.get("content-type") or "").lower()
47
 
 
48
  if "application/json" not in ctype:
49
  resp.raise_for_status()
50
  return _save_video_bytes(resp.content, "stitch")
51
 
 
52
  data = resp.json()
53
- video_url = data.get("video_url") or data.get("url") or data.get("result") or data.get("output")
54
- if isinstance(video_url, str) and video_url.startswith(("http://", "https://")):
55
  b = _download_to_bytes(video_url)
56
  return _save_video_bytes(b, "stitch")
57
 
58
- video_b64 = data.get("video_b64") or data.get("videoBase64")
59
  if isinstance(video_b64, str):
60
  pad = (-len(video_b64)) % 4
61
  if pad:
@@ -64,78 +67,83 @@ def stitch_call(start_img: Image.Image, end_img: Image.Image, prompt: str, seed:
64
  return _save_video_bytes(b, "stitch")
65
 
66
  except Exception as e:
67
- print("stitch_call failed:", e)
68
 
69
  return None
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  # -------- Gradio callbacks --------
72
  def stitch_12(prompt12, seed, img1, img2):
73
- if img1 is None or img2 is None:
74
- gr.Warning("Please upload Image 1 and Image 2.")
75
- return None
76
  path = stitch_call(img1, img2, prompt12 or "", int(seed or 0))
77
- if path is None:
78
- gr.Warning("Stitch 1β†’2 failed.")
79
  return path
80
 
81
  def stitch_23(prompt23, seed, img2, img3):
82
- if img2 is None or img3 is None:
83
- gr.Warning("Please upload Image 2 and Image 3.")
84
- return None
85
  path = stitch_call(img2, img3, prompt23 or "", int(seed or 0))
86
- if path is None:
87
- gr.Warning("Stitch 2β†’3 failed.")
88
  return path
89
 
90
- def stitch_all(video12, video23):
91
- if video12 is None or video23 is None:
92
- gr.Warning("Need both videos to stitch all.")
93
- return None
94
- try:
95
- clip1 = VideoFileClip(video12)
96
- clip2 = VideoFileClip(video23)
97
- final = concatenate_videoclips([clip1, clip2])
98
- out_path = f"/tmp/stitch_all_{int(time.time())}.mp4"
99
- final.write_videofile(out_path, codec="libx264", audio=False, verbose=False, logger=None)
100
- return out_path
101
- except Exception as e:
102
- print("stitch_all failed:", e)
103
- gr.Warning("Stitch All failed.")
104
  return None
 
105
 
106
  # -------- UI --------
107
  CSS = """
108
  .gradio-container { padding: 24px; }
109
- .pill button { border-radius: 999px !important; padding: 10px 18px; }
110
- .rounded textarea { border-radius: 16px !important; }
111
  """
112
 
113
- with gr.Blocks(css=CSS, title="Stitch Master") as demo:
114
- gr.Markdown("## Stitch β€” Upload 3 images, generate videos between 1β†’2 and 2β†’3, then merge them.")
115
 
116
- # --- Uploads row ---
117
  with gr.Row():
118
- img1 = gr.Image(label="Image 1 upload", type="pil")
119
- img2 = gr.Image(label="Image 2 upload", type="pil")
120
- img3 = gr.Image(label="Image 3 upload", type="pil")
121
-
122
- seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
123
 
124
- # --- First stitch ---
125
- prompt12 = gr.Textbox(placeholder="Prompt for stitching 1β†’2", lines=2, label="Prompt (1β†’2)", elem_classes=["rounded"])
126
- btn12 = gr.Button("Stitch 1β†’2", elem_classes=["pill"])
127
- vid12 = gr.Video(label="Video (image 1+2) output", interactive=False)
128
-
129
- # --- Second stitch ---
130
- prompt23 = gr.Textbox(placeholder="Prompt for stitching 2β†’3", lines=2, label="Prompt (2β†’3)", elem_classes=["rounded"])
131
- btn23 = gr.Button("Stitch 2β†’3", elem_classes=["pill"])
132
- vid23 = gr.Video(label="Video (image 2+3) output", interactive=False)
133
-
134
- # --- Final merge ---
135
- btn_all = gr.Button("Stitch All", elem_classes=["pill"])
136
- vid_all = gr.Video(label="Final concatenated video", interactive=False)
137
-
138
- # --- Wire buttons ---
 
139
  btn12.click(stitch_12, inputs=[prompt12, seed, img1, img2], outputs=[vid12])
140
  btn23.click(stitch_23, inputs=[prompt23, seed, img2, img3], outputs=[vid23])
141
  btn_all.click(stitch_all, inputs=[vid12, vid23], outputs=[vid_all])
 
1
+ import os, io, time, base64, random, subprocess
2
  from typing import Optional
3
  from urllib.parse import quote
4
+
5
+ import requests
6
  from PIL import Image
7
  import gradio as gr
 
8
 
9
+ # -------- Modal inference endpoint (dev) --------
10
  INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"
11
 
12
  # -------- small helpers --------
 
46
  resp = requests.post(url, files=files, headers=headers, timeout=600)
47
  ctype = (resp.headers.get("content-type") or "").lower()
48
 
49
+ # Raw video bytes
50
  if "application/json" not in ctype:
51
  resp.raise_for_status()
52
  return _save_video_bytes(resp.content, "stitch")
53
 
54
+ # JSON with url or base64
55
  data = resp.json()
56
+ video_url = data.get("video_url") or data.get("url") or data.get("result")
57
+ if isinstance(video_url, str) and video_url.startswith("http"):
58
  b = _download_to_bytes(video_url)
59
  return _save_video_bytes(b, "stitch")
60
 
61
+ video_b64 = data.get("video_b64")
62
  if isinstance(video_b64, str):
63
  pad = (-len(video_b64)) % 4
64
  if pad:
 
67
  return _save_video_bytes(b, "stitch")
68
 
69
  except Exception as e:
70
+ print("stitch_call error:", e)
71
 
72
  return None
73
 
74
+ # -------- FFmpeg-based concatenation --------
75
+ def concat_videos(vid1: str, vid2: str) -> Optional[str]:
76
+ if not vid1 or not vid2:
77
+ return None
78
+ try:
79
+ os.makedirs("/tmp", exist_ok=True)
80
+ out_path = f"/tmp/final_{int(time.time())}.mp4"
81
+
82
+ # Create a temporary file list for ffmpeg
83
+ list_file = f"/tmp/list_{int(time.time())}.txt"
84
+ with open(list_file, "w") as f:
85
+ f.write(f"file '{vid1}'\n")
86
+ f.write(f"file '{vid2}'\n")
87
+
88
+ # Run ffmpeg concat
89
+ subprocess.run(
90
+ ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_file, "-c", "copy", out_path],
91
+ check=True,
92
+ stdout=subprocess.PIPE,
93
+ stderr=subprocess.PIPE,
94
+ )
95
+
96
+ return out_path
97
+ except Exception as e:
98
+ print("concat_videos error:", e)
99
+ return None
100
+
101
  # -------- Gradio callbacks --------
102
  def stitch_12(prompt12, seed, img1, img2):
 
 
 
103
  path = stitch_call(img1, img2, prompt12 or "", int(seed or 0))
 
 
104
  return path
105
 
106
  def stitch_23(prompt23, seed, img2, img3):
 
 
 
107
  path = stitch_call(img2, img3, prompt23 or "", int(seed or 0))
 
 
108
  return path
109
 
110
+ def stitch_all(vid12, vid23):
111
+ if vid12 is None or vid23 is None:
112
+ gr.Warning("Generate both videos first before stitching all.")
 
 
 
 
 
 
 
 
 
 
 
113
  return None
114
+ return concat_videos(vid12, vid23)
115
 
116
  # -------- UI --------
117
  CSS = """
118
  .gradio-container { padding: 24px; }
 
 
119
  """
120
 
121
+ with gr.Blocks(css=CSS, title="Stitch β€” 3 uploads, 2 stitches") as demo:
122
+ gr.Markdown("## Stitch β€” Upload 3 images β†’ Generate 1β†’2, 2β†’3, then combine.")
123
 
 
124
  with gr.Row():
125
+ # Top: images
126
+ with gr.Column():
127
+ img1 = gr.Image(label="Image 1 upload", type="pil")
128
+ img2 = gr.Image(label="Image 2 upload", type="pil")
129
+ img3 = gr.Image(label="Image 3 upload", type="pil")
130
 
131
+ with gr.Row():
132
+ # Prompts + buttons
133
+ with gr.Column():
134
+ seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
135
+ prompt12 = gr.Textbox(placeholder="Prompt for stitching 1β†’2", lines=2, label="Prompt (1β†’2)")
136
+ btn12 = gr.Button("Stitch 1&2")
137
+ prompt23 = gr.Textbox(placeholder="Prompt for stitching 2β†’3", lines=2, label="Prompt (2β†’3)")
138
+ btn23 = gr.Button("Stitch 2&3")
139
+ btn_all = gr.Button("Stitch All (combine 1β†’2 and 2β†’3)")
140
+
141
+ with gr.Column():
142
+ vid12 = gr.Video(label="Video (1β†’2)")
143
+ vid23 = gr.Video(label="Video (2β†’3)")
144
+ vid_all = gr.Video(label="Final Combined Video")
145
+
146
+ # Wire
147
  btn12.click(stitch_12, inputs=[prompt12, seed, img1, img2], outputs=[vid12])
148
  btn23.click(stitch_23, inputs=[prompt23, seed, img2, img3], outputs=[vid23])
149
  btn_all.click(stitch_all, inputs=[vid12, vid23], outputs=[vid_all])