Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import time | |
| import random | |
| import base64 | |
| from urllib.parse import quote_plus | |
| from typing import Optional, Tuple | |
| import requests | |
| import gradio as gr | |
| # ----------------------------- | |
| # Config | |
| # ----------------------------- | |
| # You can set your Modal endpoint via env var MM_I2V_URL | |
| DEFAULT_API_URL = os.getenv( | |
| "MM_I2V_URL", | |
| "https://moonmath-ai--moonmath-i2v-backend-moonmathinference-run.modal.run", | |
| ) | |
| SAVE_DIR = "outputs" | |
| os.makedirs(SAVE_DIR, exist_ok=True) | |
| # ----------------------------- | |
| # Helpers | |
| # ----------------------------- | |
| def _save_bytes_to_mp4(buf: bytes, name_prefix: str) -> str: | |
| ts = int(time.time() * 1000) | |
| path = os.path.join(SAVE_DIR, f"{name_prefix}-{ts}.mp4") | |
| with open(path, "wb") as f: | |
| f.write(buf) | |
| return path | |
| def _download(url: str) -> bytes: | |
| r = requests.get(url, timeout=600) | |
| r.raise_for_status() | |
| return r.content | |
| def call_i2v( | |
| image_path: str, | |
| prompt: str, | |
| seed: Optional[int], | |
| api_url: Optional[str] = None, | |
| ) -> Tuple[Optional[str], Optional[str]]: | |
| \"\"\" | |
| Call the image->video backend and return (video_path, error_message). | |
| Tries to handle several common response types: | |
| 1) raw mp4 bytes | |
| 2) JSON with {\"video\": \"<base64>\"} (mp4 base64) | |
| 3) JSON with {\"video_url\": \"https://...\"} (or \"result_url\") | |
| \"\"\" | |
| api = (api_url or DEFAULT_API_URL).strip().rstrip(\"/\") | |
| used_seed = seed if (seed is not None and str(seed).strip() != \"\") else random.randint(0, 2**31 - 1) | |
| url = f\"{api}?prompt={quote_plus(prompt)}&seed={used_seed}\" | |
| files = { | |
| \"image_bytes\": (os.path.basename(image_path), open(image_path, \"rb\"), \"application/octet-stream\") | |
| } | |
| headers = {\"accept\": \"application/json\"} | |
| try: | |
| resp = requests.post(url, headers=headers, files=files, timeout=1200) | |
| # Try to accommodate various backends | |
| ctype = resp.headers.get(\"Content-Type\", \"\") | |
| if \"application/json\" in ctype: | |
| data = resp.json() | |
| # base64 payload | |
| if \"video\" in data and isinstance(data[\"video\"], str) and len(data[\"video\"]) > 50: | |
| try: | |
| raw = base64.b64decode(data[\"video\"], validate=True) | |
| return _save_bytes_to_mp4(raw, \"clip\"), None | |
| except Exception as e: | |
| return None, f\"Could not decode base64 video: {e}\" | |
| # url payload | |
| for key in (\"video_url\", \"result_url\", \"url\"): | |
| if key in data and isinstance(data[key], str) and data[key].startswith(\"http\"): | |
| raw = _download(data[key]) | |
| return _save_bytes_to_mp4(raw, \"clip\"), None | |
| return None, 'JSON response did not include \"video\" (base64) or a known url key.' | |
| # Raw bytes (ideally mp4) | |
| elif \"video\" in ctype or \"octet-stream\" in ctype: | |
| return _save_bytes_to_mp4(resp.content, \"clip\"), None | |
| else: | |
| # Some backends still reply bytes with missing/odd content-type | |
| if resp.content and len(resp.content) > 1024: | |
| return _save_bytes_to_mp4(resp.content, \"clip\"), None | |
| return None, f\"Unexpected content type: {ctype}\" | |
| except requests.RequestException as e: | |
| return None, f\"Request failed: {e}\" | |
| def stitch_pair( | |
| image_a: str, | |
| image_b: str, | |
| prompt: str, | |
| seed: Optional[int], | |
| api_url: Optional[str], | |
| crossfade: float, | |
| ) -> Tuple[Optional[str], str]: | |
| \"\"\" | |
| Strategy: | |
| - Generate a short clip from image A | |
| - Generate a short clip from image B (same prompt/seed unless user changes) | |
| - Concatenate with a short crossfade in Python (moviepy) | |
| If you already have a backend endpoint that does stitching directly, | |
| replace this function body with a single backend call. | |
| \"\"\" | |
| if not image_a or not image_b: | |
| return None, \"Please upload both images.\" | |
| # First generate both clips | |
| clip1_path, err1 = call_i2v(image_a, prompt, seed, api_url) | |
| if err1: | |
| return None, f\"Clip 1 failed: {err1}\" | |
| clip2_path, err2 = call_i2v(image_b, prompt, seed, api_url) | |
| if err2: | |
| return None, f\"Clip 2 failed: {err2}\" | |
| # If crossfade is 0, just concatenate directly | |
| try: | |
| from moviepy.editor import VideoFileClip, concatenate_videoclips | |
| except Exception as e: | |
| return None, f\"MoviePy import failed. Add moviepy & imageio-ffmpeg to requirements.txt. Error: {e}\" | |
| try: | |
| c1 = VideoFileClip(clip1_path) | |
| c2 = VideoFileClip(clip2_path) | |
| # Enforce same size/fps (compose handles mismatches) | |
| if crossfade and crossfade > 0: | |
| # Apply crossfade (second clip fades in) | |
| c2 = c2.crossfadein(crossfade) | |
| c1 = c1.crossfadeout(crossfade) | |
| merged = concatenate_videoclips([c1, c2], method=\"compose\", padding=-crossfade) | |
| else: | |
| merged = concatenate_videoclips([c1, c2], method=\"compose\") | |
| out_path = os.path.join(SAVE_DIR, f\"stitched-{int(time.time()*1000)}.mp4\") | |
| merged.write_videofile(out_path, codec=\"libx264\", audio_codec=\"aac\", verbose=False, logger=None) | |
| c1.close(); c2.close(); merged.close() | |
| return out_path, \"\" | |
| except Exception as e: | |
| return None, f\"Stitching failed: {e}\" | |
| # ----------------------------- | |
| # UI | |
| # ----------------------------- | |
| with gr.Blocks(title=\"Image Stitch to Video\", css=\"\"\" | |
| /* Rounded tiles like the sketch */ | |
| .rounded { border-radius: 24px; } | |
| .tile { background: #f7f7ff; padding: 12px; } | |
| .tile-blue { background: #e8f0ff; } | |
| .tile-yellow { background: #fff7d6; } | |
| .small-btn button { padding: 6px 10px; border-radius: 999px; } | |
| .label-center label { text-align:center; width: 100%; } | |
| \"\"\") as demo: | |
| gr.Markdown(\"### Image → Video (Stitch Adjacent Pairs)\\nUpload 3 images, enter prompts for each stitch, then click the stitch buttons.\") | |
| with gr.Row(equal_height=True): | |
| # Left column: images + add image | |
| with gr.Column(scale=1): | |
| gr.Markdown(\"**Images**\") | |
| img1 = gr.Image(type=\"filepath\", label=\"Image 1\", height=220, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"]) | |
| img2 = gr.Image(type=\"filepath\", label=\"Image 2\", height=220, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"]) | |
| img3 = gr.Image(type=\"filepath\", label=\"Image 3\", height=220, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"]) | |
| # Optional extra slots (hidden until added) | |
| extra_imgs = [] | |
| for i in range(4, 9): | |
| comp = gr.Image(type=\"filepath\", label=f\"Image {i}\", height=220, visible=False, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"]) | |
| extra_imgs.append(comp) | |
| add_btn = gr.Button(\"Add Image\", variant=\"secondary\") | |
| # Middle column: prompts + stitch buttons | |
| with gr.Column(scale=1): | |
| gr.Markdown(\"**Prompts**\") | |
| prompt12 = gr.Textbox(label=\"Prompt for Stitch 1 & 2\", lines=3, placeholder=\"Describe motion/style/etc.\", elem_classes=[\"rounded\", \"tile\"]) | |
| seed12 = gr.Number(label=\"Seed (optional)\", value=None, precision=0) | |
| stitch12 = gr.Button(\"Stitch 1 & 2\", elem_classes=[\"small-btn\"]) | |
| prompt23 = gr.Textbox(label=\"Prompt for Stitch 2 & 3\", lines=3, placeholder=\"Describe motion/style/etc.\", elem_classes=[\"rounded\", \"tile\"]) | |
| seed23 = gr.Number(label=\"Seed (optional)\", value=None, precision=0) | |
| stitch23 = gr.Button(\"Stitch 2 & 3\", elem_classes=[\"small-btn\"]) | |
| with gr.Accordion(\"Advanced (API & Stitch)\", open=False): | |
| api_url = gr.Textbox(label=\"Backend API URL\", value=DEFAULT_API_URL) | |
| crossfade = gr.Slider(0.0, 1.5, value=0.4, step=0.1, label=\"Crossfade seconds\") | |
| clear_btn = gr.Button(\"Clear All\") | |
| # Right column: video outputs | |
| with gr.Column(scale=1): | |
| gr.Markdown(\"**Outputs**\") | |
| vid12 = gr.Video(label=\"Video (image 1 + 2) output\", elem_classes=[\"rounded\", \"tile\", \"tile-yellow\"]) | |
| vid23 = gr.Video(label=\"Video (image 2 + 3) output\", elem_classes=[\"rounded\", \"tile\", \"tile-yellow\"]) | |
| # Wire up actions | |
| def _on_add(*imgs): | |
| # Reveal the next hidden uploader | |
| for comp in extra_imgs: | |
| if comp.visible is False: | |
| comp.visible = True | |
| break | |
| return [gr.update(visible=comp.visible) for comp in extra_imgs] | |
| add_btn.click( | |
| _on_add, | |
| inputs=extra_imgs, | |
| outputs=extra_imgs, | |
| ) | |
| stitch12.click( | |
| stitch_pair, | |
| inputs=[img1, img2, prompt12, seed12, api_url, crossfade], | |
| outputs=[vid12, gr.Textbox(visible=False)], | |
| ) | |
| stitch23.click( | |
| stitch_pair, | |
| inputs=[img2, img3, prompt23, seed23, api_url, crossfade], | |
| outputs=[vid23, gr.Textbox(visible=False)], | |
| ) | |
| def _on_clear(): | |
| updates = [] | |
| for comp in [img1, img2, img3, *extra_imgs]: | |
| updates.append(gr.update(value=None, visible=True if comp in [img1, img2, img3] else False)) | |
| return updates + [None, None, \"\", \"\", gr.update(value=DEFAULT_API_URL), 0.4] | |
| clear_btn.click( | |
| _on_clear, | |
| inputs=None, | |
| outputs=[img1, img2, img3, *extra_imgs, vid12, vid23, prompt12, prompt23, api_url, crossfade], | |
| ) | |
| if __name__ == \"__main__\": | |
| demo.launch() | |