| | import os |
| | import random |
| | from datetime import datetime |
| |
|
| | import gradio as gr |
| | import numpy as np |
| | import torch |
| | from diffusers import AutoencoderKL, DDIMScheduler |
| | from einops import repeat |
| | from huggingface_hub import hf_hub_download, snapshot_download |
| | from omegaconf import OmegaConf |
| | from PIL import Image |
| | from torchvision import transforms |
| | from transformers import CLIPVisionModelWithProjection |
| |
|
| | from src.models.pose_guider import PoseGuider |
| | from src.models.unet_2d_condition import UNet2DConditionModel |
| | from src.models.unet_3d import UNet3DConditionModel |
| | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline |
| | from src.utils.download_models import prepare_base_model, prepare_image_encoder |
| | from src.utils.util import get_fps, read_frames, save_videos_grid |
| |
|
| | |
| | prepare_base_model() |
| | prepare_image_encoder() |
| |
|
| | snapshot_download( |
| | repo_id="stabilityai/sd-vae-ft-mse", local_dir="./pretrained_weights/sd-vae-ft-mse" |
| | ) |
| | snapshot_download( |
| | repo_id="patrolli/AnimateAnyone", |
| | local_dir="./pretrained_weights", |
| | ) |
| |
|
| |
|
| | class AnimateController: |
| | def __init__( |
| | self, |
| | config_path="./configs/prompts/animation.yaml", |
| | weight_dtype=torch.float16, |
| | ): |
| | |
| | self.config = OmegaConf.load(config_path) |
| | self.pipeline = None |
| | self.weight_dtype = weight_dtype |
| |
|
| | def animate( |
| | self, |
| | ref_image, |
| | pose_video_path, |
| | width=512, |
| | height=768, |
| | length=24, |
| | num_inference_steps=25, |
| | cfg=3.5, |
| | seed=123, |
| | ): |
| | generator = torch.manual_seed(seed) |
| | if isinstance(ref_image, np.ndarray): |
| | ref_image = Image.fromarray(ref_image) |
| | if self.pipeline is None: |
| | vae = AutoencoderKL.from_pretrained( |
| | self.config.pretrained_vae_path, |
| | ).to("cuda", dtype=self.weight_dtype) |
| |
|
| | reference_unet = UNet2DConditionModel.from_pretrained( |
| | self.config.pretrained_base_model_path, |
| | subfolder="unet", |
| | ).to(dtype=self.weight_dtype, device="cuda") |
| |
|
| | inference_config_path = self.config.inference_config |
| | infer_config = OmegaConf.load(inference_config_path) |
| | denoising_unet = UNet3DConditionModel.from_pretrained_2d( |
| | self.config.pretrained_base_model_path, |
| | self.config.motion_module_path, |
| | subfolder="unet", |
| | unet_additional_kwargs=infer_config.unet_additional_kwargs, |
| | ).to(dtype=self.weight_dtype, device="cuda") |
| |
|
| | pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to( |
| | dtype=self.weight_dtype, device="cuda" |
| | ) |
| |
|
| | image_enc = CLIPVisionModelWithProjection.from_pretrained( |
| | self.config.image_encoder_path |
| | ).to(dtype=self.weight_dtype, device="cuda") |
| | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) |
| | scheduler = DDIMScheduler(**sched_kwargs) |
| |
|
| | |
| | denoising_unet.load_state_dict( |
| | torch.load(self.config.denoising_unet_path, map_location="cpu"), |
| | strict=False, |
| | ) |
| | reference_unet.load_state_dict( |
| | torch.load(self.config.reference_unet_path, map_location="cpu"), |
| | ) |
| | pose_guider.load_state_dict( |
| | torch.load(self.config.pose_guider_path, map_location="cpu"), |
| | ) |
| |
|
| | pipe = Pose2VideoPipeline( |
| | vae=vae, |
| | image_encoder=image_enc, |
| | reference_unet=reference_unet, |
| | denoising_unet=denoising_unet, |
| | pose_guider=pose_guider, |
| | scheduler=scheduler, |
| | ) |
| | pipe = pipe.to("cuda", dtype=self.weight_dtype) |
| | self.pipeline = pipe |
| |
|
| | pose_images = read_frames(pose_video_path) |
| | src_fps = get_fps(pose_video_path) |
| |
|
| | pose_list = [] |
| | total_length = min(length, len(pose_images)) |
| | for pose_image_pil in pose_images[:total_length]: |
| | pose_list.append(pose_image_pil) |
| |
|
| | video = self.pipeline( |
| | ref_image, |
| | pose_list, |
| | width=width, |
| | height=height, |
| | video_length=total_length, |
| | num_inference_steps=num_inference_steps, |
| | guidance_scale=cfg, |
| | generator=generator, |
| | ).videos |
| |
|
| | new_h, new_w = video.shape[-2:] |
| | pose_transform = transforms.Compose( |
| | [transforms.Resize((new_h, new_w)), transforms.ToTensor()] |
| | ) |
| | pose_tensor_list = [] |
| | for pose_image_pil in pose_images[:total_length]: |
| | pose_tensor_list.append(pose_transform(pose_image_pil)) |
| |
|
| | ref_image_tensor = pose_transform(ref_image) |
| | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) |
| | ref_image_tensor = repeat( |
| | ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=total_length |
| | ) |
| | pose_tensor = torch.stack(pose_tensor_list, dim=0) |
| | pose_tensor = pose_tensor.transpose(0, 1) |
| | pose_tensor = pose_tensor.unsqueeze(0) |
| | video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0) |
| |
|
| | save_dir = f"./output/gradio" |
| | if not os.path.exists(save_dir): |
| | os.makedirs(save_dir, exist_ok=True) |
| | date_str = datetime.now().strftime("%Y%m%d") |
| | time_str = datetime.now().strftime("%H%M") |
| | out_path = os.path.join(save_dir, f"{date_str}T{time_str}.mp4") |
| | save_videos_grid( |
| | video, |
| | out_path, |
| | n_rows=3, |
| | fps=src_fps, |
| | ) |
| |
|
| | torch.cuda.empty_cache() |
| |
|
| | return out_path |
| |
|
| |
|
| | controller = AnimateController() |
| |
|
| |
|
| | def ui(): |
| | from datasets import load_dataset |
| | import io |
| | from PIL import Image |
| |
|
| | |
| | image_ds = load_dataset("svjack/Genshin-Impact-Item-Image") |
| | image_df = image_ds["train"].to_pandas() |
| | image_df = image_df[ |
| | image_df["tag"].map( |
| | lambda x: "肖像" in x and "角色" in x |
| | ) |
| | ] |
| |
|
| | def bytes_to_pil_image(byte_data): |
| | """ |
| | Convert a byte array to a PIL Image. |
| | |
| | :param byte_data: A byte array containing image data. |
| | :return: A PIL Image object. |
| | """ |
| | |
| | image_stream = io.BytesIO(byte_data) |
| |
|
| | |
| | pil_image = Image.open(image_stream) |
| |
|
| | return pil_image |
| |
|
| | image_df["image"] = image_df["image"].map(lambda x: bytes_to_pil_image(x["bytes"])) |
| |
|
| | with gr.Blocks() as demo: |
| | gr.HTML( |
| | """ |
| | <h1 style="color:#dc5b1c;text-align:center"> |
| | Moore-AnimateAnyone Gradio Demo |
| | </h1> |
| | <div style="text-align:center"> |
| | <div style="display: inline-block; text-align: left;"> |
| | <p> This is a quick preview demo of Moore-AnimateAnyone. We appreciate the assistance provided by the HuggingFace team in setting up this demo. </p> |
| | <p> If you like this project, please consider giving a star on <a herf="https://github.com/MooreThreads/Moore-AnimateAnyone"> our GitHub repo </a> 🤗. </p> |
| | </div> |
| | </div> |
| | """ |
| | ) |
| |
|
| | |
| | with gr.Row(): |
| | gallery = gr.Gallery( |
| | image_df["image"].tolist(), |
| | label="Select Reference Image", |
| | show_label=True, |
| | elem_id="gallery", |
| | columns=[2, 3, 4, 5, 6, 6], |
| | rows=[2, 2, 2, 2, 2, 2], |
| | height="400px", |
| | object_fit="contain", |
| | ) |
| |
|
| | with gr.Row(): |
| | reference_image = gr.Image(label="Reference Image") |
| | motion_sequence = gr.Video( |
| | format="mp4", label="Motion Sequence", height=512 |
| | ) |
| |
|
| | with gr.Column(): |
| | width_slider = gr.Slider( |
| | label="Width", minimum=448, maximum=768, value=512, step=64 |
| | ) |
| | height_slider = gr.Slider( |
| | label="Height", minimum=512, maximum=960, value=768, step=64 |
| | ) |
| | length_slider = gr.Slider( |
| | label="Video Length", minimum=24, maximum=128, value=72, step=24 |
| | ) |
| | with gr.Row(): |
| | seed_textbox = gr.Textbox(label="Seed", value=-1) |
| | seed_button = gr.Button( |
| | value="\U0001F3B2", elem_classes="toolbutton" |
| | ) |
| | seed_button.click( |
| | fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), |
| | inputs=[], |
| | outputs=[seed_textbox], |
| | ) |
| | with gr.Row(): |
| | sampling_steps = gr.Slider( |
| | label="Sampling steps", |
| | value=15, |
| | info="default: 15", |
| | step=5, |
| | maximum=20, |
| | minimum=10, |
| | ) |
| | guidance_scale = gr.Slider( |
| | label="Guidance scale", |
| | value=3.5, |
| | info="default: 3.5", |
| | step=0.5, |
| | maximum=6.5, |
| | minimum=2.0, |
| | ) |
| | submit = gr.Button("Animate") |
| |
|
| | |
| | |
| | with gr.Row(): |
| | animation = gr.Video( |
| | format="mp4", |
| | label="Animation Results", |
| | height=448, |
| | autoplay=True, |
| | ) |
| |
|
| | def read_video(video): |
| | return video |
| |
|
| | def read_image(image): |
| | return Image.fromarray(image) |
| |
|
| | def select_image(selection: gr.SelectData): |
| | print(selection.value['image']) |
| | return selection.value['image']["path"] |
| |
|
| | |
| | motion_sequence.upload( |
| | read_video, motion_sequence, motion_sequence, queue=False |
| | ) |
| | |
| | reference_image.upload( |
| | read_image, reference_image, reference_image, queue=False |
| | ) |
| | |
| | submit.click( |
| | controller.animate, |
| | [ |
| | reference_image, |
| | motion_sequence, |
| | width_slider, |
| | height_slider, |
| | length_slider, |
| | sampling_steps, |
| | guidance_scale, |
| | seed_textbox, |
| | ], |
| | animation, |
| | ) |
| |
|
| | gallery.select(fn=select_image, inputs=None, outputs=[reference_image]) |
| |
|
| | |
| | gr.Markdown("## Examples") |
| | gr.Examples( |
| | examples=[ |
| | [ |
| | "./configs/inference/ref_images/anyone-5.png", |
| | "./configs/inference/pose_videos/anyone-video-2_kps.mp4", |
| | 512, |
| | 768, |
| | 72, |
| | ], |
| | [ |
| | "./configs/inference/ref_images/anyone-10.png", |
| | "./configs/inference/pose_videos/anyone-video-1_kps.mp4", |
| | 512, |
| | 768, |
| | 72, |
| | ], |
| | [ |
| | "./configs/inference/ref_images/anyone-2.png", |
| | "./configs/inference/pose_videos/anyone-video-5_kps.mp4", |
| | 512, |
| | 768, |
| | 72, |
| | ], |
| | ], |
| | inputs=[reference_image, motion_sequence, width_slider, height_slider, length_slider], |
| | outputs=animation, |
| | ) |
| |
|
| | return demo |
| |
|
| |
|
| | demo = ui() |
| | demo.queue(max_size=10) |
| | demo.launch(share=True, show_api=False) |