| | import os |
| | import random |
| | import spaces |
| | from datetime import datetime |
| | import gradio as gr |
| | import numpy as np |
| | import torch |
| | from diffusers import AutoencoderKL, DDIMScheduler |
| | from einops import repeat |
| | from huggingface_hub import snapshot_download |
| | from omegaconf import OmegaConf |
| | from PIL import Image |
| | from torchvision import transforms |
| | from transformers import CLIPVisionModelWithProjection |
| | from src.models.pose_guider import PoseGuider |
| | from src.models.unet_2d_condition import UNet2DConditionModel |
| | from src.models.unet_3d import UNet3DConditionModel |
| | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline |
| | from src.utils.download_models import prepare_base_model, prepare_image_encoder |
| | from src.utils.util import get_fps, read_frames, save_videos_grid |
| |
|
| | prepare_base_model() |
| | prepare_image_encoder() |
| | snapshot_download(repo_id="stabilityai/sd-vae-ft-mse", local_dir="./pretrained_weights/sd-vae-ft-mse") |
| | snapshot_download(repo_id="patrolli/AnimateAnyone", local_dir="./pretrained_weights") |
| |
|
| | class AnimateController: |
| | def __init__(self, config_path="./configs/prompts/animation.yaml", weight_dtype=torch.float16): |
| | self.config = OmegaConf.load(config_path) |
| | self.pipeline = None |
| | self.weight_dtype = weight_dtype |
| |
|
| | @spaces.GPU(duration=60) |
| | def animate(self, ref_image, pose_video_path, width=512, height=768, length=24, num_inference_steps=25, cfg=3.5, seed=123): |
| | generator = torch.manual_seed(seed) |
| | if isinstance(ref_image, np.ndarray): |
| | ref_image = Image.fromarray(ref_image) |
| | if self.pipeline is None: |
| | vae = AutoencoderKL.from_pretrained(self.config.pretrained_vae_path).to("cuda", dtype=self.weight_dtype) |
| | reference_unet = UNet2DConditionModel.from_pretrained(self.config.pretrained_base_model_path, subfolder="unet").to(dtype=self.weight_dtype, device="cuda") |
| | infer_config = OmegaConf.load(self.config.inference_config) |
| | denoising_unet = UNet3DConditionModel.from_pretrained_2d( |
| | self.config.pretrained_base_model_path, |
| | self.config.motion_module_path, |
| | subfolder="unet", |
| | unet_additional_kwargs=infer_config.unet_additional_kwargs, |
| | ).to(dtype=self.weight_dtype, device="cuda") |
| | pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(dtype=self.weight_dtype, device="cuda") |
| | image_enc = CLIPVisionModelWithProjection.from_pretrained(self.config.image_encoder_path).to(dtype=self.weight_dtype, device="cuda") |
| | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) |
| | scheduler = DDIMScheduler(**sched_kwargs) |
| | denoising_unet.load_state_dict(torch.load(self.config.denoising_unet_path, map_location="cpu"), strict=False) |
| | reference_unet.load_state_dict(torch.load(self.config.reference_unet_path, map_location="cpu")) |
| | pose_guider.load_state_dict(torch.load(self.config.pose_guider_path, map_location="cpu")) |
| | pipe = Pose2VideoPipeline( |
| | vae=vae, |
| | image_encoder=image_enc, |
| | reference_unet=reference_unet, |
| | denoising_unet=denoising_unet, |
| | pose_guider=pose_guider, |
| | scheduler=scheduler, |
| | ) |
| | pipe = pipe.to("cuda", dtype=self.weight_dtype) |
| | self.pipeline = pipe |
| |
|
| | pose_images = read_frames(pose_video_path) |
| | src_fps = get_fps(pose_video_path) |
| | pose_list = [] |
| | total_length = min(length, len(pose_images)) |
| | for pose_image_pil in pose_images[:total_length]: |
| | pose_list.append(pose_image_pil) |
| | video = self.pipeline( |
| | ref_image, |
| | pose_list, |
| | width=width, |
| | height=height, |
| | video_length=total_length, |
| | num_inference_steps=num_inference_steps, |
| | guidance_scale=cfg, |
| | generator=generator, |
| | ).videos |
| |
|
| | new_h, new_w = video.shape[-2:] |
| | pose_transform = transforms.Compose([transforms.Resize((new_h, new_w)), transforms.ToTensor()]) |
| | pose_tensor_list = [] |
| | for pose_image_pil in pose_images[:total_length]: |
| | pose_tensor_list.append(pose_transform(pose_image_pil)) |
| |
|
| | ref_image_tensor = pose_transform(ref_image).unsqueeze(1).unsqueeze(0) |
| | ref_image_tensor = repeat(ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=total_length) |
| | pose_tensor = torch.stack(pose_tensor_list, dim=0).transpose(0, 1).unsqueeze(0) |
| | video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0) |
| |
|
| | save_dir = "./output/gradio" |
| | if not os.path.exists(save_dir): |
| | os.makedirs(save_dir, exist_ok=True) |
| | date_str = datetime.now().strftime("%Y%m%d") |
| | time_str = datetime.now().strftime("%H%M") |
| | out_path = os.path.join(save_dir, f"{date_str}T{time_str}.mp4") |
| | save_videos_grid(video, out_path, n_rows=3, fps=src_fps) |
| | torch.cuda.empty_cache() |
| | return out_path |
| |
|
| | controller = AnimateController() |
| |
|
| | def ui(): |
| | with gr.Blocks() as demo: |
| | gr.HTML( |
| | """ |
| | <h1 style="color:#dc5b1c;text-align:center"> |
| | Moore-AnimateAnyone Gradio Demo |
| | </h1> |
| | <div style="text-align:center"> |
| | <div style="display: inline-block; text-align: left;"> |
| | <p>This is a quick preview demo of Moore-AnimateAnyone. We appreciate the assistance provided by the HuggingFace team in setting up this demo.</p> |
| | <p>If you like this project, please consider giving a star on <a herf="https://github.com/MooreThreads/Moore-AnimateAnyone">our GitHub repo</a> 🤗.</p> |
| | </div> |
| | </div> |
| | """ |
| | ) |
| | animation = gr.Video(format="mp4", label="Animation Results", height=448, autoplay=True) |
| | with gr.Row(): |
| | reference_image = gr.Image(label="Reference Image") |
| | motion_sequence = gr.Video(format="mp4", label="Motion Sequence", height=512) |
| | with gr.Column(): |
| | width_slider = gr.Slider(label="Width", minimum=448, maximum=768, value=512, step=64) |
| | height_slider = gr.Slider(label="Height", minimum=512, maximum=960, value=768, step=64) |
| | length_slider = gr.Slider(label="Video Length", minimum=24, maximum=128, value=72, step=24) |
| | with gr.Row(): |
| | seed_textbox = gr.Textbox(label="Seed", value=-1) |
| | seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") |
| | seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox]) |
| | with gr.Row(): |
| | sampling_steps = gr.Slider(label="Sampling steps", value=15, info="default: 15", step=5, maximum=20, minimum=10) |
| | guidance_scale = gr.Slider(label="Guidance scale", value=3.5, info="default: 3.5", step=0.5, maximum=6.5, minimum=2.0) |
| | submit = gr.Button("Animate") |
| | motion_sequence.upload(lambda x: x, motion_sequence, motion_sequence, queue=False) |
| | reference_image.upload(lambda x: Image.fromarray(x), reference_image, reference_image, queue=False) |
| | submit.click( |
| | controller.animate, |
| | [reference_image, motion_sequence, width_slider, height_slider, length_slider, sampling_steps, guidance_scale, seed_textbox], |
| | animation, |
| | ) |
| | gr.Markdown("## Examples") |
| | gr.Examples( |
| | examples=[ |
| | ["./configs/inference/ref_images/anyone-5.png", "./configs/inference/pose_videos/anyone-video-2_kps.mp4", 512, 768, 72], |
| | ["./configs/inference/ref_images/anyone-10.png", "./configs/inference/pose_videos/anyone-video-1_kps.mp4", 512, 768, 72], |
| | ["./configs/inference/ref_images/anyone-2.png", "./configs/inference/pose_videos/anyone-video-5_kps.mp4", 512, 768, 72], |
| | ], |
| | inputs=[reference_image, motion_sequence, width_slider, height_slider, length_slider], |
| | outputs=animation, |
| | ) |
| | return demo |
| |
|
| | demo = ui() |
| | demo.queue(max_size=10) |
| | demo.launch(share=True, show_api=False) |