rmz92002's picture
Upload 95 files
5484092 verified
"""
Gradio Space entrypoint for Time-to-Move (Wan 2.2 backbone).
The UI allows users to run the provided cut-and-drag examples or upload their own
`first_frame.png`, `mask.mp4`, and `motion_signal.mp4` triplet. All preprocessing
matches the `run_wan.py` script in the main repository.
"""
import base64
import os
import tempfile
from pathlib import Path
from typing import Dict, Optional, Tuple
import gradio as gr
import torch
from diffusers.utils import export_to_video, load_image
from pipelines.utils import compute_hw_from_area
from pipelines.wan_pipeline import WanImageToVideoTTMPipeline
WAN_MODEL_ID = os.getenv("WAN_MODEL_ID", "Wan-AI/Wan2.2-I2V-A14B-Diffusers")
EXAMPLES_DIR = Path(os.getenv("TTM_EXAMPLES_DIR", "examples"))
DEFAULT_NEGATIVE_PROMPT = (
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,"
"低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,"
"毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
_PIPELINE: Optional[WanImageToVideoTTMPipeline] = None
_MOD_VALUE: Optional[int] = None
def _build_example_index() -> Dict[str, Dict[str, str]]:
"""Scan `examples/` for Wan-compatible folders that contain the required files."""
index: Dict[str, Dict[str, str]] = {}
if not EXAMPLES_DIR.exists():
return index
for folder in sorted(EXAMPLES_DIR.iterdir()):
if not folder.is_dir():
continue
image = folder / "first_frame.png"
mask = folder / "mask.mp4"
motion = folder / "motion_signal.mp4"
prompt_file = folder / "prompt.txt"
if not (image.exists() and mask.exists() and motion.exists()):
continue
index[folder.name] = {
"folder": str(folder.resolve()),
"prompt": prompt_file.read_text(encoding="utf-8").strip() if prompt_file.exists() else "",
}
return index
EXAMPLE_INDEX = _build_example_index()
def _ensure_pipeline() -> WanImageToVideoTTMPipeline:
"""Lazy-load the Wan Time-to-Move pipeline."""
global _PIPELINE, _MOD_VALUE
if _PIPELINE is None:
pipe = WanImageToVideoTTMPipeline.from_pretrained(WAN_MODEL_ID, torch_dtype=DTYPE)
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
pipe.to(DEVICE)
_PIPELINE = pipe
# Height/width must be multiples of vae_scale_factor * patch_size.
_MOD_VALUE = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
return _PIPELINE
def _save_video_payload(payload, tmpdir: Path, filename: str) -> str:
"""
Persist a video uploaded through Gradio to disk and return its path.
Hugging Face Spaces provide uploads either as temp file paths or as base64 data URIs.
"""
target = tmpdir / filename
if payload is None:
raise gr.Error(f"Missing upload for {filename}.")
if isinstance(payload, str) and Path(payload).exists():
data = Path(payload).read_bytes()
target.write_bytes(data)
return str(target)
if isinstance(payload, dict):
# payload["name"] may already point to a temp file on disk.
potential_path = payload.get("name")
if potential_path and Path(potential_path).exists():
data = Path(potential_path).read_bytes()
target.write_bytes(data)
return str(target)
raw_data = payload.get("data")
if raw_data is None:
raise gr.Error(f"Could not read data for {filename}.")
if isinstance(raw_data, str):
# Format: data:video/mp4;base64,AAA...
if raw_data.startswith("data:"):
raw_data = raw_data.split(",", 1)[1]
file_bytes = base64.b64decode(raw_data)
else:
file_bytes = raw_data
target.write_bytes(file_bytes)
return str(target)
raise gr.Error(f"Unsupported upload format for {filename}.")
def _prepare_inputs(
example_name: str,
prompt: str,
negative_prompt: str,
custom_image,
custom_mask,
custom_motion,
) -> Tuple:
"""
Determine which inputs to feed into the pipeline.
Returns (prompt, negative_prompt, image, mask_path, motion_path).
"""
negative_prompt = (negative_prompt or "").strip() or DEFAULT_NEGATIVE_PROMPT
if example_name != "custom":
meta = EXAMPLE_INDEX.get(example_name)
if not meta:
raise gr.Error(f"Example '{example_name}' not found in {EXAMPLES_DIR}.")
folder = Path(meta["folder"])
image = load_image(folder / "first_frame.png")
mask_path = str((folder / "mask.mp4").resolve())
motion_path = str((folder / "motion_signal.mp4").resolve())
resolved_prompt = (prompt or "").strip() or meta["prompt"]
if not resolved_prompt:
raise gr.Error("Prompt cannot be empty for example runs.")
return resolved_prompt, negative_prompt, image, mask_path, motion_path
# Custom upload path
if custom_image is None:
raise gr.Error("Upload a first frame (PNG/JPG) for custom mode.")
resolved_prompt = (prompt or "").strip()
if not resolved_prompt:
raise gr.Error("Prompt cannot be empty for custom runs.")
tmpdir = Path(tempfile.mkdtemp(prefix="ttm_space_inputs_"))
mask_path = _save_video_payload(custom_mask, tmpdir, "mask.mp4")
motion_path = _save_video_payload(custom_motion, tmpdir, "motion_signal.mp4")
return resolved_prompt, negative_prompt, custom_image, mask_path, motion_path
def generate_video(
example_name: str,
prompt: str,
negative_prompt: str,
tweak_index: int,
tstrong_index: int,
num_inference_steps: int,
guidance_scale: float,
num_frames: int,
max_area: int,
seed: int,
custom_image,
custom_mask,
custom_motion,
):
"""Main callback used by Gradio."""
pipe = _ensure_pipeline()
prompt, negative_prompt, image, mask_path, motion_path = _prepare_inputs(
example_name, prompt, negative_prompt, custom_image, custom_mask, custom_motion
)
tweak_index = int(tweak_index)
tstrong_index = int(tstrong_index)
num_inference_steps = int(num_inference_steps)
num_frames = int(num_frames)
guidance_scale = float(guidance_scale)
seed = int(seed)
if not (0 <= tweak_index <= tstrong_index <= num_inference_steps):
raise gr.Error("Require 0 ≤ tweak-index ≤ tstrong-index ≤ num_inference_steps.")
max_area = int(max_area)
mod_value = _MOD_VALUE or (pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1])
height, width = compute_hw_from_area(image.height, image.width, max_area, mod_value)
if hasattr(image, "mode") and image.mode != "RGB":
image = image.convert("RGB")
image = image.resize((width, height))
generator_device = DEVICE if DEVICE.startswith("cuda") else "cpu"
generator = torch.Generator(device=generator_device).manual_seed(seed)
with torch.inference_mode():
result = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
motion_signal_video_path=motion_path,
motion_signal_mask_path=mask_path,
tweak_index=tweak_index,
tstrong_index=tstrong_index,
)
frames = result.frames[0]
output_dir = Path(tempfile.mkdtemp(prefix="ttm_space_output_"))
output_path = output_dir / "ttm.mp4"
export_to_video(frames, str(output_path), fps=16)
status = (
f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}\n"
f"Resolution: {height}x{width}, frames: {num_frames}, guidance: {guidance_scale}"
)
return str(output_path), status
def build_ui() -> gr.Blocks:
example_choices = sorted(EXAMPLE_INDEX.keys())
default_example = example_choices[0] if example_choices else "custom"
with gr.Blocks(title="Time-to-Move (Wan 2.2)") as demo:
gr.Markdown(
"### Time-to-Move (Wan 2.2)\n"
"Generate motion-controlled videos by combining a still frame with a cut-and-drag motion signal. "
"Select one of the bundled examples or upload your own trio of files."
)
with gr.Row():
example_dropdown = gr.Dropdown(
choices=example_choices + ["custom"],
value=default_example,
label="Example preset",
info="Choose a prepackaged example or switch to 'custom' to upload your own inputs.",
)
prompt_box = gr.Textbox(
label="Prompt",
lines=4,
placeholder="Enter the text prompt (auto-filled for examples).",
)
negative_prompt_box = gr.Textbox(
label="Negative prompt",
lines=4,
value=DEFAULT_NEGATIVE_PROMPT,
)
with gr.Row():
image_input = gr.Image(label="first_frame (custom only)", type="pil")
mask_input = gr.Video(label="mask.mp4 (custom only)")
motion_input = gr.Video(label="motion_signal.mp4 (custom only)")
with gr.Row():
tweak_slider = gr.Slider(0, 20, value=3, step=1, label="tweak-index")
tstrong_slider = gr.Slider(0, 50, value=7, step=1, label="tstrong-index")
steps_slider = gr.Slider(10, 50, value=50, step=1, label="num_inference_steps")
guidance_slider = gr.Slider(1.0, 8.0, value=3.5, step=0.1, label="guidance_scale")
with gr.Row():
frames_slider = gr.Slider(21, 81, value=81, step=1, label="num_frames")
area_slider = gr.Slider(
256 * 256,
640 * 1152,
value=480 * 832,
step=64,
label="max pixel area (height*width)",
)
seed_box = gr.Number(label="Seed", value=0, precision=0)
generate_button = gr.Button("Generate video", variant="primary")
output_video = gr.Video(label="Generated video", autoplay=True, height=512)
status_box = gr.Markdown()
generate_button.click(
fn=generate_video,
inputs=[
example_dropdown,
prompt_box,
negative_prompt_box,
tweak_slider,
tstrong_slider,
steps_slider,
guidance_slider,
frames_slider,
area_slider,
seed_box,
image_input,
mask_input,
motion_input,
],
outputs=[output_video, status_box],
)
info = "\n".join(f"- **{name}**: `{meta['folder']}`" for name, meta in EXAMPLE_INDEX.items())
with gr.Accordion("Available packaged examples", open=False):
gr.Markdown(info or "No example folders detected. Upload custom inputs instead.")
return demo
app = build_ui()
if __name__ == "__main__":
# Enable queuing to support concurrent users on Spaces.
app.queue(max_size=2).launch(server_name="0.0.0.0", server_port=7860)