StitchTool / app.py
Shalmoni's picture
Update app.py
0d7b5a8 verified
raw
history blame
5.65 kB
import io, uuid, base64, requests, random
from PIL import Image
import gradio as gr
# ====== CONFIG ======
# Your Modal endpoint (POST with multipart body; prompt/seed in querystring)
MODAL_URL = "https://moonmath-ai--moonmath-i2v-backend-moonmathinference-run.modal.run"
REQUEST_TIMEOUT_SEC = 600 # adjust if your backend needs longer
# ====== BACKEND CALLER ======
def call_modal_backend(prompt: str, image: Image.Image | None, seed: int | None):
"""
Sends prompt + optional image to the Modal backend.
Accepts:
- raw MP4 bytes response
- JSON with video_url or base64 video
Returns a path or URL usable by gr.Video.
"""
if not prompt and image is None:
raise gr.Error("Please provide a prompt or upload an image.")
# Build multipart body if image provided
files = None
if image is not None:
buf = io.BytesIO()
image.save(buf, format="PNG") # change to JPEG if your backend expects it
buf.seek(0)
files = {"image_bytes": ("input.png", buf, "image/png")}
# Query string params
params = {}
if prompt:
params["prompt"] = prompt
if seed is not None:
params["seed"] = str(seed)
# Perform request
res = requests.post(
MODAL_URL,
params=params,
files=files,
headers={"accept": "application/json"},
timeout=REQUEST_TIMEOUT_SEC,
)
res.raise_for_status()
ctype = (res.headers.get("content-type") or "").lower()
# 1) Raw MP4 bytes directly
if "video/mp4" in ctype or ctype.startswith("application/octet-stream"):
mp4_path = f"out_{uuid.uuid4().hex[:8]}.mp4"
with open(mp4_path, "wb") as f:
f.write(res.content)
return mp4_path
# 2) JSON (URL or base64)
if "application/json" in ctype:
data = res.json()
url = data.get("video_url") or data.get("url") or data.get("result", {}).get("video_url")
if url:
return url # gr.Video can stream a URL
b64 = (
data.get("video_b64")
or data.get("video_bytes")
or data.get("result", {}).get("video_b64")
)
if b64:
if "," in b64: # strip data: header if present
b64 = b64.split(",", 1)[1]
blob = base64.b64decode(b64)
mp4_path = f"out_{uuid.uuid4().hex[:8]}.mp4"
with open(mp4_path, "wb") as f:
f.write(blob)
return mp4_path
raise gr.Error(f"Backend JSON did not contain a video field. Keys: {list(data.keys())}")
# 3) Fallback: write bytes as mp4
mp4_path = f"out_{uuid.uuid4().hex[:8]}.mp4"
with open(mp4_path, "wb") as f:
f.write(res.content)
return mp4_path
# ====== UI CALLBACK ======
def on_generate(prompt, image, seed, lock_longshot):
"""
lock_longshot is included so you can later inject constraints server-side if needed.
For now it simply forwards prompt & image to your Modal backend.
"""
# If user left seed blank, generate one
if seed is None or str(seed).strip() == "":
seed_val = random.randint(0, 2**31 - 1)
else:
# Gradio Number returns float; cast safely
try:
seed_val = int(seed)
except Exception:
seed_val = random.randint(0, 2**31 - 1)
# (Optional) reinforce long-shot constraints in prompt (safe no-op if you don’t need it)
if lock_longshot and prompt:
musts = [
"single continuous long shot",
"no cuts, no new shot, no angle switch",
"smooth camera motion (pan/tilt/zoom only)",
"unbroken continuity"
]
prompt = prompt.strip() + ". " + "; ".join(musts)
video_path_or_url = call_modal_backend(prompt, image, seed_val)
info = f"Seed: {seed_val}"
return video_path_or_url, info
# ====== STYLE ======
CUSTOM_CSS = """
.gradio-container { padding: 24px; }
/* Big rounded prompt box */
#prompt-box textarea {
border-radius: 28px !important;
min-height: 180px;
font-size: 18px;
line-height: 1.45;
padding: 18px 22px;
}
/* Rounded square image card */
#add-image .wrap,
#add-image .input-image,
#add-image .empty {
border-radius: 28px !important;
min-width: 240px;
min-height: 240px;
}
/* Pill generate button */
#gen-btn button {
border-radius: 999px !important;
padding: 12px 24px;
font-size: 18px;
}
"""
# ====== APP ======
with gr.Blocks(css=CUSTOM_CSS, title="Stitch UI – Modal Hook") as demo:
gr.Markdown("### Stitch – turn prompt/image into a generated video (Modal backend)")
# Row 1: Big rounded prompt input
prompt_tb = gr.Textbox(
label=None,
placeholder="Prompt input",
lines=8,
elem_id="prompt-box"
)
# Row 2: Left image card, right controls (seed + generate)
with gr.Row():
with gr.Column(scale=1, min_width=300):
img_in = gr.Image(label="Add Image", type="pil", elem_id="add-image")
with gr.Column(scale=3, min_width=300):
with gr.Row():
seed_in = gr.Number(value=None, label="Seed (optional)")
lock_long = gr.Checkbox(value=True, label="Lock camera (long shot, no cuts)")
gen_btn = gr.Button("Generate", elem_id="gen-btn")
# Output
with gr.Row():
video_out = gr.Video(label="Output Video", interactive=False, autoplay=True)
info_out = gr.Markdown("")
gen_btn.click(
fn=on_generate,
inputs=[prompt_tb, img_in, seed_in, lock_long],
outputs=[video_out, info_out]
)
if __name__ == "__main__":
demo.launch()