Spaces:
Paused
Paused
| import os | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from diffusers.pipelines.wan import WanVACEPipeline | |
| from diffusers.utils import load_image, export_to_video | |
| from PIL import Image | |
| REPO_ID = "Wan-AI/Wan2.1-VACE-1.3B-diffusers" | |
| LOCAL_DIR = "/root/.cache/wan21" | |
| OUT_DIR = "outputs" | |
| os.makedirs(OUT_DIR, exist_ok=True) | |
| pipe = None | |
| def download_model(): | |
| os.makedirs(LOCAL_DIR, exist_ok=True) | |
| snapshot_download( | |
| repo_id=REPO_ID, | |
| local_dir=LOCAL_DIR, | |
| local_dir_use_symlinks=False | |
| ) | |
| ok = os.path.exists(os.path.join(LOCAL_DIR, "model_index.json")) | |
| return { | |
| "downloaded": ok, | |
| "local_dir": LOCAL_DIR, | |
| "contains": sorted(os.listdir(LOCAL_DIR))[:30], | |
| "gpu": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None, | |
| } | |
| def init_pipe(): | |
| global pipe | |
| if pipe is not None: | |
| return "Pipeline already initialized." | |
| if not os.path.exists(os.path.join(LOCAL_DIR, "model_index.json")): | |
| raise RuntimeError("Model not downloaded yet. Click '1) Download Model' first.") | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipe = WanVACEPipeline.from_pretrained(LOCAL_DIR, torch_dtype=dtype).to(device) | |
| try: | |
| pipe.enable_attention_slicing() | |
| except Exception: | |
| pass | |
| try: | |
| pipe.enable_vae_slicing() | |
| except Exception: | |
| pass | |
| try: | |
| pipe.enable_vae_tiling() | |
| except Exception: | |
| pass | |
| return f"Initialized WanVACEPipeline on {device} ({dtype})." | |
| def generate_demo(): | |
| msg = init_pipe() | |
| image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" | |
| init_image = load_image(image_url) | |
| prompt = "A realistic video. Subtle natural motion, gentle camera movement, stable subject, cinematic lighting." | |
| out_path = os.path.join(OUT_DIR, "test.mp4") | |
| # L4-safe settings (divisible by 16, and (num_frames-1) divisible by 4) | |
| height, width = 320, 576 | |
| num_frames = 13 | |
| # Resize the conditioning image exactly to the generation size | |
| init_image = init_image.resize((width, height)) | |
| # Build the conditioning video: | |
| # - first frame is the image | |
| # - remaining frames are blank (white) placeholders | |
| blank = Image.new("RGB", (width, height), (255, 255, 255)) | |
| video_in = [init_image] + [blank] * (num_frames - 1) | |
| # Build the masks: | |
| # - black on frame 0 => keep/condition | |
| # - white on other frames => generate | |
| black_mask = Image.new("RGB", (width, height), (0, 0, 0)) | |
| white_mask = Image.new("RGB", (width, height), (255, 255, 255)) | |
| mask_in = [black_mask] + [white_mask] * (num_frames - 1) | |
| result = pipe( | |
| prompt=prompt, | |
| video=video_in, | |
| mask=mask_in, | |
| reference_images=[init_image], | |
| conditioning_scale=2.0, | |
| num_frames=num_frames, | |
| height=height, | |
| width=width, | |
| guidance_scale=5.0, | |
| num_inference_steps=20, | |
| output_type="pil", | |
| ) | |
| frames = result.frames[0] if hasattr(result, "frames") else result["frames"][0] | |
| export_to_video(frames, out_path, fps=8) | |
| return msg, out_path | |
| with gr.Blocks(title="Wan2.1 VACE 1.3B — Stateless Server") as demo: | |
| gr.Markdown("## Wan2.1 VACE 1.3B (Stateless)\nNo persistent storage: download the model after each restart.") | |
| btn_dl = gr.Button("1) Download Model (one-time per restart)") | |
| dl_out = gr.JSON(label="Download status") | |
| btn_gen = gr.Button("2) Generate Test Video") | |
| gen_status = gr.Textbox(label="Init status") | |
| gen_vid = gr.Video(label="Generated video") | |
| btn_dl.click(download_model, inputs=[], outputs=[dl_out]) | |
| btn_gen.click(generate_demo, inputs=[], outputs=[gen_status, gen_vid]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |