Spaces:
Build error
Build error
| import argparse | |
| import os | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import List | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| from diffusers import AutoencoderKL, DDIMScheduler | |
| from omegaconf import OmegaConf | |
| from PIL import Image | |
| from transformers import CLIPVisionModelWithProjection | |
| from src.models.pose_guider import PoseGuider | |
| from src.models.unet_2d_condition import UNet2DConditionModel | |
| from src.models.unet_3d_edit_bkfill import UNet3DConditionModel | |
| from src.pipelines.pipeline_pose2vid_long_edit_bkfill_roiclip import Pose2VideoPipeline | |
| from src.utils.util import get_fps, read_frames | |
| import cv2 | |
| from tools.human_segmenter import human_segmenter | |
| import imageio | |
| from tools.util import all_file, load_mask_list, crop_img, pad_img, crop_human_clip_auto_context, get_mask, \ | |
| refine_img_prepross, recover_bk | |
| from tools.util import load_video_fixed_fps | |
| import json | |
| seg_path = './assets/matting_human.pb' | |
| segmenter = human_segmenter(model_path=seg_path) | |
| def init_bk(n_frame, tw, th): | |
| """Initialize background images with white background""" | |
| bk_images = [] | |
| for _ in range(n_frame): | |
| bk_img = Image.new('RGB', (tw, th), (255, 255, 255)) | |
| bk_images.append(bk_img) | |
| return bk_images | |
| def process_seg(img): | |
| rgba = segmenter.run(img) | |
| mask = rgba[:, :, 3] | |
| color = rgba[:, :, :3] | |
| alpha = mask / 255 | |
| bk = np.ones_like(color) * 255 | |
| color = color * alpha[:, :, np.newaxis] + bk * (1 - alpha[:, :, np.newaxis]) | |
| color = color.astype(np.uint8) | |
| return color, mask | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", type=str, default='./configs/prompts/animation_edit.yaml') | |
| parser.add_argument("-W", type=int, default=784) | |
| parser.add_argument("-H", type=int, default=784) | |
| parser.add_argument("-L", type=int, default=64) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--cfg", type=float, default=3.5) | |
| parser.add_argument("--steps", type=int, default=25) | |
| parser.add_argument("--fps", type=int) | |
| parser.add_argument("--assets_dir", type=str, default='./assets') | |
| parser.add_argument("--ref_pad", type=int, default=1) | |
| parser.add_argument("--use_bk", type=int, default=1) | |
| parser.add_argument("--clip_length", type=int, default=32) | |
| parser.add_argument("--MAX_FRAME_NUM", type=int, default=150) | |
| args = parser.parse_args() | |
| return args | |
| class MIMO(): | |
| def __init__(self, debug_mode=False): | |
| args = parse_args() | |
| config = OmegaConf.load(args.config) | |
| # Auto-detect device (CPU/CUDA) | |
| if torch.cuda.is_available(): | |
| self.device = "cuda" | |
| print("🚀 Using CUDA GPU for inference") | |
| else: | |
| self.device = "cpu" | |
| print("⚠️ CUDA not available, running on CPU (will be slow)") | |
| if config.weight_dtype == "fp16" and self.device == "cuda": | |
| weight_dtype = torch.float16 | |
| else: | |
| weight_dtype = torch.float32 | |
| vae = AutoencoderKL.from_pretrained( | |
| config.pretrained_vae_path, | |
| ).to(self.device, dtype=weight_dtype) | |
| reference_unet = UNet2DConditionModel.from_pretrained( | |
| config.pretrained_base_model_path, | |
| subfolder="unet", | |
| ).to(dtype=weight_dtype, device=self.device) | |
| inference_config_path = config.inference_config | |
| infer_config = OmegaConf.load(inference_config_path) | |
| denoising_unet = UNet3DConditionModel.from_pretrained_2d( | |
| config.pretrained_base_model_path, | |
| config.motion_module_path, | |
| subfolder="unet", | |
| unet_additional_kwargs=infer_config.unet_additional_kwargs, | |
| ).to(dtype=weight_dtype, device=self.device) | |
| pose_guider = PoseGuider(320, conditioning_channels=3, block_out_channels=(16, 32, 96, 256)).to( | |
| dtype=weight_dtype, device=self.device | |
| ) | |
| image_enc = CLIPVisionModelWithProjection.from_pretrained( | |
| config.image_encoder_path | |
| ).to(dtype=weight_dtype, device=self.device) | |
| sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) | |
| scheduler = DDIMScheduler(**sched_kwargs) | |
| self.generator = torch.manual_seed(args.seed) | |
| self.width, self.height = args.W, args.H | |
| # load pretrained weights | |
| denoising_unet.load_state_dict( | |
| torch.load(config.denoising_unet_path, map_location="cpu"), | |
| strict=False, | |
| ) | |
| reference_unet.load_state_dict( | |
| torch.load(config.reference_unet_path, map_location="cpu"), | |
| ) | |
| pose_guider.load_state_dict( | |
| torch.load(config.pose_guider_path, map_location="cpu"), | |
| ) | |
| self.pipe = Pose2VideoPipeline( | |
| vae=vae, | |
| image_encoder=image_enc, | |
| reference_unet=reference_unet, | |
| denoising_unet=denoising_unet, | |
| pose_guider=pose_guider, | |
| scheduler=scheduler, | |
| ) | |
| self.pipe = self.pipe.to(self.device, dtype=weight_dtype) | |
| self.args = args | |
| # load mask | |
| mask_path = os.path.join(self.args.assets_dir, 'masks', 'alpha2.png') | |
| self.mask_list = load_mask_list(mask_path) | |
| def load_template(self, template_path): | |
| video_path = os.path.join(template_path, 'vid.mp4') | |
| pose_video_path = os.path.join(template_path, 'sdc.mp4') | |
| bk_video_path = os.path.join(template_path, 'bk.mp4') | |
| occ_video_path = os.path.join(template_path, 'occ.mp4') | |
| if not os.path.exists(occ_video_path): | |
| occ_video_path = None | |
| config_file = os.path.join(template_path, 'config.json') | |
| with open(config_file) as f: | |
| template_data = json.load(f) | |
| template_info = {} | |
| template_info['video_path'] = video_path | |
| template_info['pose_video_path'] = pose_video_path | |
| template_info['bk_video_path'] = bk_video_path | |
| template_info['occ_video_path'] = occ_video_path | |
| template_info['target_fps'] = template_data['fps'] | |
| template_info['time_crop'] = template_data['time_crop'] | |
| template_info['frame_crop'] = template_data['frame_crop'] | |
| template_info['layer_recover'] = template_data['layer_recover'] | |
| return template_info | |
| def run(self, ref_img_path, template_path): | |
| template_name = os.path.basename(template_path) | |
| template_info = self.load_template(template_path) | |
| target_fps = template_info['target_fps'] | |
| video_path = template_info['video_path'] | |
| pose_video_path = template_info['pose_video_path'] | |
| bk_video_path = template_info['bk_video_path'] | |
| occ_video_path = template_info['occ_video_path'] | |
| ref_image_pil = Image.open(ref_img_path).convert('RGB') | |
| source_image = np.array(ref_image_pil) | |
| source_image, mask = process_seg(source_image[..., ::-1]) | |
| source_image = source_image[..., ::-1] | |
| source_image = crop_img(source_image, mask) | |
| source_image, _ = pad_img(source_image, [255, 255, 255]) | |
| ref_image_pil = Image.fromarray(source_image) | |
| # load tgt | |
| vid_images = load_video_fixed_fps(video_path, target_fps=target_fps) | |
| if bk_video_path is None: | |
| n_frame = len(vid_images) | |
| tw, th = vid_images[0].size | |
| bk_images = init_bk(n_frame, tw, th) | |
| else: | |
| bk_images = load_video_fixed_fps(bk_video_path, target_fps=target_fps) | |
| if occ_video_path is not None: | |
| occ_mask_images = load_video_fixed_fps(occ_video_path, target_fps=target_fps) | |
| print('load occ from %s' % occ_video_path) | |
| else: | |
| occ_mask_images = None | |
| print('no occ masks') | |
| pose_images = load_video_fixed_fps(pose_video_path, target_fps=target_fps) | |
| src_fps = get_fps(pose_video_path) | |
| start_idx, end_idx = template_info['time_crop']['start_idx'], template_info['time_crop']['end_idx'] | |
| start_idx = int(target_fps * start_idx / 30) | |
| end_idx = int(target_fps * end_idx / 30) | |
| start_idx = max(0, start_idx) | |
| end_idx = min(len(pose_images), end_idx) | |
| pose_images = pose_images[start_idx:end_idx] | |
| vid_images = vid_images[start_idx:end_idx] | |
| bk_images = bk_images[start_idx:end_idx] | |
| if occ_mask_images is not None: | |
| occ_mask_images = occ_mask_images[start_idx:end_idx] | |
| self.args.L = len(pose_images) | |
| max_n_frames = self.args.MAX_FRAME_NUM | |
| if self.args.L > max_n_frames: | |
| pose_images = pose_images[:max_n_frames] | |
| vid_images = vid_images[:max_n_frames] | |
| bk_images = bk_images[:max_n_frames] | |
| if occ_mask_images is not None: | |
| occ_mask_images = occ_mask_images[:max_n_frames] | |
| self.args.L = len(pose_images) | |
| bk_images_ori = bk_images.copy() | |
| vid_images_ori = vid_images.copy() | |
| overlay = 4 | |
| pose_images, vid_images, bk_images, bbox_clip, context_list, bbox_clip_list = crop_human_clip_auto_context( | |
| pose_images, vid_images, bk_images, overlay) | |
| clip_pad_list_context = [] | |
| clip_padv_list_context = [] | |
| pose_list_context = [] | |
| vid_bk_list_context = [] | |
| for frame_idx in range(len(pose_images)): | |
| pose_image_pil = pose_images[frame_idx] | |
| pose_image = np.array(pose_image_pil) | |
| pose_image, _ = pad_img(pose_image, color=[0, 0, 0]) | |
| pose_image_pil = Image.fromarray(pose_image) | |
| pose_list_context.append(pose_image_pil) | |
| vid_bk = bk_images[frame_idx] | |
| vid_bk = np.array(vid_bk) | |
| vid_bk, padding_v = pad_img(vid_bk, color=[255, 255, 255]) | |
| pad_h, pad_w, _ = vid_bk.shape | |
| clip_pad_list_context.append([pad_h, pad_w]) | |
| clip_padv_list_context.append(padding_v) | |
| vid_bk_list_context.append(Image.fromarray(vid_bk)) | |
| print('start to infer...') | |
| video = self.pipe( | |
| ref_image_pil, | |
| pose_list_context, | |
| vid_bk_list_context, | |
| self.width, | |
| self.height, | |
| len(pose_list_context), | |
| self.args.steps, | |
| self.args.cfg, | |
| generator=self.generator, | |
| ).videos[0] | |
| # post-process video | |
| video_idx = 0 | |
| res_images = [None for _ in range(self.args.L)] | |
| for k, context in enumerate(context_list): | |
| start_i = context[0] | |
| bbox = bbox_clip_list[k] | |
| for i in context: | |
| bk_image_pil_ori = bk_images_ori[i] | |
| vid_image_pil_ori = vid_images_ori[i] | |
| if occ_mask_images is not None: | |
| occ_mask = occ_mask_images[i] | |
| else: | |
| occ_mask = None | |
| canvas = Image.new("RGB", bk_image_pil_ori.size, "white") | |
| pad_h, pad_w = clip_pad_list_context[video_idx] | |
| padding_v = clip_padv_list_context[video_idx] | |
| image = video[:, video_idx, :, :].permute(1, 2, 0).cpu().numpy() | |
| res_image_pil = Image.fromarray((image * 255).astype(np.uint8)) | |
| res_image_pil = res_image_pil.resize((pad_w, pad_h)) | |
| top, bottom, left, right = padding_v | |
| res_image_pil = res_image_pil.crop((left, top, pad_w - right, pad_h - bottom)) | |
| w_min, w_max, h_min, h_max = bbox | |
| canvas.paste(res_image_pil, (w_min, h_min)) | |
| mask_full = np.zeros((bk_image_pil_ori.size[1], bk_image_pil_ori.size[0]), dtype=np.float32) | |
| mask = get_mask(self.mask_list, bbox, bk_image_pil_ori) | |
| mask = cv2.resize(mask, res_image_pil.size, interpolation=cv2.INTER_AREA) | |
| mask_full[h_min:h_min + mask.shape[0], w_min:w_min + mask.shape[1]] = mask | |
| res_image = np.array(canvas) | |
| bk_image = np.array(bk_image_pil_ori) | |
| res_image = res_image * mask_full[:, :, np.newaxis] + bk_image * (1 - mask_full[:, :, np.newaxis]) | |
| if occ_mask is not None: | |
| vid_image = np.array(vid_image_pil_ori) | |
| occ_mask = np.array(occ_mask)[:, :, 0].astype(np.uint8) # [0,255] | |
| occ_mask = occ_mask / 255.0 | |
| res_image = res_image * (1 - occ_mask[:, :, np.newaxis]) + vid_image * occ_mask[:, :, | |
| np.newaxis] | |
| if res_images[i] is None: | |
| res_images[i] = res_image | |
| else: | |
| factor = (i - start_i + 1) / (overlay + 1) | |
| res_images[i] = res_images[i] * (1 - factor) + res_image * factor | |
| res_images[i] = res_images[i].astype(np.uint8) | |
| video_idx = video_idx + 1 | |
| return res_images, target_fps | |
| def main(): | |
| model = MIMO() | |
| ref_img_path = './assets/test_image/sugar.jpg' | |
| template_path = './assets/video_template/sports_basketball_gym' | |
| save_dir = 'output' | |
| if not os.path.exists(save_dir): | |
| os.makedirs(save_dir) | |
| print('refer_img: %s' % ref_img_path) | |
| print('template_vid: %s' % template_path) | |
| ref_name = os.path.basename(ref_img_path).split('.')[0] | |
| template_name = os.path.basename(template_path) | |
| outpath = f"{save_dir}/{template_name}_{ref_name}.mp4" | |
| res, target_fps = model.run(ref_img_path, template_path) | |
| imageio.mimsave(outpath, res, fps=target_fps, quality=8, macro_block_size=1) | |
| print('save to %s' % outpath) | |
| if __name__ == "__main__": | |
| main() | |