Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import torch.utils.checkpoint | |
| from PIL import Image | |
| import numpy as np | |
| from omegaconf import OmegaConf | |
| from tqdm import tqdm | |
| import cv2 | |
| from diffusers import AutoencoderKLTemporalDecoder | |
| from diffusers.schedulers import EulerDiscreteScheduler | |
| from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatureExtractor | |
| from src.utils.util import save_videos_grid, seed_everything | |
| from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor | |
| from src.models.base.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel, add_ip_adapters | |
| from src.pipelines.pipeline_sonic import SonicPipeline | |
| from src.models.audio_adapter.audio_proj import AudioProjModel | |
| from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel | |
| from src.utils.RIFE.RIFE_HDv3 import RIFEModel | |
| from src.dataset.face_align.align import AlignImage | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| def test( | |
| pipe, | |
| config, | |
| wav_enc, | |
| audio_pe, | |
| audio2bucket, | |
| image_encoder, | |
| width, | |
| height, | |
| batch, | |
| ): | |
| """Generate a video tensor for the given batch.""" | |
| for k, v in batch.items(): | |
| if isinstance(v, torch.Tensor): | |
| batch[k] = v.unsqueeze(0).to(pipe.device).float() | |
| ref_img = batch['ref_img'] | |
| clip_img = batch['clip_images'] | |
| face_mask = batch['face_mask'] | |
| image_embeds = image_encoder(clip_img).image_embeds | |
| audio_feature = batch['audio_feature'] | |
| audio_len = batch['audio_len'] | |
| step = int(config.step) | |
| window = 3000 | |
| audio_prompts = [] | |
| last_audio_prompts = [] | |
| for i in range(0, audio_feature.shape[-1], window): | |
| audio_prompt = wav_enc.encoder(audio_feature[:, :, i:i+window], output_hidden_states=True).hidden_states | |
| last_audio_prompt = wav_enc.encoder(audio_feature[:, :, i:i+window]).last_hidden_state | |
| last_audio_prompt = last_audio_prompt.unsqueeze(-2) | |
| audio_prompt = torch.stack(audio_prompt, dim=2) | |
| audio_prompts.append(audio_prompt) | |
| last_audio_prompts.append(last_audio_prompt) | |
| audio_prompts = torch.cat(audio_prompts, dim=1) | |
| audio_prompts = audio_prompts[:, :audio_len*2] | |
| audio_prompts = torch.cat([ | |
| torch.zeros_like(audio_prompts[:, :4]), | |
| audio_prompts, | |
| torch.zeros_like(audio_prompts[:, :6]) | |
| ], 1) | |
| last_audio_prompts = torch.cat(last_audio_prompts, dim=1) | |
| last_audio_prompts = last_audio_prompts[:, :audio_len*2] | |
| last_audio_prompts = torch.cat([ | |
| torch.zeros_like(last_audio_prompts[:, :24]), | |
| last_audio_prompts, | |
| torch.zeros_like(last_audio_prompts[:, :26]) | |
| ], 1) | |
| ref_tensor_list = [] | |
| audio_tensor_list = [] | |
| uncond_audio_tensor_list = [] | |
| motion_buckets = [] | |
| for i in tqdm(range(audio_len//step), ncols=0): | |
| audio_clip = audio_prompts[:, i*2*step:i*2*step+10].unsqueeze(0) | |
| audio_clip_for_bucket = last_audio_prompts[:, i*2*step:i*2*step+50].unsqueeze(0) | |
| motion_bucket = audio2bucket(audio_clip_for_bucket, image_embeds) | |
| motion_bucket = motion_bucket * 16 + 16 | |
| motion_buckets.append(motion_bucket[0]) | |
| cond_audio_clip = audio_pe(audio_clip).squeeze(0) | |
| uncond_audio_clip = audio_pe(torch.zeros_like(audio_clip)).squeeze(0) | |
| ref_tensor_list.append(ref_img[0]) | |
| audio_tensor_list.append(cond_audio_clip[0]) | |
| uncond_audio_tensor_list.append(uncond_audio_clip[0]) | |
| video = pipe( | |
| ref_img, | |
| clip_img, | |
| face_mask, | |
| audio_tensor_list, | |
| uncond_audio_tensor_list, | |
| motion_buckets, | |
| height=height, | |
| width=width, | |
| num_frames=len(audio_tensor_list), | |
| decode_chunk_size=config.decode_chunk_size, | |
| motion_bucket_scale=config.motion_bucket_scale, | |
| fps=config.fps, | |
| noise_aug_strength=config.noise_aug_strength, | |
| min_guidance_scale1=config.min_appearance_guidance_scale, | |
| max_guidance_scale1=config.max_appearance_guidance_scale, | |
| min_guidance_scale2=config.audio_guidance_scale, | |
| max_guidance_scale2=config.audio_guidance_scale, | |
| overlap=config.overlap, | |
| shift_offset=config.shift_offset, | |
| frames_per_batch=config.n_sample_frames, | |
| num_inference_steps=config.num_inference_steps, | |
| i2i_noise_strength=config.i2i_noise_strength | |
| ).frames | |
| video = (video * 0.5 + 0.5).clamp(0, 1) | |
| video = torch.cat([video.to(pipe.device)], dim=0).cpu() | |
| return video | |
| class Sonic: | |
| """High-level interface for the Sonic portrait animation pipeline.""" | |
| config_file = os.path.join(BASE_DIR, 'config/inference/sonic.yaml') | |
| config = OmegaConf.load(config_file) | |
| def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True): | |
| config = self.config | |
| config.use_interframe = enable_interpolate_frame | |
| device = f'cuda:{device_id}' if device_id > -1 else 'cpu' | |
| self.device = device | |
| config.pretrained_model_name_or_path = os.path.join(BASE_DIR, config.pretrained_model_name_or_path) | |
| vae = AutoencoderKLTemporalDecoder.from_pretrained( | |
| config.pretrained_model_name_or_path, subfolder='vae', variant='fp16') | |
| val_noise_scheduler = EulerDiscreteScheduler.from_pretrained( | |
| config.pretrained_model_name_or_path, subfolder='scheduler') | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
| config.pretrained_model_name_or_path, subfolder='image_encoder', variant='fp16') | |
| unet = UNetSpatioTemporalConditionModel.from_pretrained( | |
| config.pretrained_model_name_or_path, subfolder='unet', variant='fp16') | |
| add_ip_adapters(unet, [32], [config.ip_audio_scale]) | |
| audio2token = AudioProjModel(seq_len=10, blocks=5, channels=384, intermediate_dim=1024, | |
| output_dim=1024, context_tokens=32).to(device) | |
| audio2bucket = Audio2bucketModel(seq_len=50, blocks=1, channels=384, clip_channels=1024, | |
| intermediate_dim=1024, output_dim=1, context_tokens=2).to(device) | |
| unet.load_state_dict( | |
| torch.load(os.path.join(BASE_DIR, config.unet_checkpoint_path), map_location='cpu'), strict=True) | |
| audio2token.load_state_dict( | |
| torch.load(os.path.join(BASE_DIR, config.audio2token_checkpoint_path), map_location='cpu'), strict=True) | |
| audio2bucket.load_state_dict( | |
| torch.load(os.path.join(BASE_DIR, config.audio2bucket_checkpoint_path), map_location='cpu'), strict=True) | |
| dtype_map = {'fp16': torch.float16, 'fp32': torch.float32, 'bf16': torch.bfloat16} | |
| weight_dtype = dtype_map.get(config.weight_dtype) | |
| if weight_dtype is None: | |
| raise ValueError(f"Unsupported weight dtype: {config.weight_dtype}") | |
| whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')).to(device).eval() | |
| whisper.requires_grad_(False) | |
| self.feature_extractor = AutoFeatureExtractor.from_pretrained( | |
| os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')) | |
| self.face_det = AlignImage(device, det_path=os.path.join(BASE_DIR, 'checkpoints/yoloface_v5m.pt')) | |
| if config.use_interframe: | |
| self.rife = RIFEModel(device=device) | |
| self.rife.load_model(os.path.join(BASE_DIR, 'checkpoints', 'RIFE/')) | |
| image_encoder.to(weight_dtype) | |
| vae.to(weight_dtype) | |
| unet.to(weight_dtype) | |
| pipe = SonicPipeline( | |
| unet=unet, image_encoder=image_encoder, vae=vae, scheduler=val_noise_scheduler) | |
| self.pipe = pipe.to(device=device, dtype=weight_dtype) | |
| self.whisper = whisper | |
| self.audio2token = audio2token | |
| self.audio2bucket = audio2bucket | |
| self.image_encoder = image_encoder | |
| print('Sonic initialization complete.') | |
| def preprocess(self, image_path: str, expand_ratio: float = 1.0): | |
| face_image = cv2.imread(image_path) | |
| h, w = face_image.shape[:2] | |
| _, _, bboxes = self.face_det(face_image, maxface=True) | |
| face_num = len(bboxes) | |
| bbox_s = [] | |
| if face_num > 0: | |
| x1, y1, ww, hh = bboxes[0] | |
| x2, y2 = x1 + ww, y1 + hh | |
| bbox_s = process_bbox((x1, y1, x2, y2), expand_radio=expand_ratio, height=h, width=w) | |
| return {'face_num': face_num, 'crop_bbox': bbox_s} | |
| def crop_image(self, input_image_path: str, output_image_path: str, crop_bbox): | |
| face_image = cv2.imread(input_image_path) | |
| crop_img = face_image[crop_bbox[1]:crop_bbox[3], crop_bbox[0]:crop_bbox[2]] | |
| cv2.imwrite(output_image_path, crop_img) | |
| def process(self, image_path, audio_path, output_path, min_resolution=512, | |
| inference_steps=25, dynamic_scale=1.0, keep_resolution=False, seed=None): | |
| config = self.config | |
| device = self.device | |
| pipe = self.pipe | |
| whisper = self.whisper | |
| audio2token = self.audio2token | |
| audio2bucket = self.audio2bucket | |
| image_encoder = self.image_encoder | |
| if seed is not None: | |
| config.seed = seed | |
| seed_everything(config.seed) | |
| config.num_inference_steps = inference_steps | |
| config.frame_num = config.fps * 60 | |
| config.motion_bucket_scale = dynamic_scale | |
| video_path = output_path.replace('.mp4', '_noaudio.mp4') | |
| audio_video_path = output_path | |
| imSrc_ = Image.open(image_path).convert('RGB') | |
| raw_w, raw_h = imSrc_.size | |
| test_data = image_audio_to_tensor( | |
| self.face_det, self.feature_extractor, image_path, audio_path, | |
| limit=config.frame_num, image_size=min_resolution, area=config.area) | |
| if test_data is None: | |
| return -1 | |
| height, width = test_data['ref_img'].shape[-2:] | |
| resolution = f"{width}x{height}" if not keep_resolution else f"{raw_w//2*2}x{raw_h//2*2}" | |
| video = test(pipe, config, wav_enc=whisper, audio_pe=audio2token, | |
| audio2bucket=audio2bucket, image_encoder=image_encoder, | |
| width=width, height=height, batch=test_data) | |
| if config.use_interframe: | |
| out = video.to(device) | |
| results = [] | |
| for idx in tqdm(range(out.shape[2]-1), ncols=0): | |
| I1 = out[:, :, idx] | |
| I2 = out[:, :, idx+1] | |
| mid = self.rife.inference(I1, I2).clamp(0,1).detach() | |
| results.extend([out[:, :, idx], mid]) | |
| results.append(out[:, :, -1]) | |
| video = torch.stack(results, 2).cpu() | |
| save_videos_grid(video, video_path, n_rows=video.shape[0], fps=config.fps * (2 if config.use_interframe else 1)) | |
| os.system(f"ffmpeg -i '{video_path}' -i '{audio_path}' -s {resolution} -vcodec libx264 -acodec aac -crf 18 -shortest '{audio_video_path}' -y; rm '{video_path}'") | |
| return 0 | |