from omegaconf import OmegaConf import os import torch import numpy as np from PIL import Image import time import gc import cv2 from src.models.unet_2d_condition import UNet2DConditionModel from src.models.pose_guider import PoseGuider from src.models.motion_encoder.encoder import MotEncoder from src.models.unet_3d import UNet3DConditionModel from src.models.mutual_self_attention import ReferenceAttentionControl from src.scheduler.scheduler_ddim import DDIMScheduler from src.liveportrait.motion_extractor import MotionExtractor from diffusers import AutoencoderKL from diffusers.image_processor import VaeImageProcessor from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor from collections import deque from threading import Lock, Thread from torchvision import transforms as T from einops import rearrange from src.utils.util import draw_keypoints, get_boxes import torch.nn.functional as F def map_device(device_or_str): return device_or_str if isinstance(device_or_str, torch.device) else torch.device(device_or_str) class PersonaLive: def __init__(self, args, device=None): cfg = OmegaConf.load(args.config_path) if(device is None): self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") else: self.device = map_device(device) self.temporal_adaptive_step = cfg.temporal_adaptive_step self.temporal_window_size = cfg.temporal_window_size if cfg.dtype == "fp16": self.numpy_dtype = np.float16 self.dtype = torch.float16 elif cfg.dtype == "fp32": self.numpy_dtype = np.float32 self.dtype = torch.float32 infer_config = OmegaConf.load(cfg.inference_config) sched_kwargs = OmegaConf.to_container( infer_config.noise_scheduler_kwargs ) self.num_inference_steps = cfg.num_inference_steps # initialize models self.pose_guider = PoseGuider().to(device=self.device, dtype=self.dtype) pose_guider_state_dict = torch.load(cfg.pose_guider_path, map_location="cpu") self.pose_guider.load_state_dict(pose_guider_state_dict) del pose_guider_state_dict self.motion_encoder = MotEncoder().to(dtype=self.dtype, device=self.device).eval() motion_encoder_state_dict = torch.load(cfg.motion_encoder_path, map_location="cpu") self.motion_encoder.load_state_dict(motion_encoder_state_dict) del motion_encoder_state_dict self.pose_encoder = MotionExtractor(num_kp=21).to(device=self.device, dtype=self.dtype).eval() pose_encoder_state_dict = torch.load(cfg.pose_encoder_path, map_location="cpu") self.pose_encoder.load_state_dict(pose_encoder_state_dict, strict=False) del pose_encoder_state_dict self.denoising_unet = UNet3DConditionModel.from_pretrained_2d( cfg.pretrained_base_model_path, "", subfolder="unet", unet_additional_kwargs=infer_config.unet_additional_kwargs, ).to(dtype=self.dtype, device=self.device) self.reference_unet = UNet2DConditionModel.from_pretrained( cfg.pretrained_base_model_path, subfolder="unet", ).to(dtype=self.dtype, device=self.device) reference_unet_state_dict = torch.load(cfg.reference_unet_weight_path, map_location="cpu") self.reference_unet.load_state_dict(reference_unet_state_dict) del reference_unet_state_dict self.denoising_unet.load_state_dict( torch.load(cfg.denoising_unet_path, map_location="cpu"), strict=False ) self.denoising_unet.load_state_dict( torch.load( cfg.temporal_module_path, map_location="cpu", ), strict=False, ) self.reference_control_writer = ReferenceAttentionControl( self.reference_unet, do_classifier_free_guidance=False, mode="write", batch_size=cfg.batch_size, fusion_blocks="full", ) self.reference_control_reader = ReferenceAttentionControl( self.denoising_unet, do_classifier_free_guidance=False, mode="read", batch_size=cfg.batch_size, fusion_blocks="full", cache_kv=True, ) self.vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to( device=self.device, dtype=self.dtype ) self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( cfg.image_encoder_path, ).to(device=self.device, dtype=self.dtype) # miscellaneous self.scheduler = DDIMScheduler(**sched_kwargs) self.timesteps = torch.tensor([999, 666, 333, 0], device=self.device).long() self.scheduler.set_step_length(333) self.generator = torch.Generator(self.device) self.generator.manual_seed(cfg.seed) self.batch_size = cfg.batch_size self.vae_scale_factor = 8 self.ref_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True ) self.clip_image_processor = CLIPImageProcessor() self.cond_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=True) self.cfg = cfg self.reset() torch.cuda.empty_cache() try: self.enable_xformers_memory_efficient_attention() except Exception as e: print("Failed to enable xformers:", e) def reset(self): self.first_frame = True self.motion_bank = None self.count = 0 self.num_khf = 0 self.latents_pile = deque([]) self.pose_pile = deque([]) self.motion_pile = deque([]) self.reference_control_writer.clear() self.reference_control_reader.clear() def enable_xformers_memory_efficient_attention(self): self.reference_unet.enable_xformers_memory_efficient_attention() self.denoising_unet.enable_xformers_memory_efficient_attention() def fast_resize(self, images, target_width, target_height) -> torch.Tensor: tgt_cond_tensor = F.interpolate( images, size=(target_width, target_height), mode="bilinear", align_corners=False, ) return tgt_cond_tensor @torch.no_grad() def fuse_reference(self, ref_image): # pil input clip_image = self.clip_image_processor.preprocess( ref_image, return_tensors="pt" ).pixel_values ref_image_tensor = self.ref_image_processor.preprocess( ref_image, height=self.cfg.reference_image_height, width=self.cfg.reference_image_width ) # (bs, c, width, height) clip_image_embeds = self.image_encoder( clip_image.to(self.image_encoder.device, dtype=self.image_encoder.dtype) ).image_embeds self.encoder_hidden_states = clip_image_embeds.unsqueeze(1) ref_image_tensor = ref_image_tensor.to( dtype=self.vae.dtype, device=self.vae.device ) self.ref_image_tensor = ref_image_tensor.squeeze(0) ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w) self.reference_unet( ref_image_latents.to(self.reference_unet.device), torch.zeros((self.batch_size,),dtype=self.dtype,device=self.reference_unet.device), encoder_hidden_states=self.encoder_hidden_states, return_dict=False, ) self.reference_control_reader.update(self.reference_control_writer) self.encoder_hidden_states = self.encoder_hidden_states.to(self.device) ref_cond_tensor = self.cond_image_processor.preprocess( ref_image, height=256, width=256 ).to(device=self.device, dtype=self.pose_encoder.dtype) # (1, c, h, w) self.ref_cond_tensor = ref_cond_tensor / 2 + 0.5 # to [0, 1] self.ref_image_latents = ref_image_latents padding_num = (self.temporal_adaptive_step - 1) * self.temporal_window_size init_latents = ref_image_latents.unsqueeze(2).repeat(1, 1, padding_num, 1, 1) noise = torch.randn_like(init_latents) init_timesteps = reversed(self.timesteps).repeat_interleave(self.temporal_window_size, dim=0) noisy_latents_first = self.scheduler.add_noise(init_latents, noise, init_timesteps[:padding_num]) for i in range(self.temporal_adaptive_step-1): l = i * self.temporal_window_size r = (i+1) * self.temporal_window_size self.latents_pile.append(noisy_latents_first[:,:,l:r]) def crop_face(self, image_pil, boxes): image = np.array(image_pil) left, top, right, bot = boxes face_patch = image[int(top) : int(bot), int(left) : int(right)] face_patch = Image.fromarray(face_patch).convert("RGB") return face_patch def crop_face_tensor(self, image_tensor, boxes): left, top, right, bot = boxes left, top, right, bottom = map(int, (left, top, right, bot)) face_patch = image_tensor[:, top:bottom, left:right] face_patch = F.interpolate( face_patch.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False, ) return face_patch def interpolate_tensors(self, a: torch.Tensor, b: torch.Tensor, num: int = 10) -> torch.Tensor: """ 在张量 a 和 b 之间线性插值。 输入 shape: (B, 1, D1, D2, ...) 输出 shape: (B, num, D1, D2, ...) """ if a.shape != b.shape: raise ValueError(f"Shape mismatch: a.shape={a.shape}, b.shape={b.shape}") B, _, *rest = a.shape # 插值系数 (num,) → reshape 成 (1, num, 1, 1, ...) alphas = torch.linspace(0, 1, num, device=a.device, dtype=a.dtype) view_shape = (1, num) + (1,) * len(rest) alphas = alphas.view(view_shape) # (1, num, 1, 1, ...) # 插值 (B, num, D1, D2, ...) result = (1 - alphas) * a + alphas * b return result def calculate_dis(self, A, B, threshold=10.): """ A: (b, f1, c1, c2) bank B: (b, f2, c1, c2) new data """ A_flat = A.view(A.size(1), -1).clone() B_flat = B.view(B.size(1), -1).clone() dist = torch.cdist(B_flat.to(torch.float32), A_flat.to(torch.float32), p=2) min_dist, min_idx = dist.min(dim=1) # (f2,) idx_to_add = torch.nonzero(min_dist[:1] > threshold, as_tuple=False).squeeze(1).tolist() if len(idx_to_add) > 0: # 有需要添加的元素 B_to_add = B[:, idx_to_add] # (1, k, c1, c2) A_new = torch.cat([A, B_to_add], dim=1) # (1, f1+k, c1, c2) else: A_new = A # 没有需要添加的 return idx_to_add, A_new, min_idx @torch.no_grad() def process_input(self, images): batch_size = self.batch_size device = self.device temporal_window_size = self.temporal_window_size temporal_adaptive_step = self.temporal_adaptive_step tgt_cond_tensor = self.fast_resize(images, 256, 256) tgt_cond_tensor = tgt_cond_tensor / 2 + 0.5 if self.first_frame: mot_bbox_param, kps_ref, kps_frame1, kps_dri = self.pose_encoder.interpolate_kps_online(self.ref_cond_tensor, tgt_cond_tensor, num_interp=12+1) self.kps_ref = kps_ref self.kps_frame1 = kps_frame1 else: mot_bbox_param, kps_dri = self.pose_encoder.get_kps(self.kps_ref, self.kps_frame1, tgt_cond_tensor) keypoints = draw_keypoints(mot_bbox_param, device=device) boxes = get_boxes(kps_dri) keypoints = rearrange(keypoints.unsqueeze(2), 'f c b h w -> b c f h w') keypoints = keypoints.to(device=device, dtype=self.pose_guider.dtype) if self.first_frame: ref_box = get_boxes(mot_bbox_param[:1]) ref_face = self.crop_face_tensor(self.ref_image_tensor, ref_box[0]) motion_face = [ref_face] for i, frame in enumerate(images): motion_face.append(self.crop_face_tensor(frame, boxes[i])) pose_cond_tensor = torch.cat(motion_face, dim=0).transpose(0, 1) pose_cond_tensor = pose_cond_tensor.unsqueeze(0) # pose_cond_tensor = pose_cond_tensor.to( # device=device, dtype=self.motion_encoder.dtype # ) motion_hidden_states = self.motion_encoder(pose_cond_tensor) ref_motion = motion_hidden_states[:, :1] dri_motion = motion_hidden_states[:, 1:] init_motion_hidden_states = self.interpolate_tensors(ref_motion, dri_motion[:,:1], num=12+1)[:,:-1] for i in range(temporal_adaptive_step-1): l = i * temporal_window_size r = (i+1) * temporal_window_size self.motion_pile.append(init_motion_hidden_states[:,l:r]) self.motion_pile.append(dri_motion) self.motion_bank = ref_motion else: motion_face = [] for i, frame in enumerate(images): motion_face.append(self.crop_face_tensor(frame, boxes[i])) pose_cond_tensor = torch.cat(motion_face, dim=0).transpose(0, 1) pose_cond_tensor = pose_cond_tensor.unsqueeze(0) motion_hidden_states = self.motion_encoder(pose_cond_tensor) self.motion_pile.append(motion_hidden_states) pose_fea = self.pose_guider(keypoints) if self.first_frame: for i in range(temporal_adaptive_step): l = i * temporal_window_size r = (i+1) * temporal_window_size self.pose_pile.append(pose_fea[:,:,l:r]) self.first_frame = False else: self.pose_pile.append(pose_fea) latents = self.ref_image_latents.unsqueeze(2).repeat(1, 1, temporal_window_size, 1, 1) noise = torch.randn_like(latents) latents = self.scheduler.add_noise(latents, noise, self.timesteps[:1]) self.latents_pile.append(latents) jump = 1 motion_hidden_state = torch.cat(list(self.motion_pile), dim=1) pose_cond_fea=torch.cat(list(self.pose_pile), dim=2) idx_to_add = [] if self.count > 8: idx_to_add, self.motion_bank, idx_his = self.calculate_dis(self.motion_bank, motion_hidden_state, threshold=17.) latents_model_input = torch.cat(list(self.latents_pile), dim=2) for j in range(jump): timesteps = reversed(self.timesteps[j::jump]).repeat_interleave(temporal_window_size, dim=0) timesteps = torch.stack([timesteps] * batch_size)#.to(device) timesteps = rearrange(timesteps, 'b f -> (b f)') noise_pred = self.denoising_unet( latents_model_input, timesteps, encoder_hidden_states=[self.encoder_hidden_states, motion_hidden_state], pose_cond_fea=pose_cond_fea, return_dict=False, )[0] clip_length = noise_pred.shape[2] mid_noise_pred = rearrange(noise_pred, 'b c f h w -> (b f) c h w') mid_latents = rearrange(latents_model_input, 'b c f h w -> (b f) c h w') latents_model_input, pred_original_sample = self.scheduler.step( mid_noise_pred, timesteps, mid_latents, generator=self.generator, return_dict=False ) latents_model_input = rearrange(latents_model_input, '(b f) c h w -> b c f h w', f=clip_length) pred_original_sample = rearrange(pred_original_sample, '(b f) c h w -> b c f h w', f=clip_length) latents_model_input = torch.cat([ pred_original_sample[:,:,:temporal_window_size], latents_model_input[:,:,temporal_window_size:]], dim=2) latents_model_input = latents_model_input.to(dtype=self.dtype) if len(idx_to_add) > 0 and self.num_khf < 3: self.reference_control_writer.clear() self.reference_unet( pred_original_sample[:,:,0].to(self.reference_unet.dtype), torch.zeros((batch_size,),dtype=self.dtype,device=self.reference_unet.device), encoder_hidden_states=self.encoder_hidden_states, return_dict=False, ) self.reference_control_reader.update_hkf(self.reference_control_writer) print('add_keyframes') self.num_khf += 1 for i in range(len(self.latents_pile)): self.latents_pile[i] = latents_model_input[:, :, i * temporal_adaptive_step : (i + 1) * temporal_adaptive_step, :, :] self.pose_pile.popleft() self.motion_pile.popleft() latents = self.latents_pile.popleft() latents = 1 / 0.18215 * latents latents = rearrange(latents, "b c f h w -> (b f) c h w") video = self.vae.decode(latents).sample video = rearrange(video, "b c h w -> b h w c") video = (video / 2 + 0.5).clamp(0, 1) video = video.cpu().numpy() self.count += 1 return video