|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from omegaconf import OmegaConf |
|
|
import torch |
|
|
from diffusers import AutoencoderKL, DDIMScheduler |
|
|
from latentsync.models.unet import UNet3DConditionModel |
|
|
from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline |
|
|
|
|
|
from accelerate.utils import set_seed |
|
|
from latentsync.whisper.audio2feature import Audio2Feature |
|
|
|
|
|
def main(video_path, audio_path, video_out_path="./outputs/outvideo.mp4",unet_ckpt_path="./checkpoints/latentsync/latentsync_unet.pt",vae_path="./checkpoints/sd-vae-ft-mse",unet_config_path="configs/unet/second_stage.yaml", guidance_scale=1.0, seed=1247): |
|
|
print(f"Input video path: {video_path}") |
|
|
print(f"Input audio path: {audio_path}") |
|
|
print(f"Loaded unet checkpoint path: {unet_ckpt_path}") |
|
|
config = OmegaConf.load(unet_config_path) |
|
|
scheduler = DDIMScheduler.from_pretrained("configs") |
|
|
|
|
|
if config.model.cross_attention_dim == 768: |
|
|
whisper_model_path = "checkpoints/whisper/small.pt" |
|
|
elif config.model.cross_attention_dim == 384: |
|
|
whisper_model_path = "checkpoints/whisper/tiny.pt" |
|
|
else: |
|
|
raise NotImplementedError("cross_attention_dim must be 768 or 384") |
|
|
|
|
|
audio_encoder = Audio2Feature(model_path=whisper_model_path, device="cuda", num_frames=config.data.num_frames) |
|
|
|
|
|
vae = AutoencoderKL.from_pretrained(vae_path, torch_dtype=torch.float16) |
|
|
vae.config.scaling_factor = 0.18215 |
|
|
vae.config.shift_factor = 0 |
|
|
|
|
|
unet, _ = UNet3DConditionModel.from_pretrained( |
|
|
OmegaConf.to_container(config.model), |
|
|
unet_ckpt_path, |
|
|
device="cpu", |
|
|
) |
|
|
|
|
|
unet = unet.to(dtype=torch.float16) |
|
|
|
|
|
pipeline = LipsyncPipeline( |
|
|
vae=vae, |
|
|
audio_encoder=audio_encoder, |
|
|
unet=unet, |
|
|
scheduler=scheduler, |
|
|
).to("cuda") |
|
|
|
|
|
if seed != -1: |
|
|
set_seed(seed) |
|
|
else: |
|
|
torch.seed() |
|
|
|
|
|
print(f"Initial seed: {torch.initial_seed()}") |
|
|
|
|
|
pipeline( |
|
|
video_path=video_path, |
|
|
audio_path=audio_path, |
|
|
video_out_path=video_out_path, |
|
|
video_mask_path=video_out_path.replace(".mp4", "_mask.mp4"), |
|
|
num_frames=config.data.num_frames, |
|
|
num_inference_steps=config.run.inference_steps, |
|
|
guidance_scale=guidance_scale, |
|
|
weight_dtype=torch.float16, |
|
|
width=config.data.resolution, |
|
|
height=config.data.resolution, |
|
|
) |
|
|
|
|
|
|
|
|
import os |
|
|
def get_videos_from_path(path): |
|
|
"""Get all video files from a path, returns only filenames without extension""" |
|
|
video_names = [] |
|
|
|
|
|
try: |
|
|
|
|
|
files = os.listdir(path) |
|
|
|
|
|
|
|
|
for file in files: |
|
|
if file.lower().endswith('.mp4'): |
|
|
|
|
|
name_without_ext = os.path.splitext(file)[0] |
|
|
video_names.append(name_without_ext) |
|
|
except FileNotFoundError: |
|
|
print(f"Directory {path} not found") |
|
|
return [] |
|
|
|
|
|
return video_names |
|
|
|
|
|
def get_audios_from_path(path): |
|
|
"""Get all audio files from a path, returns only filenames without extension""" |
|
|
audio_names = [] |
|
|
|
|
|
try: |
|
|
|
|
|
files = os.listdir(path) |
|
|
|
|
|
|
|
|
for file in files: |
|
|
if file.lower().endswith('.wav'): |
|
|
|
|
|
name_without_ext = os.path.splitext(file)[0] |
|
|
audio_names.append(name_without_ext) |
|
|
except FileNotFoundError: |
|
|
print(f"Directory {path} not found") |
|
|
return [] |
|
|
|
|
|
return audio_names |
|
|
if __name__ == "__main__": |
|
|
file_path = "./assets/edge_cases" |
|
|
videos = get_videos_from_path(file_path) |
|
|
audios = get_audios_from_path(file_path) |
|
|
for audio in audios: |
|
|
for video in videos: |
|
|
print(video,audio) |
|
|
output_path = "./outputs/" + video + "_" + audio + ".mp4" |
|
|
try: |
|
|
main(f"./assets/edge_cases/{video}.mp4", f"./assets/edge_cases/{audio}.wav", output_path) |
|
|
except: |
|
|
print("Couldn't detect faces") |
|
|
|