Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Copyright (c) Alibaba, Inc. and its affiliates. | |
| import argparse | |
| import os | |
| import sys | |
| import datetime | |
| import imageio | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| snapshot_download( | |
| repo_id = "Wan-AI/Wan2.1-VACE-1.3B", | |
| local_dir = "./models/Wan2.1-VACE-1.3B" | |
| ) | |
| is_shared_ui = True if "fffiloni/Wan2.1-VACE-1.3B" in os.environ['SPACE_ID'] else False | |
| sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2])) | |
| import wan | |
| from wan import WanVace, WanVaceMP | |
| from wan.configs import WAN_CONFIGS, SIZE_CONFIGS | |
| class FixedSizeQueue: | |
| def __init__(self, max_size): | |
| self.max_size = max_size | |
| self.queue = [] | |
| def add(self, item): | |
| self.queue.insert(0, item) | |
| if len(self.queue) > self.max_size: | |
| self.queue.pop() | |
| def get(self): | |
| return self.queue | |
| def __repr__(self): | |
| return str(self.queue) | |
| class VACEInference: | |
| def __init__(self, cfg, skip_load=False, gallery_share=True, gallery_share_limit=5): | |
| self.cfg = cfg | |
| self.save_dir = cfg.save_dir | |
| self.gallery_share = gallery_share | |
| self.gallery_share_data = FixedSizeQueue(max_size=gallery_share_limit) | |
| if not skip_load: | |
| if not args.mp: | |
| self.pipe = WanVace( | |
| config=WAN_CONFIGS[cfg.model_name], | |
| checkpoint_dir=cfg.ckpt_dir, | |
| device_id=0, | |
| rank=0, | |
| t5_fsdp=False, | |
| dit_fsdp=False, | |
| use_usp=False, | |
| ) | |
| else: | |
| self.pipe = WanVaceMP( | |
| config=WAN_CONFIGS[cfg.model_name], | |
| checkpoint_dir=cfg.ckpt_dir, | |
| use_usp=True, | |
| ulysses_size=cfg.ulysses_size, | |
| ring_size=cfg.ring_size | |
| ) | |
| def create_ui(self, *args, **kwargs): | |
| gr.Markdown("# VACE-WAN 1.3B Demo") | |
| gr.Markdown("All-in-One Video Creation and Editing") | |
| gr.HTML(""" | |
| <div style="display:flex;column-gap:4px;"> | |
| <a href="https://ali-vilab.github.io/VACE-Page/"> | |
| <img src='https://img.shields.io/badge/Project-Page-green'> | |
| </a> | |
| <a href="https://huggingface.co/spaces/fffiloni/Wan2.1-VACE-1.3B?duplicate=true"> | |
| <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space"> | |
| </a> | |
| </div> | |
| """) | |
| with gr.Row(variant='panel', equal_height=True): | |
| with gr.Column(scale=1, min_width=0): | |
| self.src_video = gr.Video( | |
| label="src_video", | |
| sources=['upload'], | |
| value=None, | |
| interactive=True) | |
| with gr.Column(scale=1, min_width=0): | |
| self.src_mask = gr.Video( | |
| label="src_mask", | |
| sources=['upload'], | |
| value=None, | |
| interactive=True) | |
| # | |
| with gr.Row(variant='panel', equal_height=True): | |
| with gr.Column(scale=1, min_width=0): | |
| with gr.Row(equal_height=True): | |
| self.src_ref_image_1 = gr.Image(label='src_ref_image_1', | |
| height=200, | |
| interactive=True, | |
| type='filepath', | |
| image_mode='RGB', | |
| sources=['upload'], | |
| elem_id="src_ref_image_1", | |
| format='png') | |
| self.src_ref_image_2 = gr.Image(label='src_ref_image_2', | |
| height=200, | |
| interactive=True, | |
| type='filepath', | |
| image_mode='RGB', | |
| sources=['upload'], | |
| elem_id="src_ref_image_2", | |
| format='png') | |
| self.src_ref_image_3 = gr.Image(label='src_ref_image_3', | |
| height=200, | |
| interactive=True, | |
| type='filepath', | |
| image_mode='RGB', | |
| sources=['upload'], | |
| elem_id="src_ref_image_3", | |
| format='png') | |
| with gr.Row(variant='panel', equal_height=True): | |
| with gr.Column(scale=1): | |
| self.prompt = gr.Textbox( | |
| show_label=False, | |
| placeholder="positive_prompt_input", | |
| elem_id='positive_prompt', | |
| container=True, | |
| autofocus=True, | |
| elem_classes='type_row', | |
| visible=True, | |
| lines=2) | |
| self.negative_prompt = gr.Textbox( | |
| show_label=False, | |
| value=self.pipe.config.sample_neg_prompt, | |
| placeholder="negative_prompt_input", | |
| elem_id='negative_prompt', | |
| container=True, | |
| autofocus=False, | |
| elem_classes='type_row', | |
| visible=True, | |
| interactive=True, | |
| lines=1) | |
| # | |
| with gr.Row(variant='panel', equal_height=True): | |
| with gr.Column(scale=1, min_width=0): | |
| with gr.Row(equal_height=True): | |
| self.shift_scale = gr.Slider( | |
| label='shift_scale', | |
| minimum=0.0, | |
| maximum=100.0, | |
| step=1.0, | |
| value=16.0, | |
| interactive=True) | |
| self.sample_steps = gr.Slider( | |
| label='sample_steps', | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=25, | |
| interactive=False if is_shared_ui else True) | |
| self.context_scale = gr.Slider( | |
| label='context_scale', | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.1, | |
| value=1.0, | |
| interactive=True) | |
| self.guide_scale = gr.Slider( | |
| label='guide_scale', | |
| minimum=1, | |
| maximum=10, | |
| step=0.5, | |
| value=5.0, | |
| interactive=True) | |
| self.infer_seed = gr.Slider(minimum=-1, | |
| maximum=10000000, | |
| value=2025, | |
| label="Seed") | |
| # | |
| with gr.Accordion(label="Usable without source video", open=False): | |
| with gr.Row(equal_height=True): | |
| self.output_height = gr.Textbox( | |
| label='resolutions_height', | |
| value=480, | |
| #value=720, | |
| interactive=True) | |
| self.output_width = gr.Textbox( | |
| label='resolutions_width', | |
| value=832, | |
| #value=1280, | |
| interactive=True) | |
| self.frame_rate = gr.Textbox( | |
| label='frame_rate', | |
| value=16, | |
| interactive=True) | |
| self.num_frames = gr.Textbox( | |
| label='num_frames', | |
| value=81, | |
| interactive=True) | |
| # | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=5): | |
| self.generate_button = gr.Button( | |
| value='Run', | |
| elem_classes='type_row', | |
| elem_id='generate_button', | |
| visible=True) | |
| with gr.Column(scale=1): | |
| self.refresh_button = gr.Button(value='\U0001f504') # π | |
| # | |
| self.output_gallery = gr.Gallery( | |
| label="output_gallery", | |
| value=[], | |
| interactive=False, | |
| allow_preview=True, | |
| preview=True) | |
| def generate(self, output_gallery, src_video, src_mask, src_ref_image_1, src_ref_image_2, src_ref_image_3, prompt, negative_prompt, shift_scale, sample_steps, context_scale, guide_scale, infer_seed, output_height, output_width, frame_rate, num_frames, progress=gr.Progress(track_tqdm=True)): | |
| output_height, output_width, frame_rate, num_frames = int(output_height), int(output_width), int(frame_rate), int(num_frames) | |
| src_ref_images = [x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3] if | |
| x is not None] | |
| src_video, src_mask, src_ref_images = self.pipe.prepare_source([src_video], | |
| [src_mask], | |
| [src_ref_images], | |
| num_frames=num_frames, | |
| image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"], | |
| device=self.pipe.device) | |
| video = self.pipe.generate( | |
| prompt, | |
| src_video, | |
| src_mask, | |
| src_ref_images, | |
| size=(output_width, output_height), | |
| context_scale=context_scale, | |
| shift=shift_scale, | |
| sampling_steps=sample_steps, | |
| guide_scale=guide_scale, | |
| n_prompt=negative_prompt, | |
| seed=infer_seed, | |
| offload_model=True) | |
| name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now()) | |
| video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4') | |
| video_frames = (torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8) | |
| try: | |
| writer = imageio.get_writer(video_path, fps=frame_rate, codec='libx264', quality=8, macro_block_size=1) | |
| for frame in video_frames: | |
| writer.append_data(frame) | |
| writer.close() | |
| print(video_path) | |
| except Exception as e: | |
| raise gr.Error(f"Video save error: {e}") | |
| if self.gallery_share: | |
| self.gallery_share_data.add(video_path) | |
| return self.gallery_share_data.get() | |
| else: | |
| return [video_path] | |
| def set_callbacks(self, **kwargs): | |
| self.gen_inputs = [self.output_gallery, self.src_video, self.src_mask, self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3, self.prompt, self.negative_prompt, self.shift_scale, self.sample_steps, self.context_scale, self.guide_scale, self.infer_seed, self.output_height, self.output_width, self.frame_rate, self.num_frames] | |
| self.gen_outputs = [self.output_gallery] | |
| self.generate_button.click(self.generate, | |
| inputs=self.gen_inputs, | |
| outputs=self.gen_outputs, | |
| queue=True) | |
| self.refresh_button.click(lambda x: self.gallery_share_data.get() if self.gallery_share else x, inputs=[self.output_gallery], outputs=[self.output_gallery]) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='Argparser for VACE-WAN Demo:\n') | |
| parser.add_argument('--server_port', dest='server_port', help='', type=int, default=7860) | |
| parser.add_argument('--server_name', dest='server_name', help='', default='0.0.0.0') | |
| parser.add_argument('--root_path', dest='root_path', help='', default=None) | |
| parser.add_argument('--save_dir', dest='save_dir', help='', default='cache') | |
| parser.add_argument("--mp", action="store_true", help="Use Multi-GPUs",) | |
| parser.add_argument("--model_name", type=str, default="vace-1.3B", choices=list(WAN_CONFIGS.keys()), help="The model name to run.") | |
| parser.add_argument("--ulysses_size", type=int, default=1, help="The size of the ulysses parallelism in DiT.") | |
| parser.add_argument("--ring_size", type=int, default=1, help="The size of the ring attention parallelism in DiT.") | |
| parser.add_argument( | |
| "--ckpt_dir", | |
| type=str, | |
| # default='models/VACE-Wan2.1-1.3B-Preview', | |
| default='models/Wan2.1-VACE-1.3B/', | |
| help="The path to the checkpoint directory.", | |
| ) | |
| parser.add_argument( | |
| "--offload_to_cpu", | |
| action="store_true", | |
| help="Offloading unnecessary computations to CPU.", | |
| ) | |
| args = parser.parse_args() | |
| if not os.path.exists(args.save_dir): | |
| os.makedirs(args.save_dir, exist_ok=True) | |
| with gr.Blocks() as demo: | |
| infer_gr = VACEInference(args, skip_load=False, gallery_share=True, gallery_share_limit=5) | |
| infer_gr.create_ui() | |
| infer_gr.set_callbacks() | |
| allowed_paths = [args.save_dir] | |
| demo.queue(status_update_rate=1).launch(server_name=args.server_name, | |
| server_port=args.server_port, | |
| root_path=args.root_path, | |
| allowed_paths=allowed_paths, | |
| show_error=True, debug=True) | |