| """ |
| Gradio Space entrypoint for Time-to-Move (Wan 2.2 backbone). |
| |
| The UI allows users to run the provided cut-and-drag examples or upload their own |
| `first_frame.png`, `mask.mp4`, and `motion_signal.mp4` triplet. All preprocessing |
| matches the `run_wan.py` script in the main repository. |
| """ |
|
|
| import base64 |
| import os |
| import tempfile |
| from pathlib import Path |
| from typing import Dict, Optional, Tuple |
|
|
| import gradio as gr |
| import torch |
| from diffusers.utils import export_to_video, load_image |
|
|
| from pipelines.utils import compute_hw_from_area |
| from pipelines.wan_pipeline import WanImageToVideoTTMPipeline |
|
|
|
|
| WAN_MODEL_ID = os.getenv("WAN_MODEL_ID", "Wan-AI/Wan2.2-I2V-A14B-Diffusers") |
| EXAMPLES_DIR = Path(os.getenv("TTM_EXAMPLES_DIR", "examples")) |
| DEFAULT_NEGATIVE_PROMPT = ( |
| "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量," |
| "低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的," |
| "毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" |
| ) |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32 |
|
|
| _PIPELINE: Optional[WanImageToVideoTTMPipeline] = None |
| _MOD_VALUE: Optional[int] = None |
|
|
|
|
| def _build_example_index() -> Dict[str, Dict[str, str]]: |
| """Scan `examples/` for Wan-compatible folders that contain the required files.""" |
| index: Dict[str, Dict[str, str]] = {} |
| if not EXAMPLES_DIR.exists(): |
| return index |
|
|
| for folder in sorted(EXAMPLES_DIR.iterdir()): |
| if not folder.is_dir(): |
| continue |
| image = folder / "first_frame.png" |
| mask = folder / "mask.mp4" |
| motion = folder / "motion_signal.mp4" |
| prompt_file = folder / "prompt.txt" |
| if not (image.exists() and mask.exists() and motion.exists()): |
| continue |
| index[folder.name] = { |
| "folder": str(folder.resolve()), |
| "prompt": prompt_file.read_text(encoding="utf-8").strip() if prompt_file.exists() else "", |
| } |
| return index |
|
|
|
|
| EXAMPLE_INDEX = _build_example_index() |
|
|
|
|
| def _ensure_pipeline() -> WanImageToVideoTTMPipeline: |
| """Lazy-load the Wan Time-to-Move pipeline.""" |
| global _PIPELINE, _MOD_VALUE |
| if _PIPELINE is None: |
| pipe = WanImageToVideoTTMPipeline.from_pretrained(WAN_MODEL_ID, torch_dtype=DTYPE) |
| pipe.vae.enable_tiling() |
| pipe.vae.enable_slicing() |
| pipe.to(DEVICE) |
| _PIPELINE = pipe |
| |
| _MOD_VALUE = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] |
| return _PIPELINE |
|
|
|
|
| def _save_video_payload(payload, tmpdir: Path, filename: str) -> str: |
| """ |
| Persist a video uploaded through Gradio to disk and return its path. |
| |
| Hugging Face Spaces provide uploads either as temp file paths or as base64 data URIs. |
| """ |
| target = tmpdir / filename |
|
|
| if payload is None: |
| raise gr.Error(f"Missing upload for {filename}.") |
|
|
| if isinstance(payload, str) and Path(payload).exists(): |
| data = Path(payload).read_bytes() |
| target.write_bytes(data) |
| return str(target) |
|
|
| if isinstance(payload, dict): |
| |
| potential_path = payload.get("name") |
| if potential_path and Path(potential_path).exists(): |
| data = Path(potential_path).read_bytes() |
| target.write_bytes(data) |
| return str(target) |
|
|
| raw_data = payload.get("data") |
| if raw_data is None: |
| raise gr.Error(f"Could not read data for {filename}.") |
|
|
| if isinstance(raw_data, str): |
| |
| if raw_data.startswith("data:"): |
| raw_data = raw_data.split(",", 1)[1] |
| file_bytes = base64.b64decode(raw_data) |
| else: |
| file_bytes = raw_data |
| target.write_bytes(file_bytes) |
| return str(target) |
|
|
| raise gr.Error(f"Unsupported upload format for {filename}.") |
|
|
|
|
| def _prepare_inputs( |
| example_name: str, |
| prompt: str, |
| negative_prompt: str, |
| custom_image, |
| custom_mask, |
| custom_motion, |
| ) -> Tuple: |
| """ |
| Determine which inputs to feed into the pipeline. |
| |
| Returns (prompt, negative_prompt, image, mask_path, motion_path). |
| """ |
| negative_prompt = (negative_prompt or "").strip() or DEFAULT_NEGATIVE_PROMPT |
|
|
| if example_name != "custom": |
| meta = EXAMPLE_INDEX.get(example_name) |
| if not meta: |
| raise gr.Error(f"Example '{example_name}' not found in {EXAMPLES_DIR}.") |
| folder = Path(meta["folder"]) |
| image = load_image(folder / "first_frame.png") |
| mask_path = str((folder / "mask.mp4").resolve()) |
| motion_path = str((folder / "motion_signal.mp4").resolve()) |
| resolved_prompt = (prompt or "").strip() or meta["prompt"] |
| if not resolved_prompt: |
| raise gr.Error("Prompt cannot be empty for example runs.") |
| return resolved_prompt, negative_prompt, image, mask_path, motion_path |
|
|
| |
| if custom_image is None: |
| raise gr.Error("Upload a first frame (PNG/JPG) for custom mode.") |
| resolved_prompt = (prompt or "").strip() |
| if not resolved_prompt: |
| raise gr.Error("Prompt cannot be empty for custom runs.") |
|
|
| tmpdir = Path(tempfile.mkdtemp(prefix="ttm_space_inputs_")) |
| mask_path = _save_video_payload(custom_mask, tmpdir, "mask.mp4") |
| motion_path = _save_video_payload(custom_motion, tmpdir, "motion_signal.mp4") |
| return resolved_prompt, negative_prompt, custom_image, mask_path, motion_path |
|
|
|
|
| def generate_video( |
| example_name: str, |
| prompt: str, |
| negative_prompt: str, |
| tweak_index: int, |
| tstrong_index: int, |
| num_inference_steps: int, |
| guidance_scale: float, |
| num_frames: int, |
| max_area: int, |
| seed: int, |
| custom_image, |
| custom_mask, |
| custom_motion, |
| ): |
| """Main callback used by Gradio.""" |
| pipe = _ensure_pipeline() |
| prompt, negative_prompt, image, mask_path, motion_path = _prepare_inputs( |
| example_name, prompt, negative_prompt, custom_image, custom_mask, custom_motion |
| ) |
|
|
| tweak_index = int(tweak_index) |
| tstrong_index = int(tstrong_index) |
| num_inference_steps = int(num_inference_steps) |
| num_frames = int(num_frames) |
| guidance_scale = float(guidance_scale) |
| seed = int(seed) |
| if not (0 <= tweak_index <= tstrong_index <= num_inference_steps): |
| raise gr.Error("Require 0 ≤ tweak-index ≤ tstrong-index ≤ num_inference_steps.") |
|
|
| max_area = int(max_area) |
| mod_value = _MOD_VALUE or (pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]) |
| height, width = compute_hw_from_area(image.height, image.width, max_area, mod_value) |
| if hasattr(image, "mode") and image.mode != "RGB": |
| image = image.convert("RGB") |
| image = image.resize((width, height)) |
|
|
| generator_device = DEVICE if DEVICE.startswith("cuda") else "cpu" |
| generator = torch.Generator(device=generator_device).manual_seed(seed) |
|
|
| with torch.inference_mode(): |
| result = pipe( |
| image=image, |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| height=height, |
| width=width, |
| num_frames=num_frames, |
| guidance_scale=guidance_scale, |
| num_inference_steps=num_inference_steps, |
| generator=generator, |
| motion_signal_video_path=motion_path, |
| motion_signal_mask_path=mask_path, |
| tweak_index=tweak_index, |
| tstrong_index=tstrong_index, |
| ) |
|
|
| frames = result.frames[0] |
| output_dir = Path(tempfile.mkdtemp(prefix="ttm_space_output_")) |
| output_path = output_dir / "ttm.mp4" |
| export_to_video(frames, str(output_path), fps=16) |
|
|
| status = ( |
| f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}\n" |
| f"Resolution: {height}x{width}, frames: {num_frames}, guidance: {guidance_scale}" |
| ) |
| return str(output_path), status |
|
|
|
|
| def build_ui() -> gr.Blocks: |
| example_choices = sorted(EXAMPLE_INDEX.keys()) |
| default_example = example_choices[0] if example_choices else "custom" |
|
|
| with gr.Blocks(title="Time-to-Move (Wan 2.2)") as demo: |
| gr.Markdown( |
| "### Time-to-Move (Wan 2.2)\n" |
| "Generate motion-controlled videos by combining a still frame with a cut-and-drag motion signal. " |
| "Select one of the bundled examples or upload your own trio of files." |
| ) |
|
|
| with gr.Row(): |
| example_dropdown = gr.Dropdown( |
| choices=example_choices + ["custom"], |
| value=default_example, |
| label="Example preset", |
| info="Choose a prepackaged example or switch to 'custom' to upload your own inputs.", |
| ) |
| prompt_box = gr.Textbox( |
| label="Prompt", |
| lines=4, |
| placeholder="Enter the text prompt (auto-filled for examples).", |
| ) |
| negative_prompt_box = gr.Textbox( |
| label="Negative prompt", |
| lines=4, |
| value=DEFAULT_NEGATIVE_PROMPT, |
| ) |
|
|
| with gr.Row(): |
| image_input = gr.Image(label="first_frame (custom only)", type="pil") |
| mask_input = gr.Video(label="mask.mp4 (custom only)") |
| motion_input = gr.Video(label="motion_signal.mp4 (custom only)") |
|
|
| with gr.Row(): |
| tweak_slider = gr.Slider(0, 20, value=3, step=1, label="tweak-index") |
| tstrong_slider = gr.Slider(0, 50, value=7, step=1, label="tstrong-index") |
| steps_slider = gr.Slider(10, 50, value=50, step=1, label="num_inference_steps") |
| guidance_slider = gr.Slider(1.0, 8.0, value=3.5, step=0.1, label="guidance_scale") |
|
|
| with gr.Row(): |
| frames_slider = gr.Slider(21, 81, value=81, step=1, label="num_frames") |
| area_slider = gr.Slider( |
| 256 * 256, |
| 640 * 1152, |
| value=480 * 832, |
| step=64, |
| label="max pixel area (height*width)", |
| ) |
| seed_box = gr.Number(label="Seed", value=0, precision=0) |
|
|
| generate_button = gr.Button("Generate video", variant="primary") |
| output_video = gr.Video(label="Generated video", autoplay=True, height=512) |
| status_box = gr.Markdown() |
|
|
| generate_button.click( |
| fn=generate_video, |
| inputs=[ |
| example_dropdown, |
| prompt_box, |
| negative_prompt_box, |
| tweak_slider, |
| tstrong_slider, |
| steps_slider, |
| guidance_slider, |
| frames_slider, |
| area_slider, |
| seed_box, |
| image_input, |
| mask_input, |
| motion_input, |
| ], |
| outputs=[output_video, status_box], |
| ) |
|
|
| info = "\n".join(f"- **{name}**: `{meta['folder']}`" for name, meta in EXAMPLE_INDEX.items()) |
| with gr.Accordion("Available packaged examples", open=False): |
| gr.Markdown(info or "No example folders detected. Upload custom inputs instead.") |
|
|
| return demo |
|
|
|
|
| app = build_ui() |
|
|
| if __name__ == "__main__": |
| |
| app.queue(max_size=2).launch(server_name="0.0.0.0", server_port=7860) |
|
|