Spaces:
Runtime error
Runtime error
File size: 4,460 Bytes
67ddbf8 7a05005 f5651ba 67ddbf8 f5651ba 67ddbf8 7a05005 67ddbf8 7a05005 8349be9 7a05005 67ddbf8 8349be9 67ddbf8 8349be9 67ddbf8 7a05005 67ddbf8 f5651ba 67ddbf8 7a05005 67ddbf8 7a05005 67ddbf8 7a05005 67ddbf8 7a05005 67ddbf8 7a05005 67ddbf8 7a05005 67ddbf8 7a05005 67ddbf8 7a05005 67ddbf8 7a05005 67ddbf8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import sys
from omegaconf import OmegaConf
import torch
from diffusers import AutoencoderKL, DDIMScheduler
from latentsync.models.unet import UNet3DConditionModel
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import MODELS_DIR
from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
from accelerate.utils import set_seed
from latentsync.whisper.audio2feature import Audio2Feature
from DeepCache import DeepCacheSDHelper
def main(config, args):
if not os.path.exists(args.video_path):
raise RuntimeError(f"Video path '{args.video_path}' not found")
if not os.path.exists(args.audio_path):
raise RuntimeError(f"Audio path '{args.audio_path}' not found")
# Check if the GPU supports float16
is_fp16_supported = (
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] > 7
)
dtype = torch.float16 if is_fp16_supported else torch.float32
print(f"Input video path: {args.video_path}")
print(f"Input audio path: {args.audio_path}")
print(f"Loaded checkpoint path: {args.inference_ckpt_path}")
scheduler = DDIMScheduler.from_pretrained("configs")
if config.model.cross_attention_dim == 768:
whisper_model_path = "small"
elif config.model.cross_attention_dim == 384:
whisper_model_path = "tiny"
else:
raise NotImplementedError("cross_attention_dim must be 768 or 384")
audio_encoder = Audio2Feature(
model_path=whisper_model_path,
device="cuda",
num_frames=config.data.num_frames,
audio_feat_length=config.data.audio_feat_length,
)
vae = AutoencoderKL.from_pretrained(
"stabilityai/sd-vae-ft-mse", torch_dtype=dtype, cache_dir=MODELS_DIR
)
vae.config.scaling_factor = 0.18215
vae.config.shift_factor = 0
unet, _ = UNet3DConditionModel.from_pretrained(
OmegaConf.to_container(config.model),
args.inference_ckpt_path,
device="cpu",
)
unet = unet.to(dtype=dtype)
pipeline = LipsyncPipeline(
vae=vae,
audio_encoder=audio_encoder,
unet=unet,
scheduler=scheduler,
).to("cuda")
# use DeepCache
if args.enable_deepcache:
helper = DeepCacheSDHelper(pipe=pipeline)
helper.set_params(cache_interval=3, cache_branch_id=0)
helper.enable()
if args.seed != -1:
set_seed(args.seed)
else:
torch.seed()
print(f"Initial seed: {torch.initial_seed()}")
pipeline(
video_path=args.video_path,
audio_path=args.audio_path,
video_out_path=args.video_out_path,
num_frames=config.data.num_frames,
num_inference_steps=args.inference_steps,
guidance_scale=args.guidance_scale,
weight_dtype=dtype,
width=config.data.resolution,
height=config.data.resolution,
mask_image_path=config.data.mask_image_path,
temp_dir=args.temp_dir,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--unet_config_path", type=str, default="configs/unet.yaml")
parser.add_argument("--inference_ckpt_path", type=str, required=True)
parser.add_argument("--video_path", type=str, required=True)
parser.add_argument("--audio_path", type=str, required=True)
parser.add_argument("--video_out_path", type=str, required=True)
parser.add_argument("--inference_steps", type=int, default=20)
parser.add_argument("--guidance_scale", type=float, default=1.0)
parser.add_argument("--temp_dir", type=str, default="temp")
parser.add_argument("--seed", type=int, default=1247)
parser.add_argument("--enable_deepcache", action="store_true")
args = parser.parse_args()
config = OmegaConf.load(args.unet_config_path)
main(config, args)
|