Initial upload of directory
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +21 -0
- .modal_lipsync_serve.py.swo +0 -0
- .modal_lipsync_serve.py.swp +0 -0
- README.md +21 -10
- assets/edge_cases/3d_model.mp4 +3 -0
- assets/edge_cases/anime.mp4 +3 -0
- assets/edge_cases/cartoon.mp4 +3 -0
- assets/edge_cases/full_body.mp4 +3 -0
- assets/edge_cases/lecun_belinda.wav +3 -0
- assets/edge_cases/lecun_freeman.wav +3 -0
- assets/edge_cases/obama_side.mp4 +3 -0
- assets/edge_cases/sculpture.mp4 +3 -0
- checkpoints/whisper/small.pt +3 -0
- dummy_inference.py +82 -0
- inference.py +131 -0
- latentsync/__init__.py +2 -0
- latentsync/models/__init__.py +20 -0
- latentsync/models/attention.py +492 -0
- latentsync/models/motion_module.py +332 -0
- latentsync/models/resnet.py +234 -0
- latentsync/models/syncnet.py +233 -0
- latentsync/models/syncnet_wav2lip.py +90 -0
- latentsync/models/unet.py +528 -0
- latentsync/models/unet_blocks.py +903 -0
- latentsync/models/utils.py +19 -0
- latentsync/pipelines/__init__.py +6 -0
- latentsync/pipelines/lipsync_pipeline.py +470 -0
- latentsync/utils/__init__.py +16 -0
- latentsync/utils/affine_transform.py +138 -0
- latentsync/utils/audio.py +217 -0
- latentsync/utils/av_reader.py +157 -0
- latentsync/utils/image_processor.py +349 -0
- latentsync/utils/mask.png +0 -0
- latentsync/utils/util.py +365 -0
- latentsync/whisper/__init__.py +6 -0
- latentsync/whisper/audio2feature.py +166 -0
- latentsync/whisper/whisper/__init__.py +119 -0
- latentsync/whisper/whisper/__main__.py +4 -0
- latentsync/whisper/whisper/assets/gpt2/merges.txt +0 -0
- latentsync/whisper/whisper/assets/gpt2/special_tokens_map.json +1 -0
- latentsync/whisper/whisper/assets/gpt2/tokenizer_config.json +1 -0
- latentsync/whisper/whisper/assets/gpt2/vocab.json +0 -0
- latentsync/whisper/whisper/assets/mel_filters.npz +3 -0
- latentsync/whisper/whisper/assets/multilingual/added_tokens.json +1 -0
- latentsync/whisper/whisper/assets/multilingual/merges.txt +0 -0
- latentsync/whisper/whisper/assets/multilingual/special_tokens_map.json +1 -0
- latentsync/whisper/whisper/assets/multilingual/tokenizer_config.json +1 -0
- latentsync/whisper/whisper/assets/multilingual/vocab.json +0 -0
- latentsync/whisper/whisper/audio.py +132 -0
- latentsync/whisper/whisper/decoding.py +729 -0
.gitattributes
CHANGED
|
@@ -41,3 +41,24 @@ assets/demo3_audio.wav filter=lfs diff=lfs merge=lfs -text
|
|
| 41 |
assets/demo3_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 42 |
assets/framework.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
temp/video.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
assets/demo3_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 42 |
assets/framework.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
temp/video.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
assets/edge_cases/3d_model.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
assets/edge_cases/anime.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
assets/edge_cases/cartoon.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
assets/edge_cases/full_body.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
assets/edge_cases/lecun_belinda.wav filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
assets/edge_cases/lecun_freeman.wav filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
assets/edge_cases/obama_side.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
assets/edge_cases/sculpture.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
local_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
outputs/cartoon_lecun_belinda.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
outputs/cartoon_lecun_freeman.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
outputs/obama_side_lecun_belinda.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
outputs/obama_side_lecun_freeman.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
outputs/outvideo.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
outputs/sculpture_lecun_belinda.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
outputs/sculpture_lecun_freeman.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
outvideo.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
remote_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
temp/audio.wav filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
temp_audio.wav filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
temp_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
.modal_lipsync_serve.py.swo
ADDED
|
Binary file (16.4 kB). View file
|
|
|
.modal_lipsync_serve.py.swp
ADDED
|
Binary file (20.5 kB). View file
|
|
|
README.md
CHANGED
|
@@ -1,15 +1,26 @@
|
|
| 1 |
-
#
|
| 2 |
|
| 3 |
-
This is a small repo containing all the required files to run inference for LatentSync1.5
|
| 4 |
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
add MuseTalk checkpoints
|
| 7 |
add LatentSync16 checkpoints
|
| 8 |
|
| 9 |
-
## Installation
|
| 10 |
-
- clone the repo
|
| 11 |
-
- On debian based systems run bash debian_setup.sh
|
| 12 |
-
|
| 13 |
-
## Run
|
| 14 |
-
- for inference modify the scritpts/inference.py file add your video and audio path
|
| 15 |
-
- run with uv run python -m scripts.inference
|
|
|
|
| 1 |
+
#OpenLipSync
|
| 2 |
|
| 3 |
+
This is a small repo containing all the required files to run inference for LatentSync1.5
|
| 4 |
|
| 5 |
+
Installation
|
| 6 |
+
- clone the repo
|
| 7 |
+
- On debian based systems run bash debian_setup.sh for both local and modal (remote inference)
|
| 8 |
+
- for remote inference with modal you must first create the volume by runing
|
| 9 |
+
```uv run modal run scripts/modal_download_extras.py``` and
|
| 10 |
+
```uv run modal run scriptsmodal_download_models.py```
|
| 11 |
+
Run
|
| 12 |
+
for local inference modify the inference.py file at the root of the directory.
|
| 13 |
+
(add the path of your video and file). Then run with ```uv run inference.py```
|
| 14 |
+
|
| 15 |
+
for remote inference modify the modal_lipsync_inference.py file at the root of the directory.
|
| 16 |
+
(add the path of your video and file). Then run with ```uv run inference.py```
|
| 17 |
+
|
| 18 |
+
for remote inference modify the modal_lipsync_inference.py file at the root of the directory.
|
| 19 |
+
(add the path of your video and file). Then run with ```uv run modal run modal_lipsync_inference.py```
|
| 20 |
+
|
| 21 |
+
for remote inference with fastapi endpoints run ```uv run modal run modal_lipsync_serve.py```
|
| 22 |
+
|
| 23 |
+
TODO:
|
| 24 |
add MuseTalk checkpoints
|
| 25 |
add LatentSync16 checkpoints
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assets/edge_cases/3d_model.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b2d88bf3bac42cb5e3356c811670c04468ca471700693eb548553dc071f2a6bc
|
| 3 |
+
size 827929
|
assets/edge_cases/anime.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e2463d1a2d576f2d4029c9f12eaa97b64b7df0a965832e210ab0188103586da5
|
| 3 |
+
size 116889
|
assets/edge_cases/cartoon.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:24ac39f09cde19091e5fb94085353317bd83025dfa25aae950a9a2a95297cea1
|
| 3 |
+
size 628734
|
assets/edge_cases/full_body.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0a2532cabb4db6b62eff944400e2c7d4fa6ac08a51d538be8bbdffe9a58f5e57
|
| 3 |
+
size 399416
|
assets/edge_cases/lecun_belinda.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2285b59bd3c82907708c231a57db6c58a15e85a4d255ad643f51104da04484ff
|
| 3 |
+
size 647176
|
assets/edge_cases/lecun_freeman.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9283b35cd77ceed4bea7c31813575d52e8e166ba6e9b55609d4bc60fcf4c7a68
|
| 3 |
+
size 627516
|
assets/edge_cases/obama_side.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0339eaec4182bc681f72a851ea9347bc8a1e44ae0b4916c9c2b7b334419af29e
|
| 3 |
+
size 930835
|
assets/edge_cases/sculpture.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3aa2f7ad977c375b681ddcfd25ffac71e0122c7a1663024999a633ef104f7f14
|
| 3 |
+
size 459062
|
checkpoints/whisper/small.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794
|
| 3 |
+
size 483617219
|
dummy_inference.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
from omegaconf import OmegaConf
|
| 17 |
+
import torch
|
| 18 |
+
from diffusers import AutoencoderKL, DDIMScheduler
|
| 19 |
+
from latentsync.models.unet import UNet3DConditionModel
|
| 20 |
+
from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
|
| 21 |
+
#from diffusers.utils.import_utils import is_xformers_available
|
| 22 |
+
from accelerate.utils import set_seed
|
| 23 |
+
from latentsync.whisper.audio2feature import Audio2Feature
|
| 24 |
+
|
| 25 |
+
def main(video_path, audio_path, video_out_path="./outputs/outvideo.mp4",unet_ckpt_path="./checkpoints/latentsync/latentsync_unet.pt",vae_path="./checkpoints/sd-vae-ft-mse",unet_config_path="configs/unet/second_stage.yaml", guidance_scale=1.0, seed=1247):
|
| 26 |
+
print(f"Input video path: {video_path}")
|
| 27 |
+
print(f"Input audio path: {audio_path}")
|
| 28 |
+
print(f"Loaded unet checkpoint path: {unet_ckpt_path}")
|
| 29 |
+
config = OmegaConf.load(unet_config_path)
|
| 30 |
+
scheduler = DDIMScheduler.from_pretrained("configs")
|
| 31 |
+
|
| 32 |
+
if config.model.cross_attention_dim == 768:
|
| 33 |
+
whisper_model_path = "checkpoints/whisper/small.pt"
|
| 34 |
+
elif config.model.cross_attention_dim == 384:
|
| 35 |
+
whisper_model_path = "checkpoints/whisper/tiny.pt"
|
| 36 |
+
else:
|
| 37 |
+
raise NotImplementedError("cross_attention_dim must be 768 or 384")
|
| 38 |
+
|
| 39 |
+
audio_encoder = Audio2Feature(model_path=whisper_model_path, device="cuda", num_frames=config.data.num_frames)
|
| 40 |
+
|
| 41 |
+
vae = AutoencoderKL.from_pretrained(vae_path, torch_dtype=torch.float16)
|
| 42 |
+
vae.config.scaling_factor = 0.18215
|
| 43 |
+
vae.config.shift_factor = 0
|
| 44 |
+
|
| 45 |
+
unet, _ = UNet3DConditionModel.from_pretrained(
|
| 46 |
+
OmegaConf.to_container(config.model),
|
| 47 |
+
unet_ckpt_path, # load checkpoint
|
| 48 |
+
device="cpu",
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
unet = unet.to(dtype=torch.float16)
|
| 52 |
+
|
| 53 |
+
pipeline = LipsyncPipeline(
|
| 54 |
+
vae=vae,
|
| 55 |
+
audio_encoder=audio_encoder,
|
| 56 |
+
unet=unet,
|
| 57 |
+
scheduler=scheduler,
|
| 58 |
+
).to("cuda")
|
| 59 |
+
|
| 60 |
+
if seed != -1:
|
| 61 |
+
set_seed(seed)
|
| 62 |
+
else:
|
| 63 |
+
torch.seed()
|
| 64 |
+
|
| 65 |
+
print(f"Initial seed: {torch.initial_seed()}")
|
| 66 |
+
|
| 67 |
+
pipeline(
|
| 68 |
+
video_path=video_path,
|
| 69 |
+
audio_path=audio_path,
|
| 70 |
+
video_out_path=video_out_path,
|
| 71 |
+
video_mask_path=video_out_path.replace(".mp4", "_mask.mp4"),
|
| 72 |
+
num_frames=config.data.num_frames,
|
| 73 |
+
num_inference_steps=config.run.inference_steps,
|
| 74 |
+
guidance_scale=guidance_scale,
|
| 75 |
+
weight_dtype=torch.float16,
|
| 76 |
+
width=config.data.resolution,
|
| 77 |
+
height=config.data.resolution,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
if __name__ == "__main__":
|
| 82 |
+
main("./assets/demo2_video.mp4","./assets/demo1_audio.wav")
|
inference.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from omegaconf import OmegaConf
|
| 16 |
+
import torch
|
| 17 |
+
from diffusers import AutoencoderKL, DDIMScheduler
|
| 18 |
+
from latentsync.models.unet import UNet3DConditionModel
|
| 19 |
+
from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
|
| 20 |
+
#from diffusers.utils.import_utils import is_xformers_available
|
| 21 |
+
from accelerate.utils import set_seed
|
| 22 |
+
from latentsync.whisper.audio2feature import Audio2Feature
|
| 23 |
+
|
| 24 |
+
def main(video_path, audio_path, video_out_path="./outputs/outvideo.mp4",unet_ckpt_path="./checkpoints/latentsync/latentsync_unet.pt",vae_path="./checkpoints/sd-vae-ft-mse",unet_config_path="configs/unet/second_stage.yaml", guidance_scale=1.0, seed=1247):
|
| 25 |
+
print(f"Input video path: {video_path}")
|
| 26 |
+
print(f"Input audio path: {audio_path}")
|
| 27 |
+
print(f"Loaded unet checkpoint path: {unet_ckpt_path}")
|
| 28 |
+
config = OmegaConf.load(unet_config_path)
|
| 29 |
+
scheduler = DDIMScheduler.from_pretrained("configs")
|
| 30 |
+
|
| 31 |
+
if config.model.cross_attention_dim == 768:
|
| 32 |
+
whisper_model_path = "checkpoints/whisper/small.pt"
|
| 33 |
+
elif config.model.cross_attention_dim == 384:
|
| 34 |
+
whisper_model_path = "checkpoints/whisper/tiny.pt"
|
| 35 |
+
else:
|
| 36 |
+
raise NotImplementedError("cross_attention_dim must be 768 or 384")
|
| 37 |
+
|
| 38 |
+
audio_encoder = Audio2Feature(model_path=whisper_model_path, device="cuda", num_frames=config.data.num_frames)
|
| 39 |
+
|
| 40 |
+
vae = AutoencoderKL.from_pretrained(vae_path, torch_dtype=torch.float16)
|
| 41 |
+
vae.config.scaling_factor = 0.18215
|
| 42 |
+
vae.config.shift_factor = 0
|
| 43 |
+
|
| 44 |
+
unet, _ = UNet3DConditionModel.from_pretrained(
|
| 45 |
+
OmegaConf.to_container(config.model),
|
| 46 |
+
unet_ckpt_path, # load checkpoint
|
| 47 |
+
device="cpu",
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
unet = unet.to(dtype=torch.float16)
|
| 51 |
+
|
| 52 |
+
pipeline = LipsyncPipeline(
|
| 53 |
+
vae=vae,
|
| 54 |
+
audio_encoder=audio_encoder,
|
| 55 |
+
unet=unet,
|
| 56 |
+
scheduler=scheduler,
|
| 57 |
+
).to("cuda")
|
| 58 |
+
|
| 59 |
+
if seed != -1:
|
| 60 |
+
set_seed(seed)
|
| 61 |
+
else:
|
| 62 |
+
torch.seed()
|
| 63 |
+
|
| 64 |
+
print(f"Initial seed: {torch.initial_seed()}")
|
| 65 |
+
|
| 66 |
+
pipeline(
|
| 67 |
+
video_path=video_path,
|
| 68 |
+
audio_path=audio_path,
|
| 69 |
+
video_out_path=video_out_path,
|
| 70 |
+
video_mask_path=video_out_path.replace(".mp4", "_mask.mp4"),
|
| 71 |
+
num_frames=config.data.num_frames,
|
| 72 |
+
num_inference_steps=config.run.inference_steps,
|
| 73 |
+
guidance_scale=guidance_scale,
|
| 74 |
+
weight_dtype=torch.float16,
|
| 75 |
+
width=config.data.resolution,
|
| 76 |
+
height=config.data.resolution,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
import os
|
| 81 |
+
def get_videos_from_path(path):
|
| 82 |
+
"""Get all video files from a path, returns only filenames without extension"""
|
| 83 |
+
video_names = []
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
# List all files in the directory
|
| 87 |
+
files = os.listdir(path)
|
| 88 |
+
|
| 89 |
+
# Filter for mp4 files
|
| 90 |
+
for file in files:
|
| 91 |
+
if file.lower().endswith('.mp4'):
|
| 92 |
+
# Remove the extension
|
| 93 |
+
name_without_ext = os.path.splitext(file)[0]
|
| 94 |
+
video_names.append(name_without_ext)
|
| 95 |
+
except FileNotFoundError:
|
| 96 |
+
print(f"Directory {path} not found")
|
| 97 |
+
return []
|
| 98 |
+
|
| 99 |
+
return video_names
|
| 100 |
+
|
| 101 |
+
def get_audios_from_path(path):
|
| 102 |
+
"""Get all audio files from a path, returns only filenames without extension"""
|
| 103 |
+
audio_names = []
|
| 104 |
+
|
| 105 |
+
try:
|
| 106 |
+
# List all files in the directory
|
| 107 |
+
files = os.listdir(path)
|
| 108 |
+
|
| 109 |
+
# Filter for wav files
|
| 110 |
+
for file in files:
|
| 111 |
+
if file.lower().endswith('.wav'):
|
| 112 |
+
# Remove the extension
|
| 113 |
+
name_without_ext = os.path.splitext(file)[0]
|
| 114 |
+
audio_names.append(name_without_ext)
|
| 115 |
+
except FileNotFoundError:
|
| 116 |
+
print(f"Directory {path} not found")
|
| 117 |
+
return []
|
| 118 |
+
|
| 119 |
+
return audio_names
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
file_path = "./assets/edge_cases"
|
| 122 |
+
videos = get_videos_from_path(file_path) # all with extension .mp4 returns only the name without extension
|
| 123 |
+
audios = get_audios_from_path(file_path) # all with extension .wav returns only the name without extension
|
| 124 |
+
for audio in audios:
|
| 125 |
+
for video in videos:
|
| 126 |
+
print(video,audio)
|
| 127 |
+
output_path = "./outputs/" + video + "_" + audio + ".mp4"
|
| 128 |
+
try:
|
| 129 |
+
main(f"./assets/edge_cases/{video}.mp4", f"./assets/edge_cases/{audio}.wav", output_path)
|
| 130 |
+
except:
|
| 131 |
+
print("Couldn't detect faces")
|
latentsync/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# latentsync/__init__.py
|
| 2 |
+
from .pipelines.lipsync_pipeline import LipsyncPipeline
|
latentsync/models/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# latentsync/models/__init__.py
|
| 2 |
+
|
| 3 |
+
from .attention import *
|
| 4 |
+
from .resnet import *
|
| 5 |
+
#from .syncnet import *
|
| 6 |
+
from .syncnet_wav2lip import *
|
| 7 |
+
from .unet import *
|
| 8 |
+
from .unet_blocks import *
|
| 9 |
+
from .utils import *
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"Attention",
|
| 13 |
+
"ResNet",
|
| 14 |
+
"SyncNet",
|
| 15 |
+
"SyncNetWav2Lip",
|
| 16 |
+
"UNet",
|
| 17 |
+
"UNetBlocks",
|
| 18 |
+
"utils"
|
| 19 |
+
]
|
| 20 |
+
|
latentsync/models/attention.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from turtle import forward
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 12 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 13 |
+
from diffusers.utils import BaseOutput
|
| 14 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 15 |
+
from diffusers.models.attention import Attention as CrossAttention, FeedForward, AdaLayerNorm
|
| 16 |
+
|
| 17 |
+
from einops import rearrange, repeat
|
| 18 |
+
from .utils import zero_module
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class Transformer3DModelOutput(BaseOutput):
|
| 23 |
+
sample: torch.FloatTensor
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if is_xformers_available():
|
| 27 |
+
import xformers
|
| 28 |
+
import xformers.ops
|
| 29 |
+
else:
|
| 30 |
+
xformers = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Transformer3DModel(ModelMixin, ConfigMixin):
|
| 34 |
+
@register_to_config
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
num_attention_heads: int = 16,
|
| 38 |
+
attention_head_dim: int = 88,
|
| 39 |
+
in_channels: Optional[int] = None,
|
| 40 |
+
num_layers: int = 1,
|
| 41 |
+
dropout: float = 0.0,
|
| 42 |
+
norm_num_groups: int = 32,
|
| 43 |
+
cross_attention_dim: Optional[int] = None,
|
| 44 |
+
attention_bias: bool = False,
|
| 45 |
+
activation_fn: str = "geglu",
|
| 46 |
+
num_embeds_ada_norm: Optional[int] = None,
|
| 47 |
+
use_linear_projection: bool = False,
|
| 48 |
+
only_cross_attention: bool = False,
|
| 49 |
+
upcast_attention: bool = False,
|
| 50 |
+
use_motion_module: bool = False,
|
| 51 |
+
unet_use_cross_frame_attention=None,
|
| 52 |
+
unet_use_temporal_attention=None,
|
| 53 |
+
add_audio_layer=False,
|
| 54 |
+
audio_condition_method="cross_attn",
|
| 55 |
+
custom_audio_layer: bool = False,
|
| 56 |
+
):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.use_linear_projection = use_linear_projection
|
| 59 |
+
self.num_attention_heads = num_attention_heads
|
| 60 |
+
self.attention_head_dim = attention_head_dim
|
| 61 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 62 |
+
|
| 63 |
+
# Define input layers
|
| 64 |
+
self.in_channels = in_channels
|
| 65 |
+
|
| 66 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
| 67 |
+
if use_linear_projection:
|
| 68 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
| 69 |
+
else:
|
| 70 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
| 71 |
+
|
| 72 |
+
if not custom_audio_layer:
|
| 73 |
+
# Define transformers blocks
|
| 74 |
+
self.transformer_blocks = nn.ModuleList(
|
| 75 |
+
[
|
| 76 |
+
BasicTransformerBlock(
|
| 77 |
+
inner_dim,
|
| 78 |
+
num_attention_heads,
|
| 79 |
+
attention_head_dim,
|
| 80 |
+
dropout=dropout,
|
| 81 |
+
cross_attention_dim=cross_attention_dim,
|
| 82 |
+
activation_fn=activation_fn,
|
| 83 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
| 84 |
+
attention_bias=attention_bias,
|
| 85 |
+
only_cross_attention=only_cross_attention,
|
| 86 |
+
upcast_attention=upcast_attention,
|
| 87 |
+
use_motion_module=use_motion_module,
|
| 88 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 89 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 90 |
+
add_audio_layer=add_audio_layer,
|
| 91 |
+
custom_audio_layer=custom_audio_layer,
|
| 92 |
+
audio_condition_method=audio_condition_method,
|
| 93 |
+
)
|
| 94 |
+
for d in range(num_layers)
|
| 95 |
+
]
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
self.transformer_blocks = nn.ModuleList(
|
| 99 |
+
[
|
| 100 |
+
AudioTransformerBlock(
|
| 101 |
+
inner_dim,
|
| 102 |
+
num_attention_heads,
|
| 103 |
+
attention_head_dim,
|
| 104 |
+
dropout=dropout,
|
| 105 |
+
cross_attention_dim=cross_attention_dim,
|
| 106 |
+
activation_fn=activation_fn,
|
| 107 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
| 108 |
+
attention_bias=attention_bias,
|
| 109 |
+
only_cross_attention=only_cross_attention,
|
| 110 |
+
upcast_attention=upcast_attention,
|
| 111 |
+
use_motion_module=use_motion_module,
|
| 112 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 113 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 114 |
+
add_audio_layer=add_audio_layer,
|
| 115 |
+
)
|
| 116 |
+
for d in range(num_layers)
|
| 117 |
+
]
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# 4. Define output layers
|
| 121 |
+
if use_linear_projection:
|
| 122 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
| 123 |
+
else:
|
| 124 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
| 125 |
+
|
| 126 |
+
if custom_audio_layer:
|
| 127 |
+
self.proj_out = zero_module(self.proj_out)
|
| 128 |
+
|
| 129 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
| 130 |
+
# Input
|
| 131 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
| 132 |
+
video_length = hidden_states.shape[2]
|
| 133 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
| 134 |
+
|
| 135 |
+
# No need to do this for audio input, because different audio samples are independent
|
| 136 |
+
# encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
|
| 137 |
+
|
| 138 |
+
batch, channel, height, weight = hidden_states.shape
|
| 139 |
+
residual = hidden_states
|
| 140 |
+
|
| 141 |
+
hidden_states = self.norm(hidden_states)
|
| 142 |
+
if not self.use_linear_projection:
|
| 143 |
+
hidden_states = self.proj_in(hidden_states)
|
| 144 |
+
inner_dim = hidden_states.shape[1]
|
| 145 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
| 146 |
+
else:
|
| 147 |
+
inner_dim = hidden_states.shape[1]
|
| 148 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
| 149 |
+
hidden_states = self.proj_in(hidden_states)
|
| 150 |
+
|
| 151 |
+
# Blocks
|
| 152 |
+
for block in self.transformer_blocks:
|
| 153 |
+
hidden_states = block(
|
| 154 |
+
hidden_states,
|
| 155 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 156 |
+
timestep=timestep,
|
| 157 |
+
video_length=video_length,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Output
|
| 161 |
+
if not self.use_linear_projection:
|
| 162 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
| 163 |
+
hidden_states = self.proj_out(hidden_states)
|
| 164 |
+
else:
|
| 165 |
+
hidden_states = self.proj_out(hidden_states)
|
| 166 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
| 167 |
+
|
| 168 |
+
output = hidden_states + residual
|
| 169 |
+
|
| 170 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
| 171 |
+
if not return_dict:
|
| 172 |
+
return (output,)
|
| 173 |
+
|
| 174 |
+
return Transformer3DModelOutput(sample=output)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class BasicTransformerBlock(nn.Module):
|
| 178 |
+
def __init__(
|
| 179 |
+
self,
|
| 180 |
+
dim: int,
|
| 181 |
+
num_attention_heads: int,
|
| 182 |
+
attention_head_dim: int,
|
| 183 |
+
dropout=0.0,
|
| 184 |
+
cross_attention_dim: Optional[int] = None,
|
| 185 |
+
activation_fn: str = "geglu",
|
| 186 |
+
num_embeds_ada_norm: Optional[int] = None,
|
| 187 |
+
attention_bias: bool = False,
|
| 188 |
+
only_cross_attention: bool = False,
|
| 189 |
+
upcast_attention: bool = False,
|
| 190 |
+
use_motion_module: bool = False,
|
| 191 |
+
unet_use_cross_frame_attention=None,
|
| 192 |
+
unet_use_temporal_attention=None,
|
| 193 |
+
add_audio_layer=False,
|
| 194 |
+
custom_audio_layer=False,
|
| 195 |
+
audio_condition_method="cross_attn",
|
| 196 |
+
):
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.only_cross_attention = only_cross_attention
|
| 199 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
| 200 |
+
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
|
| 201 |
+
self.unet_use_temporal_attention = unet_use_temporal_attention
|
| 202 |
+
self.use_motion_module = use_motion_module
|
| 203 |
+
self.add_audio_layer = add_audio_layer
|
| 204 |
+
|
| 205 |
+
# SC-Attn
|
| 206 |
+
assert unet_use_cross_frame_attention is not None
|
| 207 |
+
if unet_use_cross_frame_attention:
|
| 208 |
+
raise NotImplementedError("SparseCausalAttention2D not implemented yet.")
|
| 209 |
+
else:
|
| 210 |
+
self.attn1 = CrossAttention(
|
| 211 |
+
query_dim=dim,
|
| 212 |
+
heads=num_attention_heads,
|
| 213 |
+
dim_head=attention_head_dim,
|
| 214 |
+
dropout=dropout,
|
| 215 |
+
bias=attention_bias,
|
| 216 |
+
upcast_attention=upcast_attention,
|
| 217 |
+
)
|
| 218 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
| 219 |
+
|
| 220 |
+
# Cross-Attn
|
| 221 |
+
if add_audio_layer and audio_condition_method == "cross_attn" and not custom_audio_layer:
|
| 222 |
+
self.audio_cross_attn = AudioCrossAttn(
|
| 223 |
+
dim=dim,
|
| 224 |
+
cross_attention_dim=cross_attention_dim,
|
| 225 |
+
num_attention_heads=num_attention_heads,
|
| 226 |
+
attention_head_dim=attention_head_dim,
|
| 227 |
+
dropout=dropout,
|
| 228 |
+
attention_bias=attention_bias,
|
| 229 |
+
upcast_attention=upcast_attention,
|
| 230 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
| 231 |
+
use_ada_layer_norm=self.use_ada_layer_norm,
|
| 232 |
+
zero_proj_out=False,
|
| 233 |
+
)
|
| 234 |
+
else:
|
| 235 |
+
self.audio_cross_attn = None
|
| 236 |
+
|
| 237 |
+
# Feed-forward
|
| 238 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
| 239 |
+
self.norm3 = nn.LayerNorm(dim)
|
| 240 |
+
|
| 241 |
+
# Temp-Attn
|
| 242 |
+
assert unet_use_temporal_attention is not None
|
| 243 |
+
if unet_use_temporal_attention:
|
| 244 |
+
self.attn_temp = CrossAttention(
|
| 245 |
+
query_dim=dim,
|
| 246 |
+
heads=num_attention_heads,
|
| 247 |
+
dim_head=attention_head_dim,
|
| 248 |
+
dropout=dropout,
|
| 249 |
+
bias=attention_bias,
|
| 250 |
+
upcast_attention=upcast_attention,
|
| 251 |
+
)
|
| 252 |
+
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
|
| 253 |
+
self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
| 254 |
+
|
| 255 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
| 256 |
+
if not is_xformers_available():
|
| 257 |
+
print("Here is how to install it")
|
| 258 |
+
raise ModuleNotFoundError(
|
| 259 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
| 260 |
+
" xformers",
|
| 261 |
+
name="xformers",
|
| 262 |
+
)
|
| 263 |
+
elif not torch.cuda.is_available():
|
| 264 |
+
raise ValueError(
|
| 265 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
| 266 |
+
" available for GPU "
|
| 267 |
+
)
|
| 268 |
+
else:
|
| 269 |
+
try:
|
| 270 |
+
# Make sure we can run the memory efficient attention
|
| 271 |
+
_ = xformers.ops.memory_efficient_attention(
|
| 272 |
+
torch.randn((1, 2, 40), device="cuda"),
|
| 273 |
+
torch.randn((1, 2, 40), device="cuda"),
|
| 274 |
+
torch.randn((1, 2, 40), device="cuda"),
|
| 275 |
+
)
|
| 276 |
+
except Exception as e:
|
| 277 |
+
raise e
|
| 278 |
+
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
| 279 |
+
if self.audio_cross_attn is not None:
|
| 280 |
+
self.audio_cross_attn.attn._use_memory_efficient_attention_xformers = (
|
| 281 |
+
use_memory_efficient_attention_xformers
|
| 282 |
+
)
|
| 283 |
+
# self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
| 284 |
+
|
| 285 |
+
def forward(
|
| 286 |
+
self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None
|
| 287 |
+
):
|
| 288 |
+
# SparseCausal-Attention
|
| 289 |
+
norm_hidden_states = (
|
| 290 |
+
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# if self.only_cross_attention:
|
| 294 |
+
# hidden_states = (
|
| 295 |
+
# self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
|
| 296 |
+
# )
|
| 297 |
+
# else:
|
| 298 |
+
# hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
|
| 299 |
+
|
| 300 |
+
# pdb.set_trace()
|
| 301 |
+
if self.unet_use_cross_frame_attention:
|
| 302 |
+
hidden_states = (
|
| 303 |
+
self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length)
|
| 304 |
+
+ hidden_states
|
| 305 |
+
)
|
| 306 |
+
else:
|
| 307 |
+
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
|
| 308 |
+
|
| 309 |
+
if self.audio_cross_attn is not None and encoder_hidden_states is not None:
|
| 310 |
+
hidden_states = self.audio_cross_attn(
|
| 311 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# Feed-forward
|
| 315 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
| 316 |
+
|
| 317 |
+
# Temporal-Attention
|
| 318 |
+
if self.unet_use_temporal_attention:
|
| 319 |
+
d = hidden_states.shape[1]
|
| 320 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
| 321 |
+
norm_hidden_states = (
|
| 322 |
+
self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
|
| 323 |
+
)
|
| 324 |
+
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
|
| 325 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
| 326 |
+
|
| 327 |
+
return hidden_states
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class AudioTransformerBlock(nn.Module):
|
| 331 |
+
def __init__(
|
| 332 |
+
self,
|
| 333 |
+
dim: int,
|
| 334 |
+
num_attention_heads: int,
|
| 335 |
+
attention_head_dim: int,
|
| 336 |
+
dropout=0.0,
|
| 337 |
+
cross_attention_dim: Optional[int] = None,
|
| 338 |
+
activation_fn: str = "geglu",
|
| 339 |
+
num_embeds_ada_norm: Optional[int] = None,
|
| 340 |
+
attention_bias: bool = False,
|
| 341 |
+
only_cross_attention: bool = False,
|
| 342 |
+
upcast_attention: bool = False,
|
| 343 |
+
use_motion_module: bool = False,
|
| 344 |
+
unet_use_cross_frame_attention=None,
|
| 345 |
+
unet_use_temporal_attention=None,
|
| 346 |
+
add_audio_layer=False,
|
| 347 |
+
):
|
| 348 |
+
super().__init__()
|
| 349 |
+
self.only_cross_attention = only_cross_attention
|
| 350 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
| 351 |
+
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
|
| 352 |
+
self.unet_use_temporal_attention = unet_use_temporal_attention
|
| 353 |
+
self.use_motion_module = use_motion_module
|
| 354 |
+
self.add_audio_layer = add_audio_layer
|
| 355 |
+
|
| 356 |
+
# SC-Attn
|
| 357 |
+
assert unet_use_cross_frame_attention is not None
|
| 358 |
+
if unet_use_cross_frame_attention:
|
| 359 |
+
raise NotImplementedError("SparseCausalAttention2D not implemented yet.")
|
| 360 |
+
else:
|
| 361 |
+
self.attn1 = CrossAttention(
|
| 362 |
+
query_dim=dim,
|
| 363 |
+
heads=num_attention_heads,
|
| 364 |
+
dim_head=attention_head_dim,
|
| 365 |
+
dropout=dropout,
|
| 366 |
+
bias=attention_bias,
|
| 367 |
+
upcast_attention=upcast_attention,
|
| 368 |
+
)
|
| 369 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
| 370 |
+
|
| 371 |
+
self.audio_cross_attn = AudioCrossAttn(
|
| 372 |
+
dim=dim,
|
| 373 |
+
cross_attention_dim=cross_attention_dim,
|
| 374 |
+
num_attention_heads=num_attention_heads,
|
| 375 |
+
attention_head_dim=attention_head_dim,
|
| 376 |
+
dropout=dropout,
|
| 377 |
+
attention_bias=attention_bias,
|
| 378 |
+
upcast_attention=upcast_attention,
|
| 379 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
| 380 |
+
use_ada_layer_norm=self.use_ada_layer_norm,
|
| 381 |
+
zero_proj_out=False,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
# Feed-forward
|
| 385 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
| 386 |
+
self.norm3 = nn.LayerNorm(dim)
|
| 387 |
+
|
| 388 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
| 389 |
+
if not is_xformers_available():
|
| 390 |
+
print("Here is how to install it")
|
| 391 |
+
raise ModuleNotFoundError(
|
| 392 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
| 393 |
+
" xformers",
|
| 394 |
+
name="xformers",
|
| 395 |
+
)
|
| 396 |
+
elif not torch.cuda.is_available():
|
| 397 |
+
raise ValueError(
|
| 398 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
| 399 |
+
" available for GPU "
|
| 400 |
+
)
|
| 401 |
+
else:
|
| 402 |
+
try:
|
| 403 |
+
# Make sure we can run the memory efficient attention
|
| 404 |
+
_ = xformers.ops.memory_efficient_attention(
|
| 405 |
+
torch.randn((1, 2, 40), device="cuda"),
|
| 406 |
+
torch.randn((1, 2, 40), device="cuda"),
|
| 407 |
+
torch.randn((1, 2, 40), device="cuda"),
|
| 408 |
+
)
|
| 409 |
+
except Exception as e:
|
| 410 |
+
raise e
|
| 411 |
+
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
| 412 |
+
if self.audio_cross_attn is not None:
|
| 413 |
+
self.audio_cross_attn.attn._use_memory_efficient_attention_xformers = (
|
| 414 |
+
use_memory_efficient_attention_xformers
|
| 415 |
+
)
|
| 416 |
+
# self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
| 417 |
+
|
| 418 |
+
def forward(
|
| 419 |
+
self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None
|
| 420 |
+
):
|
| 421 |
+
# SparseCausal-Attention
|
| 422 |
+
norm_hidden_states = (
|
| 423 |
+
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# pdb.set_trace()
|
| 427 |
+
if self.unet_use_cross_frame_attention:
|
| 428 |
+
hidden_states = (
|
| 429 |
+
self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length)
|
| 430 |
+
+ hidden_states
|
| 431 |
+
)
|
| 432 |
+
else:
|
| 433 |
+
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
|
| 434 |
+
|
| 435 |
+
if self.audio_cross_attn is not None and encoder_hidden_states is not None:
|
| 436 |
+
hidden_states = self.audio_cross_attn(
|
| 437 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# Feed-forward
|
| 441 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
| 442 |
+
|
| 443 |
+
return hidden_states
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class AudioCrossAttn(nn.Module):
|
| 447 |
+
def __init__(
|
| 448 |
+
self,
|
| 449 |
+
dim,
|
| 450 |
+
cross_attention_dim,
|
| 451 |
+
num_attention_heads,
|
| 452 |
+
attention_head_dim,
|
| 453 |
+
dropout,
|
| 454 |
+
attention_bias,
|
| 455 |
+
upcast_attention,
|
| 456 |
+
num_embeds_ada_norm,
|
| 457 |
+
use_ada_layer_norm,
|
| 458 |
+
zero_proj_out=False,
|
| 459 |
+
):
|
| 460 |
+
super().__init__()
|
| 461 |
+
|
| 462 |
+
self.norm = AdaLayerNorm(dim, num_embeds_ada_norm) if use_ada_layer_norm else nn.LayerNorm(dim)
|
| 463 |
+
self.attn = CrossAttention(
|
| 464 |
+
query_dim=dim,
|
| 465 |
+
cross_attention_dim=cross_attention_dim,
|
| 466 |
+
heads=num_attention_heads,
|
| 467 |
+
dim_head=attention_head_dim,
|
| 468 |
+
dropout=dropout,
|
| 469 |
+
bias=attention_bias,
|
| 470 |
+
upcast_attention=upcast_attention,
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
if zero_proj_out:
|
| 474 |
+
self.proj_out = zero_module(nn.Linear(dim, dim))
|
| 475 |
+
|
| 476 |
+
self.zero_proj_out = zero_proj_out
|
| 477 |
+
self.use_ada_layer_norm = use_ada_layer_norm
|
| 478 |
+
|
| 479 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None):
|
| 480 |
+
previous_hidden_states = hidden_states
|
| 481 |
+
hidden_states = self.norm(hidden_states, timestep) if self.use_ada_layer_norm else self.norm(hidden_states)
|
| 482 |
+
|
| 483 |
+
if encoder_hidden_states.dim() == 4:
|
| 484 |
+
encoder_hidden_states = rearrange(encoder_hidden_states, "b f n d -> (b f) n d")
|
| 485 |
+
|
| 486 |
+
hidden_states = self.attn(
|
| 487 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
if self.zero_proj_out:
|
| 491 |
+
hidden_states = self.proj_out(hidden_states)
|
| 492 |
+
return hidden_states + previous_hidden_states
|
latentsync/models/motion_module.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
|
| 2 |
+
|
| 3 |
+
# Actually we don't use the motion module in the final version of LatentSync
|
| 4 |
+
# When we started the project, we used the codebase of AnimateDiff and tried motion module
|
| 5 |
+
# But the results are poor, and we decied to leave the code here for possible future usage
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 14 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 15 |
+
from diffusers.utils import BaseOutput
|
| 16 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 17 |
+
from diffusers.models.attention import Attention as CrossAttention, FeedForward
|
| 18 |
+
|
| 19 |
+
from einops import rearrange, repeat
|
| 20 |
+
import math
|
| 21 |
+
from .utils import zero_module
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class TemporalTransformer3DModelOutput(BaseOutput):
|
| 26 |
+
sample: torch.FloatTensor
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if is_xformers_available():
|
| 30 |
+
import xformers
|
| 31 |
+
import xformers.ops
|
| 32 |
+
else:
|
| 33 |
+
xformers = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
|
| 37 |
+
if motion_module_type == "Vanilla":
|
| 38 |
+
return VanillaTemporalModule(
|
| 39 |
+
in_channels=in_channels,
|
| 40 |
+
**motion_module_kwargs,
|
| 41 |
+
)
|
| 42 |
+
else:
|
| 43 |
+
raise ValueError
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class VanillaTemporalModule(nn.Module):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
in_channels,
|
| 50 |
+
num_attention_heads=8,
|
| 51 |
+
num_transformer_block=2,
|
| 52 |
+
attention_block_types=("Temporal_Self", "Temporal_Self"),
|
| 53 |
+
cross_frame_attention_mode=None,
|
| 54 |
+
temporal_position_encoding=False,
|
| 55 |
+
temporal_position_encoding_max_len=24,
|
| 56 |
+
temporal_attention_dim_div=1,
|
| 57 |
+
zero_initialize=True,
|
| 58 |
+
):
|
| 59 |
+
super().__init__()
|
| 60 |
+
|
| 61 |
+
self.temporal_transformer = TemporalTransformer3DModel(
|
| 62 |
+
in_channels=in_channels,
|
| 63 |
+
num_attention_heads=num_attention_heads,
|
| 64 |
+
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
|
| 65 |
+
num_layers=num_transformer_block,
|
| 66 |
+
attention_block_types=attention_block_types,
|
| 67 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
| 68 |
+
temporal_position_encoding=temporal_position_encoding,
|
| 69 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if zero_initialize:
|
| 73 |
+
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
|
| 74 |
+
|
| 75 |
+
def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
|
| 76 |
+
hidden_states = input_tensor
|
| 77 |
+
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
|
| 78 |
+
|
| 79 |
+
output = hidden_states
|
| 80 |
+
return output
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class TemporalTransformer3DModel(nn.Module):
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
in_channels,
|
| 87 |
+
num_attention_heads,
|
| 88 |
+
attention_head_dim,
|
| 89 |
+
num_layers,
|
| 90 |
+
attention_block_types=(
|
| 91 |
+
"Temporal_Self",
|
| 92 |
+
"Temporal_Self",
|
| 93 |
+
),
|
| 94 |
+
dropout=0.0,
|
| 95 |
+
norm_num_groups=32,
|
| 96 |
+
cross_attention_dim=768,
|
| 97 |
+
activation_fn="geglu",
|
| 98 |
+
attention_bias=False,
|
| 99 |
+
upcast_attention=False,
|
| 100 |
+
cross_frame_attention_mode=None,
|
| 101 |
+
temporal_position_encoding=False,
|
| 102 |
+
temporal_position_encoding_max_len=24,
|
| 103 |
+
):
|
| 104 |
+
super().__init__()
|
| 105 |
+
|
| 106 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 107 |
+
|
| 108 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
| 109 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
| 110 |
+
|
| 111 |
+
self.transformer_blocks = nn.ModuleList(
|
| 112 |
+
[
|
| 113 |
+
TemporalTransformerBlock(
|
| 114 |
+
dim=inner_dim,
|
| 115 |
+
num_attention_heads=num_attention_heads,
|
| 116 |
+
attention_head_dim=attention_head_dim,
|
| 117 |
+
attention_block_types=attention_block_types,
|
| 118 |
+
dropout=dropout,
|
| 119 |
+
norm_num_groups=norm_num_groups,
|
| 120 |
+
cross_attention_dim=cross_attention_dim,
|
| 121 |
+
activation_fn=activation_fn,
|
| 122 |
+
attention_bias=attention_bias,
|
| 123 |
+
upcast_attention=upcast_attention,
|
| 124 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
| 125 |
+
temporal_position_encoding=temporal_position_encoding,
|
| 126 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
| 127 |
+
)
|
| 128 |
+
for d in range(num_layers)
|
| 129 |
+
]
|
| 130 |
+
)
|
| 131 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
| 132 |
+
|
| 133 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
| 134 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
| 135 |
+
video_length = hidden_states.shape[2]
|
| 136 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
| 137 |
+
|
| 138 |
+
batch, channel, height, weight = hidden_states.shape
|
| 139 |
+
residual = hidden_states
|
| 140 |
+
|
| 141 |
+
hidden_states = self.norm(hidden_states)
|
| 142 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
|
| 143 |
+
hidden_states = self.proj_in(hidden_states)
|
| 144 |
+
|
| 145 |
+
# Transformer Blocks
|
| 146 |
+
for block in self.transformer_blocks:
|
| 147 |
+
hidden_states = block(
|
| 148 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# output
|
| 152 |
+
hidden_states = self.proj_out(hidden_states)
|
| 153 |
+
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2).contiguous()
|
| 154 |
+
|
| 155 |
+
output = hidden_states + residual
|
| 156 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
| 157 |
+
|
| 158 |
+
return output
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class TemporalTransformerBlock(nn.Module):
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
dim,
|
| 165 |
+
num_attention_heads,
|
| 166 |
+
attention_head_dim,
|
| 167 |
+
attention_block_types=(
|
| 168 |
+
"Temporal_Self",
|
| 169 |
+
"Temporal_Self",
|
| 170 |
+
),
|
| 171 |
+
dropout=0.0,
|
| 172 |
+
norm_num_groups=32,
|
| 173 |
+
cross_attention_dim=768,
|
| 174 |
+
activation_fn="geglu",
|
| 175 |
+
attention_bias=False,
|
| 176 |
+
upcast_attention=False,
|
| 177 |
+
cross_frame_attention_mode=None,
|
| 178 |
+
temporal_position_encoding=False,
|
| 179 |
+
temporal_position_encoding_max_len=24,
|
| 180 |
+
):
|
| 181 |
+
super().__init__()
|
| 182 |
+
|
| 183 |
+
attention_blocks = []
|
| 184 |
+
norms = []
|
| 185 |
+
|
| 186 |
+
for block_name in attention_block_types:
|
| 187 |
+
attention_blocks.append(
|
| 188 |
+
VersatileAttention(
|
| 189 |
+
attention_mode=block_name.split("_")[0],
|
| 190 |
+
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
|
| 191 |
+
query_dim=dim,
|
| 192 |
+
heads=num_attention_heads,
|
| 193 |
+
dim_head=attention_head_dim,
|
| 194 |
+
dropout=dropout,
|
| 195 |
+
bias=attention_bias,
|
| 196 |
+
upcast_attention=upcast_attention,
|
| 197 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
| 198 |
+
temporal_position_encoding=temporal_position_encoding,
|
| 199 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
| 200 |
+
)
|
| 201 |
+
)
|
| 202 |
+
norms.append(nn.LayerNorm(dim))
|
| 203 |
+
|
| 204 |
+
self.attention_blocks = nn.ModuleList(attention_blocks)
|
| 205 |
+
self.norms = nn.ModuleList(norms)
|
| 206 |
+
|
| 207 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
| 208 |
+
self.ff_norm = nn.LayerNorm(dim)
|
| 209 |
+
|
| 210 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
| 211 |
+
for attention_block, norm in zip(self.attention_blocks, self.norms):
|
| 212 |
+
norm_hidden_states = norm(hidden_states)
|
| 213 |
+
hidden_states = (
|
| 214 |
+
attention_block(
|
| 215 |
+
norm_hidden_states,
|
| 216 |
+
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
|
| 217 |
+
video_length=video_length,
|
| 218 |
+
)
|
| 219 |
+
+ hidden_states
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
|
| 223 |
+
|
| 224 |
+
output = hidden_states
|
| 225 |
+
return output
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class PositionalEncoding(nn.Module):
|
| 229 |
+
def __init__(self, d_model, dropout=0.0, max_len=24):
|
| 230 |
+
super().__init__()
|
| 231 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 232 |
+
position = torch.arange(max_len).unsqueeze(1)
|
| 233 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
| 234 |
+
pe = torch.zeros(1, max_len, d_model)
|
| 235 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
| 236 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
| 237 |
+
self.register_buffer("pe", pe)
|
| 238 |
+
|
| 239 |
+
def forward(self, x):
|
| 240 |
+
x = x + self.pe[:, : x.size(1)]
|
| 241 |
+
return self.dropout(x)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class VersatileAttention(CrossAttention):
|
| 245 |
+
def __init__(
|
| 246 |
+
self,
|
| 247 |
+
attention_mode=None,
|
| 248 |
+
cross_frame_attention_mode=None,
|
| 249 |
+
temporal_position_encoding=False,
|
| 250 |
+
temporal_position_encoding_max_len=24,
|
| 251 |
+
*args,
|
| 252 |
+
**kwargs,
|
| 253 |
+
):
|
| 254 |
+
super().__init__(*args, **kwargs)
|
| 255 |
+
assert attention_mode == "Temporal"
|
| 256 |
+
|
| 257 |
+
self.attention_mode = attention_mode
|
| 258 |
+
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
|
| 259 |
+
|
| 260 |
+
self.pos_encoder = (
|
| 261 |
+
PositionalEncoding(kwargs["query_dim"], dropout=0.0, max_len=temporal_position_encoding_max_len)
|
| 262 |
+
if (temporal_position_encoding and attention_mode == "Temporal")
|
| 263 |
+
else None
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
def extra_repr(self):
|
| 267 |
+
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
|
| 268 |
+
|
| 269 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
| 270 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 271 |
+
|
| 272 |
+
if self.attention_mode == "Temporal":
|
| 273 |
+
d = hidden_states.shape[1]
|
| 274 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
| 275 |
+
|
| 276 |
+
if self.pos_encoder is not None:
|
| 277 |
+
hidden_states = self.pos_encoder(hidden_states)
|
| 278 |
+
|
| 279 |
+
encoder_hidden_states = (
|
| 280 |
+
repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
|
| 281 |
+
if encoder_hidden_states is not None
|
| 282 |
+
else encoder_hidden_states
|
| 283 |
+
)
|
| 284 |
+
else:
|
| 285 |
+
raise NotImplementedError
|
| 286 |
+
|
| 287 |
+
# encoder_hidden_states = encoder_hidden_states
|
| 288 |
+
|
| 289 |
+
if self.group_norm is not None:
|
| 290 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 291 |
+
|
| 292 |
+
query = self.to_q(hidden_states)
|
| 293 |
+
dim = query.shape[-1]
|
| 294 |
+
query = self.reshape_heads_to_batch_dim(query)
|
| 295 |
+
|
| 296 |
+
if self.added_kv_proj_dim is not None:
|
| 297 |
+
raise NotImplementedError
|
| 298 |
+
|
| 299 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
| 300 |
+
key = self.to_k(encoder_hidden_states)
|
| 301 |
+
value = self.to_v(encoder_hidden_states)
|
| 302 |
+
|
| 303 |
+
key = self.reshape_heads_to_batch_dim(key)
|
| 304 |
+
value = self.reshape_heads_to_batch_dim(value)
|
| 305 |
+
|
| 306 |
+
if attention_mask is not None:
|
| 307 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
| 308 |
+
target_length = query.shape[1]
|
| 309 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
| 310 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
| 311 |
+
|
| 312 |
+
# attention, what we cannot get enough of
|
| 313 |
+
if self._use_memory_efficient_attention_xformers:
|
| 314 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
| 315 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
| 316 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 317 |
+
else:
|
| 318 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
| 319 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
| 320 |
+
else:
|
| 321 |
+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
| 322 |
+
|
| 323 |
+
# linear proj
|
| 324 |
+
hidden_states = self.to_out[0](hidden_states)
|
| 325 |
+
|
| 326 |
+
# dropout
|
| 327 |
+
hidden_states = self.to_out[1](hidden_states)
|
| 328 |
+
|
| 329 |
+
if self.attention_mode == "Temporal":
|
| 330 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
| 331 |
+
|
| 332 |
+
return hidden_states
|
latentsync/models/resnet.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class InflatedConv3d(nn.Conv2d):
|
| 11 |
+
def forward(self, x):
|
| 12 |
+
video_length = x.shape[2]
|
| 13 |
+
|
| 14 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
| 15 |
+
x = super().forward(x)
|
| 16 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
| 17 |
+
|
| 18 |
+
return x
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class InflatedGroupNorm(nn.GroupNorm):
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
video_length = x.shape[2]
|
| 24 |
+
|
| 25 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
| 26 |
+
x = super().forward(x)
|
| 27 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
| 28 |
+
|
| 29 |
+
return x
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Upsample3D(nn.Module):
|
| 33 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.channels = channels
|
| 36 |
+
self.out_channels = out_channels or channels
|
| 37 |
+
self.use_conv = use_conv
|
| 38 |
+
self.use_conv_transpose = use_conv_transpose
|
| 39 |
+
self.name = name
|
| 40 |
+
|
| 41 |
+
conv = None
|
| 42 |
+
if use_conv_transpose:
|
| 43 |
+
raise NotImplementedError
|
| 44 |
+
elif use_conv:
|
| 45 |
+
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
|
| 46 |
+
|
| 47 |
+
def forward(self, hidden_states, output_size=None):
|
| 48 |
+
assert hidden_states.shape[1] == self.channels
|
| 49 |
+
|
| 50 |
+
if self.use_conv_transpose:
|
| 51 |
+
raise NotImplementedError
|
| 52 |
+
|
| 53 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
| 54 |
+
dtype = hidden_states.dtype
|
| 55 |
+
if dtype == torch.bfloat16:
|
| 56 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 57 |
+
|
| 58 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
| 59 |
+
if hidden_states.shape[0] >= 64:
|
| 60 |
+
hidden_states = hidden_states.contiguous()
|
| 61 |
+
|
| 62 |
+
# if `output_size` is passed we force the interpolation output
|
| 63 |
+
# size and do not make use of `scale_factor=2`
|
| 64 |
+
if output_size is None:
|
| 65 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
|
| 66 |
+
else:
|
| 67 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
| 68 |
+
|
| 69 |
+
# If the input is bfloat16, we cast back to bfloat16
|
| 70 |
+
if dtype == torch.bfloat16:
|
| 71 |
+
hidden_states = hidden_states.to(dtype)
|
| 72 |
+
|
| 73 |
+
# if self.use_conv:
|
| 74 |
+
# if self.name == "conv":
|
| 75 |
+
# hidden_states = self.conv(hidden_states)
|
| 76 |
+
# else:
|
| 77 |
+
# hidden_states = self.Conv2d_0(hidden_states)
|
| 78 |
+
hidden_states = self.conv(hidden_states)
|
| 79 |
+
|
| 80 |
+
return hidden_states
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class Downsample3D(nn.Module):
|
| 84 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.channels = channels
|
| 87 |
+
self.out_channels = out_channels or channels
|
| 88 |
+
self.use_conv = use_conv
|
| 89 |
+
self.padding = padding
|
| 90 |
+
stride = 2
|
| 91 |
+
self.name = name
|
| 92 |
+
|
| 93 |
+
if use_conv:
|
| 94 |
+
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
| 95 |
+
else:
|
| 96 |
+
raise NotImplementedError
|
| 97 |
+
|
| 98 |
+
def forward(self, hidden_states):
|
| 99 |
+
assert hidden_states.shape[1] == self.channels
|
| 100 |
+
if self.use_conv and self.padding == 0:
|
| 101 |
+
raise NotImplementedError
|
| 102 |
+
|
| 103 |
+
assert hidden_states.shape[1] == self.channels
|
| 104 |
+
hidden_states = self.conv(hidden_states)
|
| 105 |
+
|
| 106 |
+
return hidden_states
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class ResnetBlock3D(nn.Module):
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
*,
|
| 113 |
+
in_channels,
|
| 114 |
+
out_channels=None,
|
| 115 |
+
conv_shortcut=False,
|
| 116 |
+
dropout=0.0,
|
| 117 |
+
temb_channels=512,
|
| 118 |
+
groups=32,
|
| 119 |
+
groups_out=None,
|
| 120 |
+
pre_norm=True,
|
| 121 |
+
eps=1e-6,
|
| 122 |
+
non_linearity="swish",
|
| 123 |
+
time_embedding_norm="default",
|
| 124 |
+
output_scale_factor=1.0,
|
| 125 |
+
use_in_shortcut=None,
|
| 126 |
+
use_inflated_groupnorm=False,
|
| 127 |
+
):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.pre_norm = pre_norm
|
| 130 |
+
self.pre_norm = True
|
| 131 |
+
self.in_channels = in_channels
|
| 132 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 133 |
+
self.out_channels = out_channels
|
| 134 |
+
self.use_conv_shortcut = conv_shortcut
|
| 135 |
+
self.time_embedding_norm = time_embedding_norm
|
| 136 |
+
self.output_scale_factor = output_scale_factor
|
| 137 |
+
|
| 138 |
+
if groups_out is None:
|
| 139 |
+
groups_out = groups
|
| 140 |
+
|
| 141 |
+
assert use_inflated_groupnorm != None
|
| 142 |
+
if use_inflated_groupnorm:
|
| 143 |
+
self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
| 144 |
+
else:
|
| 145 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
| 146 |
+
|
| 147 |
+
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 148 |
+
|
| 149 |
+
if temb_channels is not None:
|
| 150 |
+
time_emb_proj_out_channels = out_channels
|
| 151 |
+
# if self.time_embedding_norm == "default":
|
| 152 |
+
# time_emb_proj_out_channels = out_channels
|
| 153 |
+
# elif self.time_embedding_norm == "scale_shift":
|
| 154 |
+
# time_emb_proj_out_channels = out_channels * 2
|
| 155 |
+
# else:
|
| 156 |
+
# raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
| 157 |
+
|
| 158 |
+
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
|
| 159 |
+
else:
|
| 160 |
+
self.time_emb_proj = None
|
| 161 |
+
|
| 162 |
+
if self.time_embedding_norm == "scale_shift":
|
| 163 |
+
self.double_len_linear = torch.nn.Linear(time_emb_proj_out_channels, 2 * time_emb_proj_out_channels)
|
| 164 |
+
else:
|
| 165 |
+
self.double_len_linear = None
|
| 166 |
+
|
| 167 |
+
if use_inflated_groupnorm:
|
| 168 |
+
self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
| 169 |
+
else:
|
| 170 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
| 171 |
+
|
| 172 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 173 |
+
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 174 |
+
|
| 175 |
+
if non_linearity == "swish":
|
| 176 |
+
self.nonlinearity = lambda x: F.silu(x)
|
| 177 |
+
elif non_linearity == "mish":
|
| 178 |
+
self.nonlinearity = Mish()
|
| 179 |
+
elif non_linearity == "silu":
|
| 180 |
+
self.nonlinearity = nn.SiLU()
|
| 181 |
+
|
| 182 |
+
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
|
| 183 |
+
|
| 184 |
+
self.conv_shortcut = None
|
| 185 |
+
if self.use_in_shortcut:
|
| 186 |
+
self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 187 |
+
|
| 188 |
+
def forward(self, input_tensor, temb):
|
| 189 |
+
hidden_states = input_tensor
|
| 190 |
+
|
| 191 |
+
hidden_states = self.norm1(hidden_states)
|
| 192 |
+
hidden_states = self.nonlinearity(hidden_states)
|
| 193 |
+
|
| 194 |
+
hidden_states = self.conv1(hidden_states)
|
| 195 |
+
|
| 196 |
+
if temb is not None:
|
| 197 |
+
if temb.dim() == 2:
|
| 198 |
+
# input (1, 1280)
|
| 199 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))
|
| 200 |
+
temb = temb[:, :, None, None, None] # unsqueeze
|
| 201 |
+
else:
|
| 202 |
+
# input (1, 1280, 16)
|
| 203 |
+
temb = temb.permute(0, 2, 1)
|
| 204 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))
|
| 205 |
+
if self.double_len_linear is not None:
|
| 206 |
+
temb = self.double_len_linear(self.nonlinearity(temb))
|
| 207 |
+
temb = temb.permute(0, 2, 1)
|
| 208 |
+
temb = temb[:, :, :, None, None]
|
| 209 |
+
|
| 210 |
+
if temb is not None and self.time_embedding_norm == "default":
|
| 211 |
+
hidden_states = hidden_states + temb
|
| 212 |
+
|
| 213 |
+
hidden_states = self.norm2(hidden_states)
|
| 214 |
+
|
| 215 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
| 216 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
| 217 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
| 218 |
+
|
| 219 |
+
hidden_states = self.nonlinearity(hidden_states)
|
| 220 |
+
|
| 221 |
+
hidden_states = self.dropout(hidden_states)
|
| 222 |
+
hidden_states = self.conv2(hidden_states)
|
| 223 |
+
|
| 224 |
+
if self.conv_shortcut is not None:
|
| 225 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
| 226 |
+
|
| 227 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
| 228 |
+
|
| 229 |
+
return output_tensor
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class Mish(torch.nn.Module):
|
| 233 |
+
def forward(self, hidden_states):
|
| 234 |
+
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
|
latentsync/models/syncnet.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn
|
| 17 |
+
from einops import rearrange
|
| 18 |
+
from torch.nn import functional as F
|
| 19 |
+
from ..utils.util import cosine_loss
|
| 20 |
+
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
|
| 24 |
+
from diffusers.models.attention import CrossAttention, FeedForward
|
| 25 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 26 |
+
from einops import rearrange
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SyncNet(nn.Module):
|
| 30 |
+
def __init__(self, config):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.audio_encoder = DownEncoder2D(
|
| 33 |
+
in_channels=config["audio_encoder"]["in_channels"],
|
| 34 |
+
block_out_channels=config["audio_encoder"]["block_out_channels"],
|
| 35 |
+
downsample_factors=config["audio_encoder"]["downsample_factors"],
|
| 36 |
+
dropout=config["audio_encoder"]["dropout"],
|
| 37 |
+
attn_blocks=config["audio_encoder"]["attn_blocks"],
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
self.visual_encoder = DownEncoder2D(
|
| 41 |
+
in_channels=config["visual_encoder"]["in_channels"],
|
| 42 |
+
block_out_channels=config["visual_encoder"]["block_out_channels"],
|
| 43 |
+
downsample_factors=config["visual_encoder"]["downsample_factors"],
|
| 44 |
+
dropout=config["visual_encoder"]["dropout"],
|
| 45 |
+
attn_blocks=config["visual_encoder"]["attn_blocks"],
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
self.eval()
|
| 49 |
+
|
| 50 |
+
def forward(self, image_sequences, audio_sequences):
|
| 51 |
+
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
|
| 52 |
+
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
|
| 53 |
+
|
| 54 |
+
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
|
| 55 |
+
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
|
| 56 |
+
|
| 57 |
+
# Make them unit vectors
|
| 58 |
+
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
|
| 59 |
+
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
|
| 60 |
+
|
| 61 |
+
return vision_embeds, audio_embeds
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class ResnetBlock2D(nn.Module):
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
in_channels: int,
|
| 68 |
+
out_channels: int,
|
| 69 |
+
dropout: float = 0.0,
|
| 70 |
+
norm_num_groups: int = 32,
|
| 71 |
+
eps: float = 1e-6,
|
| 72 |
+
act_fn: str = "silu",
|
| 73 |
+
downsample_factor=2,
|
| 74 |
+
):
|
| 75 |
+
super().__init__()
|
| 76 |
+
|
| 77 |
+
self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
| 78 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 79 |
+
|
| 80 |
+
self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=eps, affine=True)
|
| 81 |
+
self.dropout = nn.Dropout(dropout)
|
| 82 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 83 |
+
|
| 84 |
+
if act_fn == "relu":
|
| 85 |
+
self.act_fn = nn.ReLU()
|
| 86 |
+
elif act_fn == "silu":
|
| 87 |
+
self.act_fn = nn.SiLU()
|
| 88 |
+
|
| 89 |
+
if in_channels != out_channels:
|
| 90 |
+
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 91 |
+
else:
|
| 92 |
+
self.conv_shortcut = None
|
| 93 |
+
|
| 94 |
+
if isinstance(downsample_factor, list):
|
| 95 |
+
downsample_factor = tuple(downsample_factor)
|
| 96 |
+
|
| 97 |
+
if downsample_factor == 1:
|
| 98 |
+
self.downsample_conv = None
|
| 99 |
+
else:
|
| 100 |
+
self.downsample_conv = nn.Conv2d(
|
| 101 |
+
out_channels, out_channels, kernel_size=3, stride=downsample_factor, padding=0
|
| 102 |
+
)
|
| 103 |
+
self.pad = (0, 1, 0, 1)
|
| 104 |
+
if isinstance(downsample_factor, tuple):
|
| 105 |
+
if downsample_factor[0] == 1:
|
| 106 |
+
self.pad = (0, 1, 1, 1) # The padding order is from back to front
|
| 107 |
+
elif downsample_factor[1] == 1:
|
| 108 |
+
self.pad = (1, 1, 0, 1)
|
| 109 |
+
|
| 110 |
+
def forward(self, input_tensor):
|
| 111 |
+
hidden_states = input_tensor
|
| 112 |
+
|
| 113 |
+
hidden_states = self.norm1(hidden_states)
|
| 114 |
+
hidden_states = self.act_fn(hidden_states)
|
| 115 |
+
|
| 116 |
+
hidden_states = self.conv1(hidden_states)
|
| 117 |
+
hidden_states = self.norm2(hidden_states)
|
| 118 |
+
hidden_states = self.act_fn(hidden_states)
|
| 119 |
+
|
| 120 |
+
hidden_states = self.dropout(hidden_states)
|
| 121 |
+
hidden_states = self.conv2(hidden_states)
|
| 122 |
+
|
| 123 |
+
if self.conv_shortcut is not None:
|
| 124 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
| 125 |
+
|
| 126 |
+
hidden_states += input_tensor
|
| 127 |
+
|
| 128 |
+
if self.downsample_conv is not None:
|
| 129 |
+
hidden_states = F.pad(hidden_states, self.pad, mode="constant", value=0)
|
| 130 |
+
hidden_states = self.downsample_conv(hidden_states)
|
| 131 |
+
|
| 132 |
+
return hidden_states
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class AttentionBlock2D(nn.Module):
|
| 136 |
+
def __init__(self, query_dim, norm_num_groups=32, dropout=0.0):
|
| 137 |
+
super().__init__()
|
| 138 |
+
if not is_xformers_available():
|
| 139 |
+
raise ModuleNotFoundError(
|
| 140 |
+
"You have to install xformers to enable memory efficient attetion", name="xformers"
|
| 141 |
+
)
|
| 142 |
+
# inner_dim = dim_head * heads
|
| 143 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=query_dim, eps=1e-6, affine=True)
|
| 144 |
+
self.norm2 = nn.LayerNorm(query_dim)
|
| 145 |
+
self.norm3 = nn.LayerNorm(query_dim)
|
| 146 |
+
|
| 147 |
+
self.ff = FeedForward(query_dim, dropout=dropout, activation_fn="geglu")
|
| 148 |
+
|
| 149 |
+
self.conv_in = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
|
| 150 |
+
self.conv_out = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
|
| 151 |
+
|
| 152 |
+
self.attn = CrossAttention(query_dim=query_dim, heads=8, dim_head=query_dim // 8, dropout=dropout, bias=True)
|
| 153 |
+
self.attn._use_memory_efficient_attention_xformers = True
|
| 154 |
+
|
| 155 |
+
def forward(self, hidden_states):
|
| 156 |
+
assert hidden_states.dim() == 4, f"Expected hidden_states to have ndim=4, but got ndim={hidden_states.dim()}."
|
| 157 |
+
|
| 158 |
+
batch, channel, height, width = hidden_states.shape
|
| 159 |
+
residual = hidden_states
|
| 160 |
+
|
| 161 |
+
hidden_states = self.norm1(hidden_states)
|
| 162 |
+
hidden_states = self.conv_in(hidden_states)
|
| 163 |
+
hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
|
| 164 |
+
|
| 165 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 166 |
+
hidden_states = self.attn(norm_hidden_states, attention_mask=None) + hidden_states
|
| 167 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
| 168 |
+
|
| 169 |
+
hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=height, w=width)
|
| 170 |
+
hidden_states = self.conv_out(hidden_states)
|
| 171 |
+
|
| 172 |
+
hidden_states = hidden_states + residual
|
| 173 |
+
return hidden_states
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class DownEncoder2D(nn.Module):
|
| 177 |
+
def __init__(
|
| 178 |
+
self,
|
| 179 |
+
in_channels=4 * 16,
|
| 180 |
+
block_out_channels=[64, 128, 256, 256],
|
| 181 |
+
downsample_factors=[2, 2, 2, 2],
|
| 182 |
+
layers_per_block=2,
|
| 183 |
+
norm_num_groups=32,
|
| 184 |
+
attn_blocks=[1, 1, 1, 1],
|
| 185 |
+
dropout: float = 0.0,
|
| 186 |
+
act_fn="silu",
|
| 187 |
+
):
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.layers_per_block = layers_per_block
|
| 190 |
+
|
| 191 |
+
# in
|
| 192 |
+
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
| 193 |
+
|
| 194 |
+
# down
|
| 195 |
+
self.down_blocks = nn.ModuleList([])
|
| 196 |
+
|
| 197 |
+
output_channels = block_out_channels[0]
|
| 198 |
+
for i, block_out_channel in enumerate(block_out_channels):
|
| 199 |
+
input_channels = output_channels
|
| 200 |
+
output_channels = block_out_channel
|
| 201 |
+
# is_final_block = i == len(block_out_channels) - 1
|
| 202 |
+
|
| 203 |
+
down_block = ResnetBlock2D(
|
| 204 |
+
in_channels=input_channels,
|
| 205 |
+
out_channels=output_channels,
|
| 206 |
+
downsample_factor=downsample_factors[i],
|
| 207 |
+
norm_num_groups=norm_num_groups,
|
| 208 |
+
dropout=dropout,
|
| 209 |
+
act_fn=act_fn,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
self.down_blocks.append(down_block)
|
| 213 |
+
|
| 214 |
+
if attn_blocks[i] == 1:
|
| 215 |
+
attention_block = AttentionBlock2D(query_dim=output_channels, dropout=dropout)
|
| 216 |
+
self.down_blocks.append(attention_block)
|
| 217 |
+
|
| 218 |
+
# out
|
| 219 |
+
self.norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
| 220 |
+
self.act_fn_out = nn.ReLU()
|
| 221 |
+
|
| 222 |
+
def forward(self, hidden_states):
|
| 223 |
+
hidden_states = self.conv_in(hidden_states)
|
| 224 |
+
|
| 225 |
+
# down
|
| 226 |
+
for down_block in self.down_blocks:
|
| 227 |
+
hidden_states = down_block(hidden_states)
|
| 228 |
+
|
| 229 |
+
# post-process
|
| 230 |
+
hidden_states = self.norm_out(hidden_states)
|
| 231 |
+
hidden_states = self.act_fn_out(hidden_states)
|
| 232 |
+
|
| 233 |
+
return hidden_states
|
latentsync/models/syncnet_wav2lip.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/primepake/wav2lip_288x288/blob/master/models/syncnetv2.py
|
| 2 |
+
# The code here is for ablation study.
|
| 3 |
+
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SyncNetWav2Lip(nn.Module):
|
| 9 |
+
def __init__(self, act_fn="leaky"):
|
| 10 |
+
super().__init__()
|
| 11 |
+
|
| 12 |
+
# input image sequences: (15, 128, 256)
|
| 13 |
+
self.visual_encoder = nn.Sequential(
|
| 14 |
+
Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3, act_fn=act_fn), # (128, 256)
|
| 15 |
+
Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1, act_fn=act_fn), # (126, 127)
|
| 16 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 17 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 18 |
+
Conv2d(64, 128, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (63, 64)
|
| 19 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 20 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 21 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 22 |
+
Conv2d(128, 256, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (21, 22)
|
| 23 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 24 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 25 |
+
Conv2d(256, 512, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (11, 11)
|
| 26 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 27 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 28 |
+
Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (6, 6)
|
| 29 |
+
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 30 |
+
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 31 |
+
Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1, act_fn="relu"), # (3, 3)
|
| 32 |
+
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1)
|
| 33 |
+
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# input audio sequences: (1, 80, 16)
|
| 37 |
+
self.audio_encoder = nn.Sequential(
|
| 38 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1, act_fn=act_fn),
|
| 39 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 40 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 41 |
+
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1, act_fn=act_fn), # (27, 16)
|
| 42 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 43 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 44 |
+
Conv2d(64, 128, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (9, 6)
|
| 45 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 46 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 47 |
+
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1, act_fn=act_fn), # (3, 3)
|
| 48 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 49 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 50 |
+
Conv2d(256, 512, kernel_size=3, stride=1, padding=1, act_fn=act_fn),
|
| 51 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 52 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
| 53 |
+
Conv2d(512, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1)
|
| 54 |
+
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def forward(self, image_sequences, audio_sequences):
|
| 58 |
+
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
|
| 59 |
+
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
|
| 60 |
+
|
| 61 |
+
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
|
| 62 |
+
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
|
| 63 |
+
|
| 64 |
+
# Make them unit vectors
|
| 65 |
+
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
|
| 66 |
+
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
|
| 67 |
+
|
| 68 |
+
return vision_embeds, audio_embeds
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class Conv2d(nn.Module):
|
| 72 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, act_fn="relu", *args, **kwargs):
|
| 73 |
+
super().__init__(*args, **kwargs)
|
| 74 |
+
self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout))
|
| 75 |
+
if act_fn == "relu":
|
| 76 |
+
self.act_fn = nn.ReLU()
|
| 77 |
+
elif act_fn == "tanh":
|
| 78 |
+
self.act_fn = nn.Tanh()
|
| 79 |
+
elif act_fn == "silu":
|
| 80 |
+
self.act_fn = nn.SiLU()
|
| 81 |
+
elif act_fn == "leaky":
|
| 82 |
+
self.act_fn = nn.LeakyReLU(0.2, inplace=True)
|
| 83 |
+
|
| 84 |
+
self.residual = residual
|
| 85 |
+
|
| 86 |
+
def forward(self, x):
|
| 87 |
+
out = self.conv_block(x)
|
| 88 |
+
if self.residual:
|
| 89 |
+
out += x
|
| 90 |
+
return self.act_fn(out)
|
latentsync/models/unet.py
ADDED
|
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet.py
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import List, Optional, Tuple, Union
|
| 5 |
+
import copy
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.utils.checkpoint
|
| 10 |
+
|
| 11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 12 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 13 |
+
from diffusers import UNet2DConditionModel
|
| 14 |
+
from diffusers.utils import BaseOutput, logging
|
| 15 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
| 16 |
+
from .unet_blocks import (
|
| 17 |
+
CrossAttnDownBlock3D,
|
| 18 |
+
CrossAttnUpBlock3D,
|
| 19 |
+
DownBlock3D,
|
| 20 |
+
UNetMidBlock3DCrossAttn,
|
| 21 |
+
UpBlock3D,
|
| 22 |
+
get_down_block,
|
| 23 |
+
get_up_block,
|
| 24 |
+
)
|
| 25 |
+
from .resnet import InflatedConv3d, InflatedGroupNorm
|
| 26 |
+
|
| 27 |
+
from ..utils.util import zero_rank_log
|
| 28 |
+
from einops import rearrange
|
| 29 |
+
from .utils import zero_module
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class UNet3DConditionOutput(BaseOutput):
|
| 37 |
+
sample: torch.FloatTensor
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
| 41 |
+
_supports_gradient_checkpointing = True
|
| 42 |
+
|
| 43 |
+
@register_to_config
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
sample_size: Optional[int] = None,
|
| 47 |
+
in_channels: int = 4,
|
| 48 |
+
out_channels: int = 4,
|
| 49 |
+
center_input_sample: bool = False,
|
| 50 |
+
flip_sin_to_cos: bool = True,
|
| 51 |
+
freq_shift: int = 0,
|
| 52 |
+
down_block_types: Tuple[str] = (
|
| 53 |
+
"CrossAttnDownBlock3D",
|
| 54 |
+
"CrossAttnDownBlock3D",
|
| 55 |
+
"CrossAttnDownBlock3D",
|
| 56 |
+
"DownBlock3D",
|
| 57 |
+
),
|
| 58 |
+
mid_block_type: str = "UNetMidBlock3DCrossAttn",
|
| 59 |
+
up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
|
| 60 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
| 61 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
| 62 |
+
layers_per_block: int = 2,
|
| 63 |
+
downsample_padding: int = 1,
|
| 64 |
+
mid_block_scale_factor: float = 1,
|
| 65 |
+
act_fn: str = "silu",
|
| 66 |
+
norm_num_groups: int = 32,
|
| 67 |
+
norm_eps: float = 1e-5,
|
| 68 |
+
cross_attention_dim: int = 1280,
|
| 69 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
| 70 |
+
dual_cross_attention: bool = False,
|
| 71 |
+
use_linear_projection: bool = False,
|
| 72 |
+
class_embed_type: Optional[str] = None,
|
| 73 |
+
num_class_embeds: Optional[int] = None,
|
| 74 |
+
upcast_attention: bool = False,
|
| 75 |
+
resnet_time_scale_shift: str = "default",
|
| 76 |
+
use_inflated_groupnorm=False,
|
| 77 |
+
# Additional
|
| 78 |
+
use_motion_module=False,
|
| 79 |
+
motion_module_resolutions=(1, 2, 4, 8),
|
| 80 |
+
motion_module_mid_block=False,
|
| 81 |
+
motion_module_decoder_only=False,
|
| 82 |
+
motion_module_type=None,
|
| 83 |
+
motion_module_kwargs={},
|
| 84 |
+
unet_use_cross_frame_attention=False,
|
| 85 |
+
unet_use_temporal_attention=False,
|
| 86 |
+
add_audio_layer=False,
|
| 87 |
+
audio_condition_method: str = "cross_attn",
|
| 88 |
+
custom_audio_layer=False,
|
| 89 |
+
):
|
| 90 |
+
super().__init__()
|
| 91 |
+
|
| 92 |
+
self.sample_size = sample_size
|
| 93 |
+
time_embed_dim = block_out_channels[0] * 4
|
| 94 |
+
self.use_motion_module = use_motion_module
|
| 95 |
+
self.add_audio_layer = add_audio_layer
|
| 96 |
+
|
| 97 |
+
self.conv_in = zero_module(InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)))
|
| 98 |
+
|
| 99 |
+
# time
|
| 100 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
| 101 |
+
timestep_input_dim = block_out_channels[0]
|
| 102 |
+
|
| 103 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
| 104 |
+
|
| 105 |
+
# class embedding
|
| 106 |
+
if class_embed_type is None and num_class_embeds is not None:
|
| 107 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
| 108 |
+
elif class_embed_type == "timestep":
|
| 109 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
| 110 |
+
elif class_embed_type == "identity":
|
| 111 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
| 112 |
+
else:
|
| 113 |
+
self.class_embedding = None
|
| 114 |
+
|
| 115 |
+
self.down_blocks = nn.ModuleList([])
|
| 116 |
+
self.mid_block = None
|
| 117 |
+
self.up_blocks = nn.ModuleList([])
|
| 118 |
+
|
| 119 |
+
if isinstance(only_cross_attention, bool):
|
| 120 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
| 121 |
+
|
| 122 |
+
if isinstance(attention_head_dim, int):
|
| 123 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
| 124 |
+
|
| 125 |
+
# down
|
| 126 |
+
output_channel = block_out_channels[0]
|
| 127 |
+
for i, down_block_type in enumerate(down_block_types):
|
| 128 |
+
res = 2**i
|
| 129 |
+
input_channel = output_channel
|
| 130 |
+
output_channel = block_out_channels[i]
|
| 131 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 132 |
+
|
| 133 |
+
down_block = get_down_block(
|
| 134 |
+
down_block_type,
|
| 135 |
+
num_layers=layers_per_block,
|
| 136 |
+
in_channels=input_channel,
|
| 137 |
+
out_channels=output_channel,
|
| 138 |
+
temb_channels=time_embed_dim,
|
| 139 |
+
add_downsample=not is_final_block,
|
| 140 |
+
resnet_eps=norm_eps,
|
| 141 |
+
resnet_act_fn=act_fn,
|
| 142 |
+
resnet_groups=norm_num_groups,
|
| 143 |
+
cross_attention_dim=cross_attention_dim,
|
| 144 |
+
attn_num_head_channels=attention_head_dim[i],
|
| 145 |
+
downsample_padding=downsample_padding,
|
| 146 |
+
dual_cross_attention=dual_cross_attention,
|
| 147 |
+
use_linear_projection=use_linear_projection,
|
| 148 |
+
only_cross_attention=only_cross_attention[i],
|
| 149 |
+
upcast_attention=upcast_attention,
|
| 150 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 151 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 152 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 153 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 154 |
+
use_motion_module=use_motion_module
|
| 155 |
+
and (res in motion_module_resolutions)
|
| 156 |
+
and (not motion_module_decoder_only),
|
| 157 |
+
motion_module_type=motion_module_type,
|
| 158 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 159 |
+
add_audio_layer=add_audio_layer,
|
| 160 |
+
audio_condition_method=audio_condition_method,
|
| 161 |
+
custom_audio_layer=custom_audio_layer,
|
| 162 |
+
)
|
| 163 |
+
self.down_blocks.append(down_block)
|
| 164 |
+
|
| 165 |
+
# mid
|
| 166 |
+
if mid_block_type == "UNetMidBlock3DCrossAttn":
|
| 167 |
+
self.mid_block = UNetMidBlock3DCrossAttn(
|
| 168 |
+
in_channels=block_out_channels[-1],
|
| 169 |
+
temb_channels=time_embed_dim,
|
| 170 |
+
resnet_eps=norm_eps,
|
| 171 |
+
resnet_act_fn=act_fn,
|
| 172 |
+
output_scale_factor=mid_block_scale_factor,
|
| 173 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 174 |
+
cross_attention_dim=cross_attention_dim,
|
| 175 |
+
attn_num_head_channels=attention_head_dim[-1],
|
| 176 |
+
resnet_groups=norm_num_groups,
|
| 177 |
+
dual_cross_attention=dual_cross_attention,
|
| 178 |
+
use_linear_projection=use_linear_projection,
|
| 179 |
+
upcast_attention=upcast_attention,
|
| 180 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 181 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 182 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 183 |
+
use_motion_module=use_motion_module and motion_module_mid_block,
|
| 184 |
+
motion_module_type=motion_module_type,
|
| 185 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 186 |
+
add_audio_layer=add_audio_layer,
|
| 187 |
+
audio_condition_method=audio_condition_method,
|
| 188 |
+
custom_audio_layer=custom_audio_layer,
|
| 189 |
+
)
|
| 190 |
+
else:
|
| 191 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
| 192 |
+
|
| 193 |
+
# count how many layers upsample the videos
|
| 194 |
+
self.num_upsamplers = 0
|
| 195 |
+
|
| 196 |
+
# up
|
| 197 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 198 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
| 199 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
| 200 |
+
output_channel = reversed_block_out_channels[0]
|
| 201 |
+
for i, up_block_type in enumerate(up_block_types):
|
| 202 |
+
res = 2 ** (3 - i)
|
| 203 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 204 |
+
|
| 205 |
+
prev_output_channel = output_channel
|
| 206 |
+
output_channel = reversed_block_out_channels[i]
|
| 207 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
| 208 |
+
|
| 209 |
+
# add upsample block for all BUT final layer
|
| 210 |
+
if not is_final_block:
|
| 211 |
+
add_upsample = True
|
| 212 |
+
self.num_upsamplers += 1
|
| 213 |
+
else:
|
| 214 |
+
add_upsample = False
|
| 215 |
+
|
| 216 |
+
up_block = get_up_block(
|
| 217 |
+
up_block_type,
|
| 218 |
+
num_layers=layers_per_block + 1,
|
| 219 |
+
in_channels=input_channel,
|
| 220 |
+
out_channels=output_channel,
|
| 221 |
+
prev_output_channel=prev_output_channel,
|
| 222 |
+
temb_channels=time_embed_dim,
|
| 223 |
+
add_upsample=add_upsample,
|
| 224 |
+
resnet_eps=norm_eps,
|
| 225 |
+
resnet_act_fn=act_fn,
|
| 226 |
+
resnet_groups=norm_num_groups,
|
| 227 |
+
cross_attention_dim=cross_attention_dim,
|
| 228 |
+
attn_num_head_channels=reversed_attention_head_dim[i],
|
| 229 |
+
dual_cross_attention=dual_cross_attention,
|
| 230 |
+
use_linear_projection=use_linear_projection,
|
| 231 |
+
only_cross_attention=only_cross_attention[i],
|
| 232 |
+
upcast_attention=upcast_attention,
|
| 233 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 234 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 235 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 236 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 237 |
+
use_motion_module=use_motion_module and (res in motion_module_resolutions),
|
| 238 |
+
motion_module_type=motion_module_type,
|
| 239 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 240 |
+
add_audio_layer=add_audio_layer,
|
| 241 |
+
audio_condition_method=audio_condition_method,
|
| 242 |
+
custom_audio_layer=custom_audio_layer,
|
| 243 |
+
)
|
| 244 |
+
self.up_blocks.append(up_block)
|
| 245 |
+
prev_output_channel = output_channel
|
| 246 |
+
|
| 247 |
+
# out
|
| 248 |
+
if use_inflated_groupnorm:
|
| 249 |
+
self.conv_norm_out = InflatedGroupNorm(
|
| 250 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
| 251 |
+
)
|
| 252 |
+
else:
|
| 253 |
+
self.conv_norm_out = nn.GroupNorm(
|
| 254 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
| 255 |
+
)
|
| 256 |
+
self.conv_act = nn.SiLU()
|
| 257 |
+
|
| 258 |
+
self.conv_out = zero_module(InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1))
|
| 259 |
+
|
| 260 |
+
def set_attention_slice(self, slice_size):
|
| 261 |
+
r"""
|
| 262 |
+
Enable sliced attention computation.
|
| 263 |
+
|
| 264 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
| 265 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
| 269 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
| 270 |
+
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
| 271 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
| 272 |
+
must be a multiple of `slice_size`.
|
| 273 |
+
"""
|
| 274 |
+
sliceable_head_dims = []
|
| 275 |
+
|
| 276 |
+
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
|
| 277 |
+
if hasattr(module, "set_attention_slice"):
|
| 278 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
| 279 |
+
|
| 280 |
+
for child in module.children():
|
| 281 |
+
fn_recursive_retrieve_slicable_dims(child)
|
| 282 |
+
|
| 283 |
+
# retrieve number of attention layers
|
| 284 |
+
for module in self.children():
|
| 285 |
+
fn_recursive_retrieve_slicable_dims(module)
|
| 286 |
+
|
| 287 |
+
num_slicable_layers = len(sliceable_head_dims)
|
| 288 |
+
|
| 289 |
+
if slice_size == "auto":
|
| 290 |
+
# half the attention head size is usually a good trade-off between
|
| 291 |
+
# speed and memory
|
| 292 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
| 293 |
+
elif slice_size == "max":
|
| 294 |
+
# make smallest slice possible
|
| 295 |
+
slice_size = num_slicable_layers * [1]
|
| 296 |
+
|
| 297 |
+
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
| 298 |
+
|
| 299 |
+
if len(slice_size) != len(sliceable_head_dims):
|
| 300 |
+
raise ValueError(
|
| 301 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
| 302 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
for i in range(len(slice_size)):
|
| 306 |
+
size = slice_size[i]
|
| 307 |
+
dim = sliceable_head_dims[i]
|
| 308 |
+
if size is not None and size > dim:
|
| 309 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
| 310 |
+
|
| 311 |
+
# Recursively walk through all the children.
|
| 312 |
+
# Any children which exposes the set_attention_slice method
|
| 313 |
+
# gets the message
|
| 314 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
| 315 |
+
if hasattr(module, "set_attention_slice"):
|
| 316 |
+
module.set_attention_slice(slice_size.pop())
|
| 317 |
+
|
| 318 |
+
for child in module.children():
|
| 319 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
| 320 |
+
|
| 321 |
+
reversed_slice_size = list(reversed(slice_size))
|
| 322 |
+
for module in self.children():
|
| 323 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
| 324 |
+
|
| 325 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 326 |
+
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
| 327 |
+
module.gradient_checkpointing = value
|
| 328 |
+
|
| 329 |
+
def forward(
|
| 330 |
+
self,
|
| 331 |
+
sample: torch.FloatTensor,
|
| 332 |
+
timestep: Union[torch.Tensor, float, int],
|
| 333 |
+
encoder_hidden_states: torch.Tensor,
|
| 334 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 335 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 336 |
+
# support controlnet
|
| 337 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
| 338 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
| 339 |
+
return_dict: bool = True,
|
| 340 |
+
) -> Union[UNet3DConditionOutput, Tuple]:
|
| 341 |
+
r"""
|
| 342 |
+
Args:
|
| 343 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
| 344 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
| 345 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
| 346 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 347 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
| 348 |
+
|
| 349 |
+
Returns:
|
| 350 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
| 351 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
| 352 |
+
returning a tuple, the first element is the sample tensor.
|
| 353 |
+
"""
|
| 354 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
| 355 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
| 356 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
| 357 |
+
# on the fly if necessary.
|
| 358 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
| 359 |
+
|
| 360 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
| 361 |
+
forward_upsample_size = False
|
| 362 |
+
upsample_size = None
|
| 363 |
+
|
| 364 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
| 365 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
| 366 |
+
forward_upsample_size = True
|
| 367 |
+
|
| 368 |
+
# prepare attention_mask
|
| 369 |
+
if attention_mask is not None:
|
| 370 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
| 371 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 372 |
+
|
| 373 |
+
# center input if necessary
|
| 374 |
+
if self.config.center_input_sample:
|
| 375 |
+
sample = 2 * sample - 1.0
|
| 376 |
+
|
| 377 |
+
# time
|
| 378 |
+
timesteps = timestep
|
| 379 |
+
if not torch.is_tensor(timesteps):
|
| 380 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
| 381 |
+
is_mps = sample.device.type == "mps"
|
| 382 |
+
if isinstance(timestep, float):
|
| 383 |
+
dtype = torch.float32 if is_mps else torch.float64
|
| 384 |
+
else:
|
| 385 |
+
dtype = torch.int32 if is_mps else torch.int64
|
| 386 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| 387 |
+
elif len(timesteps.shape) == 0:
|
| 388 |
+
timesteps = timesteps[None].to(sample.device)
|
| 389 |
+
|
| 390 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 391 |
+
timesteps = timesteps.expand(sample.shape[0])
|
| 392 |
+
|
| 393 |
+
t_emb = self.time_proj(timesteps)
|
| 394 |
+
|
| 395 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
| 396 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 397 |
+
# there might be better ways to encapsulate this.
|
| 398 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
| 399 |
+
emb = self.time_embedding(t_emb)
|
| 400 |
+
|
| 401 |
+
if self.class_embedding is not None:
|
| 402 |
+
if class_labels is None:
|
| 403 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
| 404 |
+
|
| 405 |
+
if self.config.class_embed_type == "timestep":
|
| 406 |
+
class_labels = self.time_proj(class_labels)
|
| 407 |
+
|
| 408 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
| 409 |
+
emb = emb + class_emb
|
| 410 |
+
|
| 411 |
+
# pre-process
|
| 412 |
+
sample = self.conv_in(sample)
|
| 413 |
+
|
| 414 |
+
# down
|
| 415 |
+
down_block_res_samples = (sample,)
|
| 416 |
+
for downsample_block in self.down_blocks:
|
| 417 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 418 |
+
sample, res_samples = downsample_block(
|
| 419 |
+
hidden_states=sample,
|
| 420 |
+
temb=emb,
|
| 421 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 422 |
+
attention_mask=attention_mask,
|
| 423 |
+
)
|
| 424 |
+
else:
|
| 425 |
+
sample, res_samples = downsample_block(
|
| 426 |
+
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
down_block_res_samples += res_samples
|
| 430 |
+
|
| 431 |
+
# support controlnet
|
| 432 |
+
down_block_res_samples = list(down_block_res_samples)
|
| 433 |
+
if down_block_additional_residuals is not None:
|
| 434 |
+
for i, down_block_additional_residual in enumerate(down_block_additional_residuals):
|
| 435 |
+
if down_block_additional_residual.dim() == 4: # boardcast
|
| 436 |
+
down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
|
| 437 |
+
down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
|
| 438 |
+
|
| 439 |
+
# mid
|
| 440 |
+
sample = self.mid_block(
|
| 441 |
+
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# support controlnet
|
| 445 |
+
if mid_block_additional_residual is not None:
|
| 446 |
+
if mid_block_additional_residual.dim() == 4: # boardcast
|
| 447 |
+
mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
|
| 448 |
+
sample = sample + mid_block_additional_residual
|
| 449 |
+
|
| 450 |
+
# up
|
| 451 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 452 |
+
is_final_block = i == len(self.up_blocks) - 1
|
| 453 |
+
|
| 454 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 455 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
| 456 |
+
|
| 457 |
+
# if we have not reached the final block and need to forward the
|
| 458 |
+
# upsample size, we do it here
|
| 459 |
+
if not is_final_block and forward_upsample_size:
|
| 460 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
| 461 |
+
|
| 462 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
| 463 |
+
sample = upsample_block(
|
| 464 |
+
hidden_states=sample,
|
| 465 |
+
temb=emb,
|
| 466 |
+
res_hidden_states_tuple=res_samples,
|
| 467 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 468 |
+
upsample_size=upsample_size,
|
| 469 |
+
attention_mask=attention_mask,
|
| 470 |
+
)
|
| 471 |
+
else:
|
| 472 |
+
sample = upsample_block(
|
| 473 |
+
hidden_states=sample,
|
| 474 |
+
temb=emb,
|
| 475 |
+
res_hidden_states_tuple=res_samples,
|
| 476 |
+
upsample_size=upsample_size,
|
| 477 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
# post-process
|
| 481 |
+
sample = self.conv_norm_out(sample)
|
| 482 |
+
sample = self.conv_act(sample)
|
| 483 |
+
sample = self.conv_out(sample)
|
| 484 |
+
|
| 485 |
+
if not return_dict:
|
| 486 |
+
return (sample,)
|
| 487 |
+
|
| 488 |
+
return UNet3DConditionOutput(sample=sample)
|
| 489 |
+
|
| 490 |
+
def load_state_dict(self, state_dict, strict=True):
|
| 491 |
+
# If the loaded checkpoint's in_channels or out_channels are different from config
|
| 492 |
+
temp_state_dict = copy.deepcopy(state_dict)
|
| 493 |
+
if temp_state_dict["conv_in.weight"].shape[1] != self.config.in_channels:
|
| 494 |
+
del temp_state_dict["conv_in.weight"]
|
| 495 |
+
del temp_state_dict["conv_in.bias"]
|
| 496 |
+
if temp_state_dict["conv_out.weight"].shape[0] != self.config.out_channels:
|
| 497 |
+
del temp_state_dict["conv_out.weight"]
|
| 498 |
+
del temp_state_dict["conv_out.bias"]
|
| 499 |
+
|
| 500 |
+
# If the loaded checkpoint's cross_attention_dim is different from config
|
| 501 |
+
keys_to_remove = []
|
| 502 |
+
for key in temp_state_dict:
|
| 503 |
+
if "audio_cross_attn.attn.to_k." in key or "audio_cross_attn.attn.to_v." in key:
|
| 504 |
+
if temp_state_dict[key].shape[1] != self.config.cross_attention_dim:
|
| 505 |
+
keys_to_remove.append(key)
|
| 506 |
+
|
| 507 |
+
for key in keys_to_remove:
|
| 508 |
+
del temp_state_dict[key]
|
| 509 |
+
|
| 510 |
+
return super().load_state_dict(state_dict=temp_state_dict, strict=strict)
|
| 511 |
+
|
| 512 |
+
@classmethod
|
| 513 |
+
def from_pretrained(cls, model_config: dict, ckpt_path: str, device="cpu"):
|
| 514 |
+
unet = cls.from_config(model_config).to(device)
|
| 515 |
+
if ckpt_path != "":
|
| 516 |
+
zero_rank_log(logger, f"Load from checkpoint: {ckpt_path}")
|
| 517 |
+
ckpt = torch.load(ckpt_path, map_location=device)
|
| 518 |
+
if "global_step" in ckpt:
|
| 519 |
+
zero_rank_log(logger, f"resume from global_step: {ckpt['global_step']}")
|
| 520 |
+
resume_global_step = ckpt["global_step"]
|
| 521 |
+
else:
|
| 522 |
+
resume_global_step = 0
|
| 523 |
+
state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
|
| 524 |
+
unet.load_state_dict(state_dict, strict=False)
|
| 525 |
+
else:
|
| 526 |
+
resume_global_step = 0
|
| 527 |
+
|
| 528 |
+
return unet, resume_global_step
|
latentsync/models/unet_blocks.py
ADDED
|
@@ -0,0 +1,903 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet_blocks.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
from .attention import Transformer3DModel
|
| 7 |
+
from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
|
| 8 |
+
from .motion_module import get_motion_module
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_down_block(
|
| 12 |
+
down_block_type,
|
| 13 |
+
num_layers,
|
| 14 |
+
in_channels,
|
| 15 |
+
out_channels,
|
| 16 |
+
temb_channels,
|
| 17 |
+
add_downsample,
|
| 18 |
+
resnet_eps,
|
| 19 |
+
resnet_act_fn,
|
| 20 |
+
attn_num_head_channels,
|
| 21 |
+
resnet_groups=None,
|
| 22 |
+
cross_attention_dim=None,
|
| 23 |
+
downsample_padding=None,
|
| 24 |
+
dual_cross_attention=False,
|
| 25 |
+
use_linear_projection=False,
|
| 26 |
+
only_cross_attention=False,
|
| 27 |
+
upcast_attention=False,
|
| 28 |
+
resnet_time_scale_shift="default",
|
| 29 |
+
unet_use_cross_frame_attention=False,
|
| 30 |
+
unet_use_temporal_attention=False,
|
| 31 |
+
use_inflated_groupnorm=False,
|
| 32 |
+
use_motion_module=None,
|
| 33 |
+
motion_module_type=None,
|
| 34 |
+
motion_module_kwargs=None,
|
| 35 |
+
add_audio_layer=False,
|
| 36 |
+
audio_condition_method="cross_attn",
|
| 37 |
+
custom_audio_layer=False,
|
| 38 |
+
):
|
| 39 |
+
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
| 40 |
+
if down_block_type == "DownBlock3D":
|
| 41 |
+
return DownBlock3D(
|
| 42 |
+
num_layers=num_layers,
|
| 43 |
+
in_channels=in_channels,
|
| 44 |
+
out_channels=out_channels,
|
| 45 |
+
temb_channels=temb_channels,
|
| 46 |
+
add_downsample=add_downsample,
|
| 47 |
+
resnet_eps=resnet_eps,
|
| 48 |
+
resnet_act_fn=resnet_act_fn,
|
| 49 |
+
resnet_groups=resnet_groups,
|
| 50 |
+
downsample_padding=downsample_padding,
|
| 51 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 52 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 53 |
+
use_motion_module=use_motion_module,
|
| 54 |
+
motion_module_type=motion_module_type,
|
| 55 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 56 |
+
)
|
| 57 |
+
elif down_block_type == "CrossAttnDownBlock3D":
|
| 58 |
+
if cross_attention_dim is None:
|
| 59 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
|
| 60 |
+
return CrossAttnDownBlock3D(
|
| 61 |
+
num_layers=num_layers,
|
| 62 |
+
in_channels=in_channels,
|
| 63 |
+
out_channels=out_channels,
|
| 64 |
+
temb_channels=temb_channels,
|
| 65 |
+
add_downsample=add_downsample,
|
| 66 |
+
resnet_eps=resnet_eps,
|
| 67 |
+
resnet_act_fn=resnet_act_fn,
|
| 68 |
+
resnet_groups=resnet_groups,
|
| 69 |
+
downsample_padding=downsample_padding,
|
| 70 |
+
cross_attention_dim=cross_attention_dim,
|
| 71 |
+
attn_num_head_channels=attn_num_head_channels,
|
| 72 |
+
dual_cross_attention=dual_cross_attention,
|
| 73 |
+
use_linear_projection=use_linear_projection,
|
| 74 |
+
only_cross_attention=only_cross_attention,
|
| 75 |
+
upcast_attention=upcast_attention,
|
| 76 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 77 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 78 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 79 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 80 |
+
use_motion_module=use_motion_module,
|
| 81 |
+
motion_module_type=motion_module_type,
|
| 82 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 83 |
+
add_audio_layer=add_audio_layer,
|
| 84 |
+
audio_condition_method=audio_condition_method,
|
| 85 |
+
custom_audio_layer=custom_audio_layer,
|
| 86 |
+
)
|
| 87 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_up_block(
|
| 91 |
+
up_block_type,
|
| 92 |
+
num_layers,
|
| 93 |
+
in_channels,
|
| 94 |
+
out_channels,
|
| 95 |
+
prev_output_channel,
|
| 96 |
+
temb_channels,
|
| 97 |
+
add_upsample,
|
| 98 |
+
resnet_eps,
|
| 99 |
+
resnet_act_fn,
|
| 100 |
+
attn_num_head_channels,
|
| 101 |
+
resnet_groups=None,
|
| 102 |
+
cross_attention_dim=None,
|
| 103 |
+
dual_cross_attention=False,
|
| 104 |
+
use_linear_projection=False,
|
| 105 |
+
only_cross_attention=False,
|
| 106 |
+
upcast_attention=False,
|
| 107 |
+
resnet_time_scale_shift="default",
|
| 108 |
+
unet_use_cross_frame_attention=False,
|
| 109 |
+
unet_use_temporal_attention=False,
|
| 110 |
+
use_inflated_groupnorm=False,
|
| 111 |
+
use_motion_module=None,
|
| 112 |
+
motion_module_type=None,
|
| 113 |
+
motion_module_kwargs=None,
|
| 114 |
+
add_audio_layer=False,
|
| 115 |
+
audio_condition_method="cross_attn",
|
| 116 |
+
custom_audio_layer=False,
|
| 117 |
+
):
|
| 118 |
+
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
| 119 |
+
if up_block_type == "UpBlock3D":
|
| 120 |
+
return UpBlock3D(
|
| 121 |
+
num_layers=num_layers,
|
| 122 |
+
in_channels=in_channels,
|
| 123 |
+
out_channels=out_channels,
|
| 124 |
+
prev_output_channel=prev_output_channel,
|
| 125 |
+
temb_channels=temb_channels,
|
| 126 |
+
add_upsample=add_upsample,
|
| 127 |
+
resnet_eps=resnet_eps,
|
| 128 |
+
resnet_act_fn=resnet_act_fn,
|
| 129 |
+
resnet_groups=resnet_groups,
|
| 130 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 131 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 132 |
+
use_motion_module=use_motion_module,
|
| 133 |
+
motion_module_type=motion_module_type,
|
| 134 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 135 |
+
)
|
| 136 |
+
elif up_block_type == "CrossAttnUpBlock3D":
|
| 137 |
+
if cross_attention_dim is None:
|
| 138 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
|
| 139 |
+
return CrossAttnUpBlock3D(
|
| 140 |
+
num_layers=num_layers,
|
| 141 |
+
in_channels=in_channels,
|
| 142 |
+
out_channels=out_channels,
|
| 143 |
+
prev_output_channel=prev_output_channel,
|
| 144 |
+
temb_channels=temb_channels,
|
| 145 |
+
add_upsample=add_upsample,
|
| 146 |
+
resnet_eps=resnet_eps,
|
| 147 |
+
resnet_act_fn=resnet_act_fn,
|
| 148 |
+
resnet_groups=resnet_groups,
|
| 149 |
+
cross_attention_dim=cross_attention_dim,
|
| 150 |
+
attn_num_head_channels=attn_num_head_channels,
|
| 151 |
+
dual_cross_attention=dual_cross_attention,
|
| 152 |
+
use_linear_projection=use_linear_projection,
|
| 153 |
+
only_cross_attention=only_cross_attention,
|
| 154 |
+
upcast_attention=upcast_attention,
|
| 155 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 156 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 157 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 158 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 159 |
+
use_motion_module=use_motion_module,
|
| 160 |
+
motion_module_type=motion_module_type,
|
| 161 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 162 |
+
add_audio_layer=add_audio_layer,
|
| 163 |
+
audio_condition_method=audio_condition_method,
|
| 164 |
+
custom_audio_layer=custom_audio_layer,
|
| 165 |
+
)
|
| 166 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class UNetMidBlock3DCrossAttn(nn.Module):
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
in_channels: int,
|
| 173 |
+
temb_channels: int,
|
| 174 |
+
dropout: float = 0.0,
|
| 175 |
+
num_layers: int = 1,
|
| 176 |
+
resnet_eps: float = 1e-6,
|
| 177 |
+
resnet_time_scale_shift: str = "default",
|
| 178 |
+
resnet_act_fn: str = "swish",
|
| 179 |
+
resnet_groups: int = 32,
|
| 180 |
+
resnet_pre_norm: bool = True,
|
| 181 |
+
attn_num_head_channels=1,
|
| 182 |
+
output_scale_factor=1.0,
|
| 183 |
+
cross_attention_dim=1280,
|
| 184 |
+
dual_cross_attention=False,
|
| 185 |
+
use_linear_projection=False,
|
| 186 |
+
upcast_attention=False,
|
| 187 |
+
unet_use_cross_frame_attention=False,
|
| 188 |
+
unet_use_temporal_attention=False,
|
| 189 |
+
use_inflated_groupnorm=False,
|
| 190 |
+
use_motion_module=None,
|
| 191 |
+
motion_module_type=None,
|
| 192 |
+
motion_module_kwargs=None,
|
| 193 |
+
add_audio_layer=False,
|
| 194 |
+
audio_condition_method="cross_attn",
|
| 195 |
+
custom_audio_layer: bool = False,
|
| 196 |
+
):
|
| 197 |
+
super().__init__()
|
| 198 |
+
|
| 199 |
+
self.has_cross_attention = True
|
| 200 |
+
self.attn_num_head_channels = attn_num_head_channels
|
| 201 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
| 202 |
+
|
| 203 |
+
# there is always at least one resnet
|
| 204 |
+
resnets = [
|
| 205 |
+
ResnetBlock3D(
|
| 206 |
+
in_channels=in_channels,
|
| 207 |
+
out_channels=in_channels,
|
| 208 |
+
temb_channels=temb_channels,
|
| 209 |
+
eps=resnet_eps,
|
| 210 |
+
groups=resnet_groups,
|
| 211 |
+
dropout=dropout,
|
| 212 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 213 |
+
non_linearity=resnet_act_fn,
|
| 214 |
+
output_scale_factor=output_scale_factor,
|
| 215 |
+
pre_norm=resnet_pre_norm,
|
| 216 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 217 |
+
)
|
| 218 |
+
]
|
| 219 |
+
attentions = []
|
| 220 |
+
audio_attentions = []
|
| 221 |
+
motion_modules = []
|
| 222 |
+
|
| 223 |
+
for _ in range(num_layers):
|
| 224 |
+
if dual_cross_attention:
|
| 225 |
+
raise NotImplementedError
|
| 226 |
+
attentions.append(
|
| 227 |
+
Transformer3DModel(
|
| 228 |
+
attn_num_head_channels,
|
| 229 |
+
in_channels // attn_num_head_channels,
|
| 230 |
+
in_channels=in_channels,
|
| 231 |
+
num_layers=1,
|
| 232 |
+
cross_attention_dim=cross_attention_dim,
|
| 233 |
+
norm_num_groups=resnet_groups,
|
| 234 |
+
use_linear_projection=use_linear_projection,
|
| 235 |
+
upcast_attention=upcast_attention,
|
| 236 |
+
use_motion_module=use_motion_module,
|
| 237 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 238 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 239 |
+
add_audio_layer=add_audio_layer,
|
| 240 |
+
audio_condition_method=audio_condition_method,
|
| 241 |
+
)
|
| 242 |
+
)
|
| 243 |
+
audio_attentions.append(
|
| 244 |
+
Transformer3DModel(
|
| 245 |
+
attn_num_head_channels,
|
| 246 |
+
in_channels // attn_num_head_channels,
|
| 247 |
+
in_channels=in_channels,
|
| 248 |
+
num_layers=1,
|
| 249 |
+
cross_attention_dim=cross_attention_dim,
|
| 250 |
+
norm_num_groups=resnet_groups,
|
| 251 |
+
use_linear_projection=use_linear_projection,
|
| 252 |
+
upcast_attention=upcast_attention,
|
| 253 |
+
use_motion_module=use_motion_module,
|
| 254 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 255 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 256 |
+
add_audio_layer=add_audio_layer,
|
| 257 |
+
audio_condition_method=audio_condition_method,
|
| 258 |
+
custom_audio_layer=True,
|
| 259 |
+
)
|
| 260 |
+
if custom_audio_layer
|
| 261 |
+
else None
|
| 262 |
+
)
|
| 263 |
+
motion_modules.append(
|
| 264 |
+
get_motion_module(
|
| 265 |
+
in_channels=in_channels,
|
| 266 |
+
motion_module_type=motion_module_type,
|
| 267 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 268 |
+
)
|
| 269 |
+
if use_motion_module
|
| 270 |
+
else None
|
| 271 |
+
)
|
| 272 |
+
resnets.append(
|
| 273 |
+
ResnetBlock3D(
|
| 274 |
+
in_channels=in_channels,
|
| 275 |
+
out_channels=in_channels,
|
| 276 |
+
temb_channels=temb_channels,
|
| 277 |
+
eps=resnet_eps,
|
| 278 |
+
groups=resnet_groups,
|
| 279 |
+
dropout=dropout,
|
| 280 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 281 |
+
non_linearity=resnet_act_fn,
|
| 282 |
+
output_scale_factor=output_scale_factor,
|
| 283 |
+
pre_norm=resnet_pre_norm,
|
| 284 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 285 |
+
)
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
self.attentions = nn.ModuleList(attentions)
|
| 289 |
+
self.audio_attentions = nn.ModuleList(audio_attentions)
|
| 290 |
+
self.resnets = nn.ModuleList(resnets)
|
| 291 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
| 292 |
+
|
| 293 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
| 294 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
| 295 |
+
for attn, audio_attn, resnet, motion_module in zip(
|
| 296 |
+
self.attentions, self.audio_attentions, self.resnets[1:], self.motion_modules
|
| 297 |
+
):
|
| 298 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
| 299 |
+
hidden_states = (
|
| 300 |
+
audio_attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
| 301 |
+
if audio_attn is not None
|
| 302 |
+
else hidden_states
|
| 303 |
+
)
|
| 304 |
+
hidden_states = (
|
| 305 |
+
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
| 306 |
+
if motion_module is not None
|
| 307 |
+
else hidden_states
|
| 308 |
+
)
|
| 309 |
+
hidden_states = resnet(hidden_states, temb)
|
| 310 |
+
|
| 311 |
+
return hidden_states
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class CrossAttnDownBlock3D(nn.Module):
|
| 315 |
+
def __init__(
|
| 316 |
+
self,
|
| 317 |
+
in_channels: int,
|
| 318 |
+
out_channels: int,
|
| 319 |
+
temb_channels: int,
|
| 320 |
+
dropout: float = 0.0,
|
| 321 |
+
num_layers: int = 1,
|
| 322 |
+
resnet_eps: float = 1e-6,
|
| 323 |
+
resnet_time_scale_shift: str = "default",
|
| 324 |
+
resnet_act_fn: str = "swish",
|
| 325 |
+
resnet_groups: int = 32,
|
| 326 |
+
resnet_pre_norm: bool = True,
|
| 327 |
+
attn_num_head_channels=1,
|
| 328 |
+
cross_attention_dim=1280,
|
| 329 |
+
output_scale_factor=1.0,
|
| 330 |
+
downsample_padding=1,
|
| 331 |
+
add_downsample=True,
|
| 332 |
+
dual_cross_attention=False,
|
| 333 |
+
use_linear_projection=False,
|
| 334 |
+
only_cross_attention=False,
|
| 335 |
+
upcast_attention=False,
|
| 336 |
+
unet_use_cross_frame_attention=False,
|
| 337 |
+
unet_use_temporal_attention=False,
|
| 338 |
+
use_inflated_groupnorm=False,
|
| 339 |
+
use_motion_module=None,
|
| 340 |
+
motion_module_type=None,
|
| 341 |
+
motion_module_kwargs=None,
|
| 342 |
+
add_audio_layer=False,
|
| 343 |
+
audio_condition_method="cross_attn",
|
| 344 |
+
custom_audio_layer: bool = False,
|
| 345 |
+
):
|
| 346 |
+
super().__init__()
|
| 347 |
+
resnets = []
|
| 348 |
+
attentions = []
|
| 349 |
+
audio_attentions = []
|
| 350 |
+
motion_modules = []
|
| 351 |
+
|
| 352 |
+
self.has_cross_attention = True
|
| 353 |
+
self.attn_num_head_channels = attn_num_head_channels
|
| 354 |
+
|
| 355 |
+
for i in range(num_layers):
|
| 356 |
+
in_channels = in_channels if i == 0 else out_channels
|
| 357 |
+
resnets.append(
|
| 358 |
+
ResnetBlock3D(
|
| 359 |
+
in_channels=in_channels,
|
| 360 |
+
out_channels=out_channels,
|
| 361 |
+
temb_channels=temb_channels,
|
| 362 |
+
eps=resnet_eps,
|
| 363 |
+
groups=resnet_groups,
|
| 364 |
+
dropout=dropout,
|
| 365 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 366 |
+
non_linearity=resnet_act_fn,
|
| 367 |
+
output_scale_factor=output_scale_factor,
|
| 368 |
+
pre_norm=resnet_pre_norm,
|
| 369 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 370 |
+
)
|
| 371 |
+
)
|
| 372 |
+
if dual_cross_attention:
|
| 373 |
+
raise NotImplementedError
|
| 374 |
+
attentions.append(
|
| 375 |
+
Transformer3DModel(
|
| 376 |
+
attn_num_head_channels,
|
| 377 |
+
out_channels // attn_num_head_channels,
|
| 378 |
+
in_channels=out_channels,
|
| 379 |
+
num_layers=1,
|
| 380 |
+
cross_attention_dim=cross_attention_dim,
|
| 381 |
+
norm_num_groups=resnet_groups,
|
| 382 |
+
use_linear_projection=use_linear_projection,
|
| 383 |
+
only_cross_attention=only_cross_attention,
|
| 384 |
+
upcast_attention=upcast_attention,
|
| 385 |
+
use_motion_module=use_motion_module,
|
| 386 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 387 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 388 |
+
add_audio_layer=add_audio_layer,
|
| 389 |
+
audio_condition_method=audio_condition_method,
|
| 390 |
+
)
|
| 391 |
+
)
|
| 392 |
+
audio_attentions.append(
|
| 393 |
+
Transformer3DModel(
|
| 394 |
+
attn_num_head_channels,
|
| 395 |
+
out_channels // attn_num_head_channels,
|
| 396 |
+
in_channels=out_channels,
|
| 397 |
+
num_layers=1,
|
| 398 |
+
cross_attention_dim=cross_attention_dim,
|
| 399 |
+
norm_num_groups=resnet_groups,
|
| 400 |
+
use_linear_projection=use_linear_projection,
|
| 401 |
+
only_cross_attention=only_cross_attention,
|
| 402 |
+
upcast_attention=upcast_attention,
|
| 403 |
+
use_motion_module=use_motion_module,
|
| 404 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 405 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 406 |
+
add_audio_layer=add_audio_layer,
|
| 407 |
+
audio_condition_method=audio_condition_method,
|
| 408 |
+
custom_audio_layer=True,
|
| 409 |
+
)
|
| 410 |
+
if custom_audio_layer
|
| 411 |
+
else None
|
| 412 |
+
)
|
| 413 |
+
motion_modules.append(
|
| 414 |
+
get_motion_module(
|
| 415 |
+
in_channels=out_channels,
|
| 416 |
+
motion_module_type=motion_module_type,
|
| 417 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 418 |
+
)
|
| 419 |
+
if use_motion_module
|
| 420 |
+
else None
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
self.attentions = nn.ModuleList(attentions)
|
| 424 |
+
self.audio_attentions = nn.ModuleList(audio_attentions)
|
| 425 |
+
self.resnets = nn.ModuleList(resnets)
|
| 426 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
| 427 |
+
|
| 428 |
+
if add_downsample:
|
| 429 |
+
self.downsamplers = nn.ModuleList(
|
| 430 |
+
[
|
| 431 |
+
Downsample3D(
|
| 432 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
| 433 |
+
)
|
| 434 |
+
]
|
| 435 |
+
)
|
| 436 |
+
else:
|
| 437 |
+
self.downsamplers = None
|
| 438 |
+
|
| 439 |
+
self.gradient_checkpointing = False
|
| 440 |
+
|
| 441 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
| 442 |
+
output_states = ()
|
| 443 |
+
|
| 444 |
+
for resnet, attn, audio_attn, motion_module in zip(
|
| 445 |
+
self.resnets, self.attentions, self.audio_attentions, self.motion_modules
|
| 446 |
+
):
|
| 447 |
+
if self.training and self.gradient_checkpointing:
|
| 448 |
+
|
| 449 |
+
def create_custom_forward(module, return_dict=None):
|
| 450 |
+
def custom_forward(*inputs):
|
| 451 |
+
if return_dict is not None:
|
| 452 |
+
return module(*inputs, return_dict=return_dict)
|
| 453 |
+
else:
|
| 454 |
+
return module(*inputs)
|
| 455 |
+
|
| 456 |
+
return custom_forward
|
| 457 |
+
|
| 458 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
| 459 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 460 |
+
create_custom_forward(attn, return_dict=False),
|
| 461 |
+
hidden_states,
|
| 462 |
+
encoder_hidden_states,
|
| 463 |
+
)[0]
|
| 464 |
+
if motion_module is not None:
|
| 465 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 466 |
+
create_custom_forward(motion_module),
|
| 467 |
+
hidden_states.requires_grad_(),
|
| 468 |
+
temb,
|
| 469 |
+
encoder_hidden_states,
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
else:
|
| 473 |
+
hidden_states = resnet(hidden_states, temb)
|
| 474 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
| 475 |
+
|
| 476 |
+
hidden_states = (
|
| 477 |
+
audio_attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
| 478 |
+
if audio_attn is not None
|
| 479 |
+
else hidden_states
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
# add motion module
|
| 483 |
+
hidden_states = (
|
| 484 |
+
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
| 485 |
+
if motion_module is not None
|
| 486 |
+
else hidden_states
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
output_states += (hidden_states,)
|
| 490 |
+
|
| 491 |
+
if self.downsamplers is not None:
|
| 492 |
+
for downsampler in self.downsamplers:
|
| 493 |
+
hidden_states = downsampler(hidden_states)
|
| 494 |
+
|
| 495 |
+
output_states += (hidden_states,)
|
| 496 |
+
|
| 497 |
+
return hidden_states, output_states
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
class DownBlock3D(nn.Module):
|
| 501 |
+
def __init__(
|
| 502 |
+
self,
|
| 503 |
+
in_channels: int,
|
| 504 |
+
out_channels: int,
|
| 505 |
+
temb_channels: int,
|
| 506 |
+
dropout: float = 0.0,
|
| 507 |
+
num_layers: int = 1,
|
| 508 |
+
resnet_eps: float = 1e-6,
|
| 509 |
+
resnet_time_scale_shift: str = "default",
|
| 510 |
+
resnet_act_fn: str = "swish",
|
| 511 |
+
resnet_groups: int = 32,
|
| 512 |
+
resnet_pre_norm: bool = True,
|
| 513 |
+
output_scale_factor=1.0,
|
| 514 |
+
add_downsample=True,
|
| 515 |
+
downsample_padding=1,
|
| 516 |
+
use_inflated_groupnorm=False,
|
| 517 |
+
use_motion_module=None,
|
| 518 |
+
motion_module_type=None,
|
| 519 |
+
motion_module_kwargs=None,
|
| 520 |
+
):
|
| 521 |
+
super().__init__()
|
| 522 |
+
resnets = []
|
| 523 |
+
motion_modules = []
|
| 524 |
+
|
| 525 |
+
for i in range(num_layers):
|
| 526 |
+
in_channels = in_channels if i == 0 else out_channels
|
| 527 |
+
resnets.append(
|
| 528 |
+
ResnetBlock3D(
|
| 529 |
+
in_channels=in_channels,
|
| 530 |
+
out_channels=out_channels,
|
| 531 |
+
temb_channels=temb_channels,
|
| 532 |
+
eps=resnet_eps,
|
| 533 |
+
groups=resnet_groups,
|
| 534 |
+
dropout=dropout,
|
| 535 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 536 |
+
non_linearity=resnet_act_fn,
|
| 537 |
+
output_scale_factor=output_scale_factor,
|
| 538 |
+
pre_norm=resnet_pre_norm,
|
| 539 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 540 |
+
)
|
| 541 |
+
)
|
| 542 |
+
motion_modules.append(
|
| 543 |
+
get_motion_module(
|
| 544 |
+
in_channels=out_channels,
|
| 545 |
+
motion_module_type=motion_module_type,
|
| 546 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 547 |
+
)
|
| 548 |
+
if use_motion_module
|
| 549 |
+
else None
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
self.resnets = nn.ModuleList(resnets)
|
| 553 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
| 554 |
+
|
| 555 |
+
if add_downsample:
|
| 556 |
+
self.downsamplers = nn.ModuleList(
|
| 557 |
+
[
|
| 558 |
+
Downsample3D(
|
| 559 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
| 560 |
+
)
|
| 561 |
+
]
|
| 562 |
+
)
|
| 563 |
+
else:
|
| 564 |
+
self.downsamplers = None
|
| 565 |
+
|
| 566 |
+
self.gradient_checkpointing = False
|
| 567 |
+
|
| 568 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
| 569 |
+
output_states = ()
|
| 570 |
+
|
| 571 |
+
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
| 572 |
+
if self.training and self.gradient_checkpointing:
|
| 573 |
+
|
| 574 |
+
def create_custom_forward(module):
|
| 575 |
+
def custom_forward(*inputs):
|
| 576 |
+
return module(*inputs)
|
| 577 |
+
|
| 578 |
+
return custom_forward
|
| 579 |
+
|
| 580 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
| 581 |
+
if motion_module is not None:
|
| 582 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 583 |
+
create_custom_forward(motion_module),
|
| 584 |
+
hidden_states.requires_grad_(),
|
| 585 |
+
temb,
|
| 586 |
+
encoder_hidden_states,
|
| 587 |
+
)
|
| 588 |
+
else:
|
| 589 |
+
hidden_states = resnet(hidden_states, temb)
|
| 590 |
+
|
| 591 |
+
# add motion module
|
| 592 |
+
hidden_states = (
|
| 593 |
+
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
| 594 |
+
if motion_module is not None
|
| 595 |
+
else hidden_states
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
output_states += (hidden_states,)
|
| 599 |
+
|
| 600 |
+
if self.downsamplers is not None:
|
| 601 |
+
for downsampler in self.downsamplers:
|
| 602 |
+
hidden_states = downsampler(hidden_states)
|
| 603 |
+
|
| 604 |
+
output_states += (hidden_states,)
|
| 605 |
+
|
| 606 |
+
return hidden_states, output_states
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
class CrossAttnUpBlock3D(nn.Module):
|
| 610 |
+
def __init__(
|
| 611 |
+
self,
|
| 612 |
+
in_channels: int,
|
| 613 |
+
out_channels: int,
|
| 614 |
+
prev_output_channel: int,
|
| 615 |
+
temb_channels: int,
|
| 616 |
+
dropout: float = 0.0,
|
| 617 |
+
num_layers: int = 1,
|
| 618 |
+
resnet_eps: float = 1e-6,
|
| 619 |
+
resnet_time_scale_shift: str = "default",
|
| 620 |
+
resnet_act_fn: str = "swish",
|
| 621 |
+
resnet_groups: int = 32,
|
| 622 |
+
resnet_pre_norm: bool = True,
|
| 623 |
+
attn_num_head_channels=1,
|
| 624 |
+
cross_attention_dim=1280,
|
| 625 |
+
output_scale_factor=1.0,
|
| 626 |
+
add_upsample=True,
|
| 627 |
+
dual_cross_attention=False,
|
| 628 |
+
use_linear_projection=False,
|
| 629 |
+
only_cross_attention=False,
|
| 630 |
+
upcast_attention=False,
|
| 631 |
+
unet_use_cross_frame_attention=False,
|
| 632 |
+
unet_use_temporal_attention=False,
|
| 633 |
+
use_inflated_groupnorm=False,
|
| 634 |
+
use_motion_module=None,
|
| 635 |
+
motion_module_type=None,
|
| 636 |
+
motion_module_kwargs=None,
|
| 637 |
+
add_audio_layer=False,
|
| 638 |
+
audio_condition_method="cross_attn",
|
| 639 |
+
custom_audio_layer=False,
|
| 640 |
+
):
|
| 641 |
+
super().__init__()
|
| 642 |
+
resnets = []
|
| 643 |
+
attentions = []
|
| 644 |
+
audio_attentions = []
|
| 645 |
+
motion_modules = []
|
| 646 |
+
|
| 647 |
+
self.has_cross_attention = True
|
| 648 |
+
self.attn_num_head_channels = attn_num_head_channels
|
| 649 |
+
|
| 650 |
+
for i in range(num_layers):
|
| 651 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
| 652 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
| 653 |
+
|
| 654 |
+
resnets.append(
|
| 655 |
+
ResnetBlock3D(
|
| 656 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
| 657 |
+
out_channels=out_channels,
|
| 658 |
+
temb_channels=temb_channels,
|
| 659 |
+
eps=resnet_eps,
|
| 660 |
+
groups=resnet_groups,
|
| 661 |
+
dropout=dropout,
|
| 662 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 663 |
+
non_linearity=resnet_act_fn,
|
| 664 |
+
output_scale_factor=output_scale_factor,
|
| 665 |
+
pre_norm=resnet_pre_norm,
|
| 666 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 667 |
+
)
|
| 668 |
+
)
|
| 669 |
+
if dual_cross_attention:
|
| 670 |
+
raise NotImplementedError
|
| 671 |
+
attentions.append(
|
| 672 |
+
Transformer3DModel(
|
| 673 |
+
attn_num_head_channels,
|
| 674 |
+
out_channels // attn_num_head_channels,
|
| 675 |
+
in_channels=out_channels,
|
| 676 |
+
num_layers=1,
|
| 677 |
+
cross_attention_dim=cross_attention_dim,
|
| 678 |
+
norm_num_groups=resnet_groups,
|
| 679 |
+
use_linear_projection=use_linear_projection,
|
| 680 |
+
only_cross_attention=only_cross_attention,
|
| 681 |
+
upcast_attention=upcast_attention,
|
| 682 |
+
use_motion_module=use_motion_module,
|
| 683 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 684 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 685 |
+
add_audio_layer=add_audio_layer,
|
| 686 |
+
audio_condition_method=audio_condition_method,
|
| 687 |
+
)
|
| 688 |
+
)
|
| 689 |
+
audio_attentions.append(
|
| 690 |
+
Transformer3DModel(
|
| 691 |
+
attn_num_head_channels,
|
| 692 |
+
out_channels // attn_num_head_channels,
|
| 693 |
+
in_channels=out_channels,
|
| 694 |
+
num_layers=1,
|
| 695 |
+
cross_attention_dim=cross_attention_dim,
|
| 696 |
+
norm_num_groups=resnet_groups,
|
| 697 |
+
use_linear_projection=use_linear_projection,
|
| 698 |
+
only_cross_attention=only_cross_attention,
|
| 699 |
+
upcast_attention=upcast_attention,
|
| 700 |
+
use_motion_module=use_motion_module,
|
| 701 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 702 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 703 |
+
add_audio_layer=add_audio_layer,
|
| 704 |
+
audio_condition_method=audio_condition_method,
|
| 705 |
+
custom_audio_layer=True,
|
| 706 |
+
)
|
| 707 |
+
if custom_audio_layer
|
| 708 |
+
else None
|
| 709 |
+
)
|
| 710 |
+
motion_modules.append(
|
| 711 |
+
get_motion_module(
|
| 712 |
+
in_channels=out_channels,
|
| 713 |
+
motion_module_type=motion_module_type,
|
| 714 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 715 |
+
)
|
| 716 |
+
if use_motion_module
|
| 717 |
+
else None
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
self.attentions = nn.ModuleList(attentions)
|
| 721 |
+
self.audio_attentions = nn.ModuleList(audio_attentions)
|
| 722 |
+
self.resnets = nn.ModuleList(resnets)
|
| 723 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
| 724 |
+
|
| 725 |
+
if add_upsample:
|
| 726 |
+
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
| 727 |
+
else:
|
| 728 |
+
self.upsamplers = None
|
| 729 |
+
|
| 730 |
+
self.gradient_checkpointing = False
|
| 731 |
+
|
| 732 |
+
def forward(
|
| 733 |
+
self,
|
| 734 |
+
hidden_states,
|
| 735 |
+
res_hidden_states_tuple,
|
| 736 |
+
temb=None,
|
| 737 |
+
encoder_hidden_states=None,
|
| 738 |
+
upsample_size=None,
|
| 739 |
+
attention_mask=None,
|
| 740 |
+
):
|
| 741 |
+
for resnet, attn, audio_attn, motion_module in zip(
|
| 742 |
+
self.resnets, self.attentions, self.audio_attentions, self.motion_modules
|
| 743 |
+
):
|
| 744 |
+
# pop res hidden states
|
| 745 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 746 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 747 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 748 |
+
|
| 749 |
+
if self.training and self.gradient_checkpointing:
|
| 750 |
+
|
| 751 |
+
def create_custom_forward(module, return_dict=None):
|
| 752 |
+
def custom_forward(*inputs):
|
| 753 |
+
if return_dict is not None:
|
| 754 |
+
return module(*inputs, return_dict=return_dict)
|
| 755 |
+
else:
|
| 756 |
+
return module(*inputs)
|
| 757 |
+
|
| 758 |
+
return custom_forward
|
| 759 |
+
|
| 760 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
| 761 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 762 |
+
create_custom_forward(attn, return_dict=False),
|
| 763 |
+
hidden_states,
|
| 764 |
+
encoder_hidden_states,
|
| 765 |
+
)[0]
|
| 766 |
+
if motion_module is not None:
|
| 767 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 768 |
+
create_custom_forward(motion_module),
|
| 769 |
+
hidden_states.requires_grad_(),
|
| 770 |
+
temb,
|
| 771 |
+
encoder_hidden_states,
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
else:
|
| 775 |
+
hidden_states = resnet(hidden_states, temb)
|
| 776 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
| 777 |
+
hidden_states = (
|
| 778 |
+
audio_attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
| 779 |
+
if audio_attn is not None
|
| 780 |
+
else hidden_states
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
# add motion module
|
| 784 |
+
hidden_states = (
|
| 785 |
+
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
| 786 |
+
if motion_module is not None
|
| 787 |
+
else hidden_states
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
if self.upsamplers is not None:
|
| 791 |
+
for upsampler in self.upsamplers:
|
| 792 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 793 |
+
|
| 794 |
+
return hidden_states
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
class UpBlock3D(nn.Module):
|
| 798 |
+
def __init__(
|
| 799 |
+
self,
|
| 800 |
+
in_channels: int,
|
| 801 |
+
prev_output_channel: int,
|
| 802 |
+
out_channels: int,
|
| 803 |
+
temb_channels: int,
|
| 804 |
+
dropout: float = 0.0,
|
| 805 |
+
num_layers: int = 1,
|
| 806 |
+
resnet_eps: float = 1e-6,
|
| 807 |
+
resnet_time_scale_shift: str = "default",
|
| 808 |
+
resnet_act_fn: str = "swish",
|
| 809 |
+
resnet_groups: int = 32,
|
| 810 |
+
resnet_pre_norm: bool = True,
|
| 811 |
+
output_scale_factor=1.0,
|
| 812 |
+
add_upsample=True,
|
| 813 |
+
use_inflated_groupnorm=False,
|
| 814 |
+
use_motion_module=None,
|
| 815 |
+
motion_module_type=None,
|
| 816 |
+
motion_module_kwargs=None,
|
| 817 |
+
):
|
| 818 |
+
super().__init__()
|
| 819 |
+
resnets = []
|
| 820 |
+
motion_modules = []
|
| 821 |
+
|
| 822 |
+
for i in range(num_layers):
|
| 823 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
| 824 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
| 825 |
+
|
| 826 |
+
resnets.append(
|
| 827 |
+
ResnetBlock3D(
|
| 828 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
| 829 |
+
out_channels=out_channels,
|
| 830 |
+
temb_channels=temb_channels,
|
| 831 |
+
eps=resnet_eps,
|
| 832 |
+
groups=resnet_groups,
|
| 833 |
+
dropout=dropout,
|
| 834 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 835 |
+
non_linearity=resnet_act_fn,
|
| 836 |
+
output_scale_factor=output_scale_factor,
|
| 837 |
+
pre_norm=resnet_pre_norm,
|
| 838 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 839 |
+
)
|
| 840 |
+
)
|
| 841 |
+
motion_modules.append(
|
| 842 |
+
get_motion_module(
|
| 843 |
+
in_channels=out_channels,
|
| 844 |
+
motion_module_type=motion_module_type,
|
| 845 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 846 |
+
)
|
| 847 |
+
if use_motion_module
|
| 848 |
+
else None
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
self.resnets = nn.ModuleList(resnets)
|
| 852 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
| 853 |
+
|
| 854 |
+
if add_upsample:
|
| 855 |
+
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
| 856 |
+
else:
|
| 857 |
+
self.upsamplers = None
|
| 858 |
+
|
| 859 |
+
self.gradient_checkpointing = False
|
| 860 |
+
|
| 861 |
+
def forward(
|
| 862 |
+
self,
|
| 863 |
+
hidden_states,
|
| 864 |
+
res_hidden_states_tuple,
|
| 865 |
+
temb=None,
|
| 866 |
+
upsample_size=None,
|
| 867 |
+
encoder_hidden_states=None,
|
| 868 |
+
):
|
| 869 |
+
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
| 870 |
+
# pop res hidden states
|
| 871 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 872 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 873 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 874 |
+
|
| 875 |
+
if self.training and self.gradient_checkpointing:
|
| 876 |
+
|
| 877 |
+
def create_custom_forward(module):
|
| 878 |
+
def custom_forward(*inputs):
|
| 879 |
+
return module(*inputs)
|
| 880 |
+
|
| 881 |
+
return custom_forward
|
| 882 |
+
|
| 883 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
| 884 |
+
if motion_module is not None:
|
| 885 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 886 |
+
create_custom_forward(motion_module),
|
| 887 |
+
hidden_states.requires_grad_(),
|
| 888 |
+
temb,
|
| 889 |
+
encoder_hidden_states,
|
| 890 |
+
)
|
| 891 |
+
else:
|
| 892 |
+
hidden_states = resnet(hidden_states, temb)
|
| 893 |
+
hidden_states = (
|
| 894 |
+
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
| 895 |
+
if motion_module is not None
|
| 896 |
+
else hidden_states
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
if self.upsamplers is not None:
|
| 900 |
+
for upsampler in self.upsamplers:
|
| 901 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 902 |
+
|
| 903 |
+
return hidden_states
|
latentsync/models/utils.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
def zero_module(module):
|
| 16 |
+
# Zero out the parameters of a module and return it.
|
| 17 |
+
for p in module.parameters():
|
| 18 |
+
p.detach().zero_()
|
| 19 |
+
return module
|
latentsync/pipelines/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# latentsync/pipelines/__init__.py
|
| 2 |
+
|
| 3 |
+
from .lipsync_pipeline import *
|
| 4 |
+
|
| 5 |
+
__all__ = ["LatentSyncPipeline"]
|
| 6 |
+
|
latentsync/pipelines/lipsync_pipeline.py
ADDED
|
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/pipelines/pipeline_animation.py
|
| 2 |
+
|
| 3 |
+
import inspect
|
| 4 |
+
import os
|
| 5 |
+
import shutil
|
| 6 |
+
from typing import Callable, List, Optional, Union
|
| 7 |
+
import subprocess
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torchvision
|
| 12 |
+
|
| 13 |
+
from diffusers.utils import is_accelerate_available
|
| 14 |
+
from packaging import version
|
| 15 |
+
|
| 16 |
+
from diffusers.configuration_utils import FrozenDict
|
| 17 |
+
from diffusers.models import AutoencoderKL
|
| 18 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 19 |
+
from diffusers.schedulers import (
|
| 20 |
+
DDIMScheduler,
|
| 21 |
+
DPMSolverMultistepScheduler,
|
| 22 |
+
EulerAncestralDiscreteScheduler,
|
| 23 |
+
EulerDiscreteScheduler,
|
| 24 |
+
LMSDiscreteScheduler,
|
| 25 |
+
PNDMScheduler,
|
| 26 |
+
)
|
| 27 |
+
from diffusers.utils import deprecate, logging
|
| 28 |
+
|
| 29 |
+
from einops import rearrange
|
| 30 |
+
|
| 31 |
+
from ..models.unet import UNet3DConditionModel
|
| 32 |
+
from ..utils.image_processor import ImageProcessor
|
| 33 |
+
from ..utils.util import read_video, read_audio, write_video
|
| 34 |
+
from ..whisper.audio2feature import Audio2Feature
|
| 35 |
+
import tqdm
|
| 36 |
+
import soundfile as sf
|
| 37 |
+
|
| 38 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class LipsyncPipeline(DiffusionPipeline):
|
| 42 |
+
_optional_components = []
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
vae: AutoencoderKL,
|
| 47 |
+
audio_encoder: Audio2Feature,
|
| 48 |
+
unet: UNet3DConditionModel,
|
| 49 |
+
scheduler: Union[
|
| 50 |
+
DDIMScheduler,
|
| 51 |
+
PNDMScheduler,
|
| 52 |
+
LMSDiscreteScheduler,
|
| 53 |
+
EulerDiscreteScheduler,
|
| 54 |
+
EulerAncestralDiscreteScheduler,
|
| 55 |
+
DPMSolverMultistepScheduler,
|
| 56 |
+
],
|
| 57 |
+
):
|
| 58 |
+
super().__init__()
|
| 59 |
+
|
| 60 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
| 61 |
+
deprecation_message = (
|
| 62 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
| 63 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
| 64 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
| 65 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
| 66 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
| 67 |
+
" file"
|
| 68 |
+
)
|
| 69 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
| 70 |
+
new_config = dict(scheduler.config)
|
| 71 |
+
new_config["steps_offset"] = 1
|
| 72 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
| 73 |
+
|
| 74 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
| 75 |
+
deprecation_message = (
|
| 76 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
| 77 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
| 78 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
| 79 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
| 80 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
| 81 |
+
)
|
| 82 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
| 83 |
+
new_config = dict(scheduler.config)
|
| 84 |
+
new_config["clip_sample"] = False
|
| 85 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
| 86 |
+
|
| 87 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
| 88 |
+
version.parse(unet.config._diffusers_version).base_version
|
| 89 |
+
) < version.parse("0.9.0.dev0")
|
| 90 |
+
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
| 91 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
| 92 |
+
deprecation_message = (
|
| 93 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
| 94 |
+
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
| 95 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
| 96 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
| 97 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
| 98 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
| 99 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
| 100 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
| 101 |
+
" the `unet/config.json` file"
|
| 102 |
+
)
|
| 103 |
+
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
| 104 |
+
new_config = dict(unet.config)
|
| 105 |
+
new_config["sample_size"] = 64
|
| 106 |
+
unet._internal_dict = FrozenDict(new_config)
|
| 107 |
+
|
| 108 |
+
self.register_modules(
|
| 109 |
+
vae=vae,
|
| 110 |
+
audio_encoder=audio_encoder,
|
| 111 |
+
unet=unet,
|
| 112 |
+
scheduler=scheduler,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 116 |
+
|
| 117 |
+
self.set_progress_bar_config(desc="Steps")
|
| 118 |
+
|
| 119 |
+
def enable_vae_slicing(self):
|
| 120 |
+
self.vae.enable_slicing()
|
| 121 |
+
|
| 122 |
+
def disable_vae_slicing(self):
|
| 123 |
+
self.vae.disable_slicing()
|
| 124 |
+
|
| 125 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
| 126 |
+
if is_accelerate_available():
|
| 127 |
+
from accelerate import cpu_offload
|
| 128 |
+
else:
|
| 129 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
| 130 |
+
|
| 131 |
+
device = torch.device(f"cuda:{gpu_id}")
|
| 132 |
+
|
| 133 |
+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
| 134 |
+
if cpu_offloaded_model is not None:
|
| 135 |
+
cpu_offload(cpu_offloaded_model, device)
|
| 136 |
+
|
| 137 |
+
@property
|
| 138 |
+
def _execution_device(self):
|
| 139 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
| 140 |
+
return self.device
|
| 141 |
+
for module in self.unet.modules():
|
| 142 |
+
if (
|
| 143 |
+
hasattr(module, "_hf_hook")
|
| 144 |
+
and hasattr(module._hf_hook, "execution_device")
|
| 145 |
+
and module._hf_hook.execution_device is not None
|
| 146 |
+
):
|
| 147 |
+
return torch.device(module._hf_hook.execution_device)
|
| 148 |
+
return self.device
|
| 149 |
+
|
| 150 |
+
def decode_latents(self, latents):
|
| 151 |
+
latents = latents / self.vae.config.scaling_factor + self.vae.config.shift_factor
|
| 152 |
+
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
| 153 |
+
decoded_latents = self.vae.decode(latents).sample
|
| 154 |
+
return decoded_latents
|
| 155 |
+
|
| 156 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 157 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 158 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 159 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 160 |
+
# and should be between [0, 1]
|
| 161 |
+
|
| 162 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 163 |
+
extra_step_kwargs = {}
|
| 164 |
+
if accepts_eta:
|
| 165 |
+
extra_step_kwargs["eta"] = eta
|
| 166 |
+
|
| 167 |
+
# check if the scheduler accepts generator
|
| 168 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 169 |
+
if accepts_generator:
|
| 170 |
+
extra_step_kwargs["generator"] = generator
|
| 171 |
+
return extra_step_kwargs
|
| 172 |
+
|
| 173 |
+
def check_inputs(self, height, width, callback_steps):
|
| 174 |
+
assert height == width, "Height and width must be equal"
|
| 175 |
+
|
| 176 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 177 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 178 |
+
|
| 179 |
+
if (callback_steps is None) or (
|
| 180 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
| 181 |
+
):
|
| 182 |
+
raise ValueError(
|
| 183 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 184 |
+
f" {type(callback_steps)}."
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def prepare_latents(self, batch_size, num_frames, num_channels_latents, height, width, dtype, device, generator):
|
| 188 |
+
shape = (
|
| 189 |
+
batch_size,
|
| 190 |
+
num_channels_latents,
|
| 191 |
+
1,
|
| 192 |
+
height // self.vae_scale_factor,
|
| 193 |
+
width // self.vae_scale_factor,
|
| 194 |
+
)
|
| 195 |
+
rand_device = "cpu" if device.type == "mps" else device
|
| 196 |
+
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
| 197 |
+
latents = latents.repeat(1, 1, num_frames, 1, 1)
|
| 198 |
+
|
| 199 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 200 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 201 |
+
return latents
|
| 202 |
+
|
| 203 |
+
def prepare_mask_latents(
|
| 204 |
+
self, mask, masked_image, height, width, dtype, device, generator, do_classifier_free_guidance
|
| 205 |
+
):
|
| 206 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
| 207 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
| 208 |
+
# and half precision
|
| 209 |
+
mask = torch.nn.functional.interpolate(
|
| 210 |
+
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
|
| 211 |
+
)
|
| 212 |
+
masked_image = masked_image.to(device=device, dtype=dtype)
|
| 213 |
+
|
| 214 |
+
# encode the mask image into latents space so we can concatenate it to the latents
|
| 215 |
+
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
|
| 216 |
+
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 217 |
+
|
| 218 |
+
# aligning device to prevent device errors when concating it with the latent model input
|
| 219 |
+
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
| 220 |
+
mask = mask.to(device=device, dtype=dtype)
|
| 221 |
+
|
| 222 |
+
# assume batch size = 1
|
| 223 |
+
mask = rearrange(mask, "f c h w -> 1 c f h w")
|
| 224 |
+
masked_image_latents = rearrange(masked_image_latents, "f c h w -> 1 c f h w")
|
| 225 |
+
|
| 226 |
+
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
| 227 |
+
masked_image_latents = (
|
| 228 |
+
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
|
| 229 |
+
)
|
| 230 |
+
return mask, masked_image_latents
|
| 231 |
+
|
| 232 |
+
def prepare_image_latents(self, images, device, dtype, generator, do_classifier_free_guidance):
|
| 233 |
+
images = images.to(device=device, dtype=dtype)
|
| 234 |
+
image_latents = self.vae.encode(images).latent_dist.sample(generator=generator)
|
| 235 |
+
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 236 |
+
image_latents = rearrange(image_latents, "f c h w -> 1 c f h w")
|
| 237 |
+
image_latents = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
|
| 238 |
+
|
| 239 |
+
return image_latents
|
| 240 |
+
|
| 241 |
+
def set_progress_bar_config(self, **kwargs):
|
| 242 |
+
if not hasattr(self, "_progress_bar_config"):
|
| 243 |
+
self._progress_bar_config = {}
|
| 244 |
+
self._progress_bar_config.update(kwargs)
|
| 245 |
+
|
| 246 |
+
@staticmethod
|
| 247 |
+
def paste_surrounding_pixels_back(decoded_latents, pixel_values, masks, device, weight_dtype):
|
| 248 |
+
# Paste the surrounding pixels back, because we only want to change the mouth region
|
| 249 |
+
pixel_values = pixel_values.to(device=device, dtype=weight_dtype)
|
| 250 |
+
masks = masks.to(device=device, dtype=weight_dtype)
|
| 251 |
+
combined_pixel_values = decoded_latents * masks + pixel_values * (1 - masks)
|
| 252 |
+
return combined_pixel_values
|
| 253 |
+
|
| 254 |
+
@staticmethod
|
| 255 |
+
def pixel_values_to_images(pixel_values: torch.Tensor):
|
| 256 |
+
pixel_values = rearrange(pixel_values, "f c h w -> f h w c")
|
| 257 |
+
pixel_values = (pixel_values / 2 + 0.5).clamp(0, 1)
|
| 258 |
+
images = (pixel_values * 255).to(torch.uint8)
|
| 259 |
+
images = images.cpu().numpy()
|
| 260 |
+
return images
|
| 261 |
+
|
| 262 |
+
def affine_transform_video(self, video_path):
|
| 263 |
+
video_frames = read_video(video_path, use_decord=False)
|
| 264 |
+
faces = []
|
| 265 |
+
boxes = []
|
| 266 |
+
affine_matrices = []
|
| 267 |
+
print(f"Affine transforming {len(video_frames)} faces...")
|
| 268 |
+
for frame in tqdm.tqdm(video_frames):
|
| 269 |
+
face, box, affine_matrix = self.image_processor.affine_transform(frame)
|
| 270 |
+
faces.append(face)
|
| 271 |
+
boxes.append(box)
|
| 272 |
+
affine_matrices.append(affine_matrix)
|
| 273 |
+
|
| 274 |
+
faces = torch.stack(faces)
|
| 275 |
+
return faces, video_frames, boxes, affine_matrices
|
| 276 |
+
|
| 277 |
+
def restore_video(self, faces, video_frames, boxes, affine_matrices):
|
| 278 |
+
video_frames = video_frames[: faces.shape[0]]
|
| 279 |
+
out_frames = []
|
| 280 |
+
for index, face in enumerate(faces):
|
| 281 |
+
x1, y1, x2, y2 = boxes[index]
|
| 282 |
+
height = int(y2 - y1)
|
| 283 |
+
width = int(x2 - x1)
|
| 284 |
+
face = torchvision.transforms.functional.resize(face, size=(height, width), antialias=True)
|
| 285 |
+
face = rearrange(face, "c h w -> h w c")
|
| 286 |
+
face = (face / 2 + 0.5).clamp(0, 1)
|
| 287 |
+
face = (face * 255).to(torch.uint8).cpu().numpy()
|
| 288 |
+
out_frame = self.image_processor.restorer.restore_img(video_frames[index], face, affine_matrices[index])
|
| 289 |
+
out_frames.append(out_frame)
|
| 290 |
+
return np.stack(out_frames, axis=0)
|
| 291 |
+
|
| 292 |
+
@torch.no_grad()
|
| 293 |
+
def __call__(
|
| 294 |
+
self,
|
| 295 |
+
video_path: str,
|
| 296 |
+
audio_path: str,
|
| 297 |
+
video_out_path: str,
|
| 298 |
+
video_mask_path: str = None,
|
| 299 |
+
num_frames: int = 16,
|
| 300 |
+
video_fps: int = 25,
|
| 301 |
+
audio_sample_rate: int = 16000,
|
| 302 |
+
height: Optional[int] = None,
|
| 303 |
+
width: Optional[int] = None,
|
| 304 |
+
num_inference_steps: int = 20,
|
| 305 |
+
guidance_scale: float = 1.5,
|
| 306 |
+
weight_dtype: Optional[torch.dtype] = torch.float16,
|
| 307 |
+
eta: float = 0.0,
|
| 308 |
+
mask: str = "fix_mask",
|
| 309 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 310 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 311 |
+
callback_steps: Optional[int] = 1,
|
| 312 |
+
**kwargs,
|
| 313 |
+
):
|
| 314 |
+
is_train = self.unet.training
|
| 315 |
+
self.unet.eval()
|
| 316 |
+
|
| 317 |
+
# 0. Define call parameters
|
| 318 |
+
batch_size = 1
|
| 319 |
+
device = self._execution_device
|
| 320 |
+
self.image_processor = ImageProcessor(height, mask=mask, device="cuda")
|
| 321 |
+
self.set_progress_bar_config(desc=f"Sample frames: {num_frames}")
|
| 322 |
+
|
| 323 |
+
video_frames, original_video_frames, boxes, affine_matrices = self.affine_transform_video(video_path)
|
| 324 |
+
audio_samples = read_audio(audio_path)
|
| 325 |
+
|
| 326 |
+
# 1. Default height and width to unet
|
| 327 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 328 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 329 |
+
|
| 330 |
+
# 2. Check inputs
|
| 331 |
+
self.check_inputs(height, width, callback_steps)
|
| 332 |
+
|
| 333 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 334 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 335 |
+
# corresponds to doing no classifier free guidance.
|
| 336 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 337 |
+
|
| 338 |
+
# 3. set timesteps
|
| 339 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 340 |
+
timesteps = self.scheduler.timesteps
|
| 341 |
+
|
| 342 |
+
# 4. Prepare extra step kwargs.
|
| 343 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 344 |
+
|
| 345 |
+
self.video_fps = video_fps
|
| 346 |
+
|
| 347 |
+
if self.unet.add_audio_layer:
|
| 348 |
+
whisper_feature = self.audio_encoder.audio2feat(audio_path)
|
| 349 |
+
whisper_chunks = self.audio_encoder.feature2chunks(feature_array=whisper_feature, fps=video_fps)
|
| 350 |
+
|
| 351 |
+
num_inferences = min(len(video_frames), len(whisper_chunks)) // num_frames
|
| 352 |
+
else:
|
| 353 |
+
num_inferences = len(video_frames) // num_frames
|
| 354 |
+
|
| 355 |
+
synced_video_frames = []
|
| 356 |
+
masked_video_frames = []
|
| 357 |
+
|
| 358 |
+
num_channels_latents = self.vae.config.latent_channels
|
| 359 |
+
|
| 360 |
+
# Prepare latent variables
|
| 361 |
+
all_latents = self.prepare_latents(
|
| 362 |
+
batch_size,
|
| 363 |
+
num_frames * num_inferences,
|
| 364 |
+
num_channels_latents,
|
| 365 |
+
height,
|
| 366 |
+
width,
|
| 367 |
+
weight_dtype,
|
| 368 |
+
device,
|
| 369 |
+
generator,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
for i in tqdm.tqdm(range(num_inferences), desc="Doing inference..."):
|
| 373 |
+
if self.unet.add_audio_layer:
|
| 374 |
+
audio_embeds = torch.stack(whisper_chunks[i * num_frames : (i + 1) * num_frames])
|
| 375 |
+
audio_embeds = audio_embeds.to(device, dtype=weight_dtype)
|
| 376 |
+
if do_classifier_free_guidance:
|
| 377 |
+
empty_audio_embeds = torch.zeros_like(audio_embeds)
|
| 378 |
+
audio_embeds = torch.cat([empty_audio_embeds, audio_embeds])
|
| 379 |
+
else:
|
| 380 |
+
audio_embeds = None
|
| 381 |
+
inference_video_frames = video_frames[i * num_frames : (i + 1) * num_frames]
|
| 382 |
+
latents = all_latents[:, :, i * num_frames : (i + 1) * num_frames]
|
| 383 |
+
pixel_values, masked_pixel_values, masks = self.image_processor.prepare_masks_and_masked_images(
|
| 384 |
+
inference_video_frames, affine_transform=False
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
# 7. Prepare mask latent variables
|
| 388 |
+
mask_latents, masked_image_latents = self.prepare_mask_latents(
|
| 389 |
+
masks,
|
| 390 |
+
masked_pixel_values,
|
| 391 |
+
height,
|
| 392 |
+
width,
|
| 393 |
+
weight_dtype,
|
| 394 |
+
device,
|
| 395 |
+
generator,
|
| 396 |
+
do_classifier_free_guidance,
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
# 8. Prepare image latents
|
| 400 |
+
image_latents = self.prepare_image_latents(
|
| 401 |
+
pixel_values,
|
| 402 |
+
device,
|
| 403 |
+
weight_dtype,
|
| 404 |
+
generator,
|
| 405 |
+
do_classifier_free_guidance,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# 9. Denoising loop
|
| 409 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 410 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 411 |
+
for j, t in enumerate(timesteps):
|
| 412 |
+
# expand the latents if we are doing classifier free guidance
|
| 413 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 414 |
+
|
| 415 |
+
# concat latents, mask, masked_image_latents in the channel dimension
|
| 416 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 417 |
+
latent_model_input = torch.cat(
|
| 418 |
+
[latent_model_input, mask_latents, masked_image_latents, image_latents], dim=1
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
# predict the noise residual
|
| 422 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=audio_embeds).sample
|
| 423 |
+
|
| 424 |
+
# perform guidance
|
| 425 |
+
if do_classifier_free_guidance:
|
| 426 |
+
noise_pred_uncond, noise_pred_audio = noise_pred.chunk(2)
|
| 427 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_audio - noise_pred_uncond)
|
| 428 |
+
|
| 429 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 430 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| 431 |
+
|
| 432 |
+
# call the callback, if provided
|
| 433 |
+
if j == len(timesteps) - 1 or ((j + 1) > num_warmup_steps and (j + 1) % self.scheduler.order == 0):
|
| 434 |
+
progress_bar.update()
|
| 435 |
+
if callback is not None and j % callback_steps == 0:
|
| 436 |
+
callback(j, t, latents)
|
| 437 |
+
|
| 438 |
+
# Recover the pixel values
|
| 439 |
+
decoded_latents = self.decode_latents(latents)
|
| 440 |
+
decoded_latents = self.paste_surrounding_pixels_back(
|
| 441 |
+
decoded_latents, pixel_values, 1 - masks, device, weight_dtype
|
| 442 |
+
)
|
| 443 |
+
synced_video_frames.append(decoded_latents)
|
| 444 |
+
masked_video_frames.append(masked_pixel_values)
|
| 445 |
+
|
| 446 |
+
synced_video_frames = self.restore_video(
|
| 447 |
+
torch.cat(synced_video_frames), original_video_frames, boxes, affine_matrices
|
| 448 |
+
)
|
| 449 |
+
masked_video_frames = self.restore_video(
|
| 450 |
+
torch.cat(masked_video_frames), original_video_frames, boxes, affine_matrices
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
audio_samples_remain_length = int(synced_video_frames.shape[0] / video_fps * audio_sample_rate)
|
| 454 |
+
audio_samples = audio_samples[:audio_samples_remain_length].cpu().numpy()
|
| 455 |
+
|
| 456 |
+
if is_train:
|
| 457 |
+
self.unet.train()
|
| 458 |
+
|
| 459 |
+
temp_dir = "temp"
|
| 460 |
+
if os.path.exists(temp_dir):
|
| 461 |
+
shutil.rmtree(temp_dir)
|
| 462 |
+
os.makedirs(temp_dir, exist_ok=True)
|
| 463 |
+
|
| 464 |
+
write_video(os.path.join(temp_dir, "video.mp4"), synced_video_frames, fps=25)
|
| 465 |
+
# write_video(video_mask_path, masked_video_frames, fps=25)
|
| 466 |
+
|
| 467 |
+
sf.write(os.path.join(temp_dir, "audio.wav"), audio_samples, audio_sample_rate)
|
| 468 |
+
|
| 469 |
+
command = f"ffmpeg -y -loglevel error -nostdin -i {os.path.join(temp_dir, 'video.mp4')} -i {os.path.join(temp_dir, 'audio.wav')} -c:v libx264 -c:a aac -q:v 0 -q:a 0 {video_out_path}"
|
| 470 |
+
subprocess.run(command, shell=True)
|
latentsync/utils/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# latentsync/utils/__init__.py
|
| 2 |
+
|
| 3 |
+
from .affine_transform import *
|
| 4 |
+
from .audio import *
|
| 5 |
+
from .av_reader import *
|
| 6 |
+
from .image_processor import *
|
| 7 |
+
from .util import *
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"AffineTransform",
|
| 11 |
+
"AudioUtils",
|
| 12 |
+
"AVReader",
|
| 13 |
+
"ImageProcessor",
|
| 14 |
+
"Util"
|
| 15 |
+
]
|
| 16 |
+
|
latentsync/utils/affine_transform.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/guanjz20/StyleSync/blob/main/utils.py
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cv2
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def transformation_from_points(points1, points0, smooth=True, p_bias=None):
|
| 8 |
+
points2 = np.array(points0)
|
| 9 |
+
points2 = points2.astype(np.float64)
|
| 10 |
+
points1 = points1.astype(np.float64)
|
| 11 |
+
c1 = np.mean(points1, axis=0)
|
| 12 |
+
c2 = np.mean(points2, axis=0)
|
| 13 |
+
points1 -= c1
|
| 14 |
+
points2 -= c2
|
| 15 |
+
s1 = np.std(points1)
|
| 16 |
+
s2 = np.std(points2)
|
| 17 |
+
points1 /= s1
|
| 18 |
+
points2 /= s2
|
| 19 |
+
U, S, Vt = np.linalg.svd(np.matmul(points1.T, points2))
|
| 20 |
+
R = (np.matmul(U, Vt)).T
|
| 21 |
+
sR = (s2 / s1) * R
|
| 22 |
+
T = c2.reshape(2, 1) - (s2 / s1) * np.matmul(R, c1.reshape(2, 1))
|
| 23 |
+
M = np.concatenate((sR, T), axis=1)
|
| 24 |
+
if smooth:
|
| 25 |
+
bias = points2[2] - points1[2]
|
| 26 |
+
if p_bias is None:
|
| 27 |
+
p_bias = bias
|
| 28 |
+
else:
|
| 29 |
+
bias = p_bias * 0.2 + bias * 0.8
|
| 30 |
+
p_bias = bias
|
| 31 |
+
M[:, 2] = M[:, 2] + bias
|
| 32 |
+
return M, p_bias
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class AlignRestore(object):
|
| 36 |
+
def __init__(self, align_points=3):
|
| 37 |
+
if align_points == 3:
|
| 38 |
+
self.upscale_factor = 1
|
| 39 |
+
self.crop_ratio = (2.8, 2.8)
|
| 40 |
+
self.face_template = np.array([[19 - 2, 30 - 10], [56 + 2, 30 - 10], [37.5, 45 - 5]])
|
| 41 |
+
self.face_template = self.face_template * 2.8
|
| 42 |
+
# self.face_size = (int(100 * self.crop_ratio[0]), int(100 * self.crop_ratio[1]))
|
| 43 |
+
self.face_size = (int(75 * self.crop_ratio[0]), int(100 * self.crop_ratio[1]))
|
| 44 |
+
self.p_bias = None
|
| 45 |
+
|
| 46 |
+
def process(self, img, lmk_align=None, smooth=True, align_points=3):
|
| 47 |
+
aligned_face, affine_matrix = self.align_warp_face(img, lmk_align, smooth)
|
| 48 |
+
restored_img = self.restore_img(img, aligned_face, affine_matrix)
|
| 49 |
+
cv2.imwrite("restored.jpg", restored_img)
|
| 50 |
+
cv2.imwrite("aligned.jpg", aligned_face)
|
| 51 |
+
return aligned_face, restored_img
|
| 52 |
+
|
| 53 |
+
def align_warp_face(self, img, lmks3, smooth=True, border_mode="constant"):
|
| 54 |
+
affine_matrix, self.p_bias = transformation_from_points(lmks3, self.face_template, smooth, self.p_bias)
|
| 55 |
+
if border_mode == "constant":
|
| 56 |
+
border_mode = cv2.BORDER_CONSTANT
|
| 57 |
+
elif border_mode == "reflect101":
|
| 58 |
+
border_mode = cv2.BORDER_REFLECT101
|
| 59 |
+
elif border_mode == "reflect":
|
| 60 |
+
border_mode = cv2.BORDER_REFLECT
|
| 61 |
+
cropped_face = cv2.warpAffine(
|
| 62 |
+
img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=[127, 127, 127]
|
| 63 |
+
)
|
| 64 |
+
return cropped_face, affine_matrix
|
| 65 |
+
|
| 66 |
+
def align_warp_face2(self, img, landmark, border_mode="constant"):
|
| 67 |
+
affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template)[0]
|
| 68 |
+
if border_mode == "constant":
|
| 69 |
+
border_mode = cv2.BORDER_CONSTANT
|
| 70 |
+
elif border_mode == "reflect101":
|
| 71 |
+
border_mode = cv2.BORDER_REFLECT101
|
| 72 |
+
elif border_mode == "reflect":
|
| 73 |
+
border_mode = cv2.BORDER_REFLECT
|
| 74 |
+
cropped_face = cv2.warpAffine(
|
| 75 |
+
img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)
|
| 76 |
+
)
|
| 77 |
+
return cropped_face, affine_matrix
|
| 78 |
+
|
| 79 |
+
def restore_img(self, input_img, face, affine_matrix):
|
| 80 |
+
h, w, _ = input_img.shape
|
| 81 |
+
h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
|
| 82 |
+
upsample_img = cv2.resize(input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
|
| 83 |
+
inverse_affine = cv2.invertAffineTransform(affine_matrix)
|
| 84 |
+
inverse_affine *= self.upscale_factor
|
| 85 |
+
if self.upscale_factor > 1:
|
| 86 |
+
extra_offset = 0.5 * self.upscale_factor
|
| 87 |
+
else:
|
| 88 |
+
extra_offset = 0
|
| 89 |
+
inverse_affine[:, 2] += extra_offset
|
| 90 |
+
inv_restored = cv2.warpAffine(face, inverse_affine, (w_up, h_up))
|
| 91 |
+
mask = np.ones((self.face_size[1], self.face_size[0]), dtype=np.float32)
|
| 92 |
+
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
|
| 93 |
+
inv_mask_erosion = cv2.erode(
|
| 94 |
+
inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)
|
| 95 |
+
)
|
| 96 |
+
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
|
| 97 |
+
total_face_area = np.sum(inv_mask_erosion)
|
| 98 |
+
w_edge = int(total_face_area**0.5) // 20
|
| 99 |
+
erosion_radius = w_edge * 2
|
| 100 |
+
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
| 101 |
+
blur_size = w_edge * 2
|
| 102 |
+
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
|
| 103 |
+
inv_soft_mask = inv_soft_mask[:, :, None]
|
| 104 |
+
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
|
| 105 |
+
if np.max(upsample_img) > 256:
|
| 106 |
+
upsample_img = upsample_img.astype(np.uint16)
|
| 107 |
+
else:
|
| 108 |
+
upsample_img = upsample_img.astype(np.uint8)
|
| 109 |
+
return upsample_img
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class laplacianSmooth:
|
| 113 |
+
def __init__(self, smoothAlpha=0.3):
|
| 114 |
+
self.smoothAlpha = smoothAlpha
|
| 115 |
+
self.pts_last = None
|
| 116 |
+
|
| 117 |
+
def smooth(self, pts_cur):
|
| 118 |
+
if self.pts_last is None:
|
| 119 |
+
self.pts_last = pts_cur.copy()
|
| 120 |
+
return pts_cur.copy()
|
| 121 |
+
x1 = min(pts_cur[:, 0])
|
| 122 |
+
x2 = max(pts_cur[:, 0])
|
| 123 |
+
y1 = min(pts_cur[:, 1])
|
| 124 |
+
y2 = max(pts_cur[:, 1])
|
| 125 |
+
width = x2 - x1
|
| 126 |
+
pts_update = []
|
| 127 |
+
for i in range(len(pts_cur)):
|
| 128 |
+
x_new, y_new = pts_cur[i]
|
| 129 |
+
x_old, y_old = self.pts_last[i]
|
| 130 |
+
tmp = (x_new - x_old) ** 2 + (y_new - y_old) ** 2
|
| 131 |
+
w = np.exp(-tmp / (width * self.smoothAlpha))
|
| 132 |
+
x = x_old * w + x_new * (1 - w)
|
| 133 |
+
y = y_old * w + y_new * (1 - w)
|
| 134 |
+
pts_update.append([x, y])
|
| 135 |
+
pts_update = np.array(pts_update)
|
| 136 |
+
self.pts_last = pts_update.copy()
|
| 137 |
+
|
| 138 |
+
return pts_update
|
latentsync/utils/audio.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/Rudrabha/Wav2Lip/blob/master/audio.py
|
| 2 |
+
|
| 3 |
+
import librosa
|
| 4 |
+
import librosa.filters
|
| 5 |
+
import numpy as np
|
| 6 |
+
from scipy import signal
|
| 7 |
+
from scipy.io import wavfile
|
| 8 |
+
from omegaconf import OmegaConf
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
# Create config directly from your YAML structure
|
| 12 |
+
#audio_config_path = "configs/audio.yaml"
|
| 13 |
+
config = OmegaConf.create({
|
| 14 |
+
"audio": {
|
| 15 |
+
"num_mels": 80,
|
| 16 |
+
"rescale": True,
|
| 17 |
+
"rescaling_max": 0.9,
|
| 18 |
+
"use_lws": False,
|
| 19 |
+
"n_fft": 800,
|
| 20 |
+
"hop_size": 200,
|
| 21 |
+
"win_size": 800,
|
| 22 |
+
"sample_rate": 16000,
|
| 23 |
+
"frame_shift_ms": None,
|
| 24 |
+
"signal_normalization": True,
|
| 25 |
+
"allow_clipping_in_normalization": True,
|
| 26 |
+
"symmetric_mels": True,
|
| 27 |
+
"max_abs_value": 4.0,
|
| 28 |
+
"preemphasize": True,
|
| 29 |
+
"preemphasis": 0.97,
|
| 30 |
+
"min_level_db": -100,
|
| 31 |
+
"ref_level_db": 20,
|
| 32 |
+
"fmin": 55,
|
| 33 |
+
"fmax": 7600
|
| 34 |
+
}
|
| 35 |
+
})
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def load_wav(path, sr):
|
| 40 |
+
return librosa.core.load(path, sr=sr)[0]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def save_wav(wav, path, sr):
|
| 44 |
+
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
| 45 |
+
# proposed by @dsmiller
|
| 46 |
+
wavfile.write(path, sr, wav.astype(np.int16))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def save_wavenet_wav(wav, path, sr):
|
| 50 |
+
librosa.output.write_wav(path, wav, sr=sr)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def preemphasis(wav, k, preemphasize=True):
|
| 54 |
+
if preemphasize:
|
| 55 |
+
return signal.lfilter([1, -k], [1], wav)
|
| 56 |
+
return wav
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def inv_preemphasis(wav, k, inv_preemphasize=True):
|
| 60 |
+
if inv_preemphasize:
|
| 61 |
+
return signal.lfilter([1], [1, -k], wav)
|
| 62 |
+
return wav
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_hop_size():
|
| 66 |
+
hop_size = config.audio.hop_size
|
| 67 |
+
if hop_size is None:
|
| 68 |
+
assert config.audio.frame_shift_ms is not None
|
| 69 |
+
hop_size = int(config.audio.frame_shift_ms / 1000 * config.audio.sample_rate)
|
| 70 |
+
return hop_size
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def linearspectrogram(wav):
|
| 74 |
+
D = _stft(preemphasis(wav, config.audio.preemphasis, config.audio.preemphasize))
|
| 75 |
+
S = _amp_to_db(np.abs(D)) - config.audio.ref_level_db
|
| 76 |
+
|
| 77 |
+
if config.audio.signal_normalization:
|
| 78 |
+
return _normalize(S)
|
| 79 |
+
return S
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def melspectrogram(wav):
|
| 83 |
+
D = _stft(preemphasis(wav, config.audio.preemphasis, config.audio.preemphasize))
|
| 84 |
+
S = _amp_to_db(_linear_to_mel(np.abs(D))) - config.audio.ref_level_db
|
| 85 |
+
|
| 86 |
+
if config.audio.signal_normalization:
|
| 87 |
+
return _normalize(S)
|
| 88 |
+
return S
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _lws_processor():
|
| 92 |
+
import lws
|
| 93 |
+
|
| 94 |
+
return lws.lws(config.audio.n_fft, get_hop_size(), fftsize=config.audio.win_size, mode="speech")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _stft(y):
|
| 98 |
+
if config.audio.use_lws:
|
| 99 |
+
return _lws_processor(config.audio).stft(y).T
|
| 100 |
+
else:
|
| 101 |
+
return librosa.stft(y=y, n_fft=config.audio.n_fft, hop_length=get_hop_size(), win_length=config.audio.win_size)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
##########################################################
|
| 105 |
+
# Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
|
| 106 |
+
def num_frames(length, fsize, fshift):
|
| 107 |
+
"""Compute number of time frames of spectrogram"""
|
| 108 |
+
pad = fsize - fshift
|
| 109 |
+
if length % fshift == 0:
|
| 110 |
+
M = (length + pad * 2 - fsize) // fshift + 1
|
| 111 |
+
else:
|
| 112 |
+
M = (length + pad * 2 - fsize) // fshift + 2
|
| 113 |
+
return M
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def pad_lr(x, fsize, fshift):
|
| 117 |
+
"""Compute left and right padding"""
|
| 118 |
+
M = num_frames(len(x), fsize, fshift)
|
| 119 |
+
pad = fsize - fshift
|
| 120 |
+
T = len(x) + 2 * pad
|
| 121 |
+
r = (M - 1) * fshift + fsize - T
|
| 122 |
+
return pad, pad + r
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
##########################################################
|
| 126 |
+
# Librosa correct padding
|
| 127 |
+
def librosa_pad_lr(x, fsize, fshift):
|
| 128 |
+
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# Conversions
|
| 132 |
+
_mel_basis = None
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _linear_to_mel(spectogram):
|
| 136 |
+
global _mel_basis
|
| 137 |
+
if _mel_basis is None:
|
| 138 |
+
_mel_basis = _build_mel_basis()
|
| 139 |
+
return np.dot(_mel_basis, spectogram)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _build_mel_basis():
|
| 143 |
+
assert config.audio.fmax <= config.audio.sample_rate // 2
|
| 144 |
+
return librosa.filters.mel(
|
| 145 |
+
sr=config.audio.sample_rate,
|
| 146 |
+
n_fft=config.audio.n_fft,
|
| 147 |
+
n_mels=config.audio.num_mels,
|
| 148 |
+
fmin=config.audio.fmin,
|
| 149 |
+
fmax=config.audio.fmax,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _amp_to_db(x):
|
| 154 |
+
min_level = np.exp(config.audio.min_level_db / 20 * np.log(10))
|
| 155 |
+
return 20 * np.log10(np.maximum(min_level, x))
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _db_to_amp(x):
|
| 159 |
+
return np.power(10.0, (x) * 0.05)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _normalize(S):
|
| 163 |
+
if config.audio.allow_clipping_in_normalization:
|
| 164 |
+
if config.audio.symmetric_mels:
|
| 165 |
+
return np.clip(
|
| 166 |
+
(2 * config.audio.max_abs_value) * ((S - config.audio.min_level_db) / (-config.audio.min_level_db))
|
| 167 |
+
- config.audio.max_abs_value,
|
| 168 |
+
-config.audio.max_abs_value,
|
| 169 |
+
config.audio.max_abs_value,
|
| 170 |
+
)
|
| 171 |
+
else:
|
| 172 |
+
return np.clip(
|
| 173 |
+
config.audio.max_abs_value * ((S - config.audio.min_level_db) / (-config.audio.min_level_db)),
|
| 174 |
+
0,
|
| 175 |
+
config.audio.max_abs_value,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
assert S.max() <= 0 and S.min() - config.audio.min_level_db >= 0
|
| 179 |
+
if config.audio.symmetric_mels:
|
| 180 |
+
return (2 * config.audio.max_abs_value) * (
|
| 181 |
+
(S - config.audio.min_level_db) / (-config.audio.min_level_db)
|
| 182 |
+
) - config.audio.max_abs_value
|
| 183 |
+
else:
|
| 184 |
+
return config.audio.max_abs_value * ((S - config.audio.min_level_db) / (-config.audio.min_level_db))
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _denormalize(D):
|
| 188 |
+
if config.audio.allow_clipping_in_normalization:
|
| 189 |
+
if config.audio.symmetric_mels:
|
| 190 |
+
return (
|
| 191 |
+
(np.clip(D, -config.audio.max_abs_value, config.audio.max_abs_value) + config.audio.max_abs_value)
|
| 192 |
+
* -config.audio.min_level_db
|
| 193 |
+
/ (2 * config.audio.max_abs_value)
|
| 194 |
+
) + config.audio.min_level_db
|
| 195 |
+
else:
|
| 196 |
+
return (
|
| 197 |
+
np.clip(D, 0, config.audio.max_abs_value) * -config.audio.min_level_db / config.audio.max_abs_value
|
| 198 |
+
) + config.audio.min_level_db
|
| 199 |
+
|
| 200 |
+
if config.audio.symmetric_mels:
|
| 201 |
+
return (
|
| 202 |
+
(D + config.audio.max_abs_value) * -config.audio.min_level_db / (2 * config.audio.max_abs_value)
|
| 203 |
+
) + config.audio.min_level_db
|
| 204 |
+
else:
|
| 205 |
+
return (D * -config.audio.min_level_db / config.audio.max_abs_value) + config.audio.min_level_db
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def get_melspec_overlap(audio_samples, melspec_length=52):
|
| 209 |
+
mel_spec_overlap = melspectrogram(audio_samples.numpy())
|
| 210 |
+
mel_spec_overlap = torch.from_numpy(mel_spec_overlap)
|
| 211 |
+
i = 0
|
| 212 |
+
mel_spec_overlap_list = []
|
| 213 |
+
while i + melspec_length < mel_spec_overlap.shape[1] - 3:
|
| 214 |
+
mel_spec_overlap_list.append(mel_spec_overlap[:, i : i + melspec_length].unsqueeze(0))
|
| 215 |
+
i += 3
|
| 216 |
+
mel_spec_overlap = torch.stack(mel_spec_overlap_list)
|
| 217 |
+
return mel_spec_overlap
|
latentsync/utils/av_reader.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# We modified the original AVReader class of decord to solve the problem of memory leak.
|
| 2 |
+
# For more details, refer to: https://github.com/dmlc/decord/issues/208
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from decord.video_reader import VideoReader
|
| 6 |
+
from decord.audio_reader import AudioReader
|
| 7 |
+
|
| 8 |
+
from decord.ndarray import cpu
|
| 9 |
+
from decord import ndarray as _nd
|
| 10 |
+
from decord.bridge import bridge_out
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class AVReader(object):
|
| 14 |
+
"""Individual audio video reader with convenient indexing function.
|
| 15 |
+
|
| 16 |
+
Parameters
|
| 17 |
+
----------
|
| 18 |
+
uri: str
|
| 19 |
+
Path of file.
|
| 20 |
+
ctx: decord.Context
|
| 21 |
+
The context to decode the file, can be decord.cpu() or decord.gpu().
|
| 22 |
+
sample_rate: int, default is -1
|
| 23 |
+
Desired output sample rate of the audio, unchanged if `-1` is specified.
|
| 24 |
+
mono: bool, default is True
|
| 25 |
+
Desired output channel layout of the audio. `True` is mono layout. `False` is unchanged.
|
| 26 |
+
width : int, default is -1
|
| 27 |
+
Desired output width of the video, unchanged if `-1` is specified.
|
| 28 |
+
height : int, default is -1
|
| 29 |
+
Desired output height of the video, unchanged if `-1` is specified.
|
| 30 |
+
num_threads : int, default is 0
|
| 31 |
+
Number of decoding thread, auto if `0` is specified.
|
| 32 |
+
fault_tol : int, default is -1
|
| 33 |
+
The threshold of corupted and recovered frames. This is to prevent silent fault
|
| 34 |
+
tolerance when for example 50% frames of a video cannot be decoded and duplicate
|
| 35 |
+
frames are returned. You may find the fault tolerant feature sweet in many cases,
|
| 36 |
+
but not for training models. Say `N = # recovered frames`
|
| 37 |
+
If `fault_tol` < 0, nothing will happen.
|
| 38 |
+
If 0 < `fault_tol` < 1.0, if N > `fault_tol * len(video)`, raise `DECORDLimitReachedError`.
|
| 39 |
+
If 1 < `fault_tol`, if N > `fault_tol`, raise `DECORDLimitReachedError`.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self, uri, ctx=cpu(0), sample_rate=44100, mono=True, width=-1, height=-1, num_threads=0, fault_tol=-1
|
| 44 |
+
):
|
| 45 |
+
self.__audio_reader = AudioReader(uri, ctx, sample_rate, mono)
|
| 46 |
+
self.__audio_reader.add_padding()
|
| 47 |
+
if hasattr(uri, "read"):
|
| 48 |
+
uri.seek(0)
|
| 49 |
+
self.__video_reader = VideoReader(uri, ctx, width, height, num_threads, fault_tol)
|
| 50 |
+
self.__video_reader.seek(0)
|
| 51 |
+
|
| 52 |
+
def __len__(self):
|
| 53 |
+
"""Get length of the video. Note that sometimes FFMPEG reports inaccurate number of frames,
|
| 54 |
+
we always follow what FFMPEG reports.
|
| 55 |
+
Returns
|
| 56 |
+
-------
|
| 57 |
+
int
|
| 58 |
+
The number of frames in the video file.
|
| 59 |
+
"""
|
| 60 |
+
return len(self.__video_reader)
|
| 61 |
+
|
| 62 |
+
def __getitem__(self, idx):
|
| 63 |
+
"""Get audio samples and video frame at `idx`.
|
| 64 |
+
|
| 65 |
+
Parameters
|
| 66 |
+
----------
|
| 67 |
+
idx : int or slice
|
| 68 |
+
The frame index, can be negative which means it will index backwards,
|
| 69 |
+
or slice of frame indices.
|
| 70 |
+
|
| 71 |
+
Returns
|
| 72 |
+
-------
|
| 73 |
+
(ndarray/list of ndarray, ndarray)
|
| 74 |
+
First element is samples of shape CxS or a list of length N containing samples of shape CxS,
|
| 75 |
+
where N is the number of frames, C is the number of channels,
|
| 76 |
+
S is the number of samples of the corresponding frame.
|
| 77 |
+
|
| 78 |
+
Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
|
| 79 |
+
where N is the length of the slice.
|
| 80 |
+
"""
|
| 81 |
+
assert self.__video_reader is not None and self.__audio_reader is not None
|
| 82 |
+
if isinstance(idx, slice):
|
| 83 |
+
return self.get_batch(range(*idx.indices(len(self.__video_reader))))
|
| 84 |
+
if idx < 0:
|
| 85 |
+
idx += len(self.__video_reader)
|
| 86 |
+
if idx >= len(self.__video_reader) or idx < 0:
|
| 87 |
+
raise IndexError("Index: {} out of bound: {}".format(idx, len(self.__video_reader)))
|
| 88 |
+
audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
|
| 89 |
+
audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
|
| 90 |
+
audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
|
| 91 |
+
results = (self.__audio_reader[audio_start_idx:audio_end_idx], self.__video_reader[idx])
|
| 92 |
+
self.__video_reader.seek(0)
|
| 93 |
+
return results
|
| 94 |
+
|
| 95 |
+
def get_batch(self, indices):
|
| 96 |
+
"""Get entire batch of audio samples and video frames.
|
| 97 |
+
|
| 98 |
+
Parameters
|
| 99 |
+
----------
|
| 100 |
+
indices : list of integers
|
| 101 |
+
A list of frame indices. If negative indices detected, the indices will be indexed from backward
|
| 102 |
+
Returns
|
| 103 |
+
-------
|
| 104 |
+
(list of ndarray, ndarray)
|
| 105 |
+
First element is a list of length N containing samples of shape CxS,
|
| 106 |
+
where N is the number of frames, C is the number of channels,
|
| 107 |
+
S is the number of samples of the corresponding frame.
|
| 108 |
+
|
| 109 |
+
Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
|
| 110 |
+
where N is the length of the slice.
|
| 111 |
+
|
| 112 |
+
"""
|
| 113 |
+
assert self.__video_reader is not None and self.__audio_reader is not None
|
| 114 |
+
indices = self._validate_indices(indices)
|
| 115 |
+
audio_arr = []
|
| 116 |
+
prev_video_idx = None
|
| 117 |
+
prev_audio_end_idx = None
|
| 118 |
+
for idx in list(indices):
|
| 119 |
+
frame_start_time, frame_end_time = self.__video_reader.get_frame_timestamp(idx)
|
| 120 |
+
# timestamp and sample conversion could have some error that could cause non-continuous audio
|
| 121 |
+
# we detect if retrieving continuous frame and make the audio continuous
|
| 122 |
+
if prev_video_idx and idx == prev_video_idx + 1:
|
| 123 |
+
audio_start_idx = prev_audio_end_idx
|
| 124 |
+
else:
|
| 125 |
+
audio_start_idx = self.__audio_reader._time_to_sample(frame_start_time)
|
| 126 |
+
audio_end_idx = self.__audio_reader._time_to_sample(frame_end_time)
|
| 127 |
+
audio_arr.append(self.__audio_reader[audio_start_idx:audio_end_idx])
|
| 128 |
+
prev_video_idx = idx
|
| 129 |
+
prev_audio_end_idx = audio_end_idx
|
| 130 |
+
results = (audio_arr, self.__video_reader.get_batch(indices))
|
| 131 |
+
self.__video_reader.seek(0)
|
| 132 |
+
return results
|
| 133 |
+
|
| 134 |
+
def _get_slice(self, sl):
|
| 135 |
+
audio_arr = np.empty(shape=(self.__audio_reader.shape()[0], 0), dtype="float32")
|
| 136 |
+
for idx in list(sl):
|
| 137 |
+
audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
|
| 138 |
+
audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
|
| 139 |
+
audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
|
| 140 |
+
audio_arr = np.concatenate(
|
| 141 |
+
(audio_arr, self.__audio_reader[audio_start_idx:audio_end_idx].asnumpy()), axis=1
|
| 142 |
+
)
|
| 143 |
+
results = (bridge_out(_nd.array(audio_arr)), self.__video_reader.get_batch(sl))
|
| 144 |
+
self.__video_reader.seek(0)
|
| 145 |
+
return results
|
| 146 |
+
|
| 147 |
+
def _validate_indices(self, indices):
|
| 148 |
+
"""Validate int64 integers and convert negative integers to positive by backward search"""
|
| 149 |
+
assert self.__video_reader is not None and self.__audio_reader is not None
|
| 150 |
+
indices = np.array(indices, dtype=np.int64)
|
| 151 |
+
# process negative indices
|
| 152 |
+
indices[indices < 0] += len(self.__video_reader)
|
| 153 |
+
if not (indices >= 0).all():
|
| 154 |
+
raise IndexError("Invalid negative indices: {}".format(indices[indices < 0] + len(self.__video_reader)))
|
| 155 |
+
if not (indices < len(self.__video_reader)).all():
|
| 156 |
+
raise IndexError("Out of bound indices: {}".format(indices[indices >= len(self.__video_reader)]))
|
| 157 |
+
return indices
|
latentsync/utils/image_processor.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from torchvision import transforms
|
| 16 |
+
import cv2
|
| 17 |
+
from einops import rearrange
|
| 18 |
+
import mediapipe as mp
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
from typing import Union
|
| 22 |
+
from .affine_transform import AlignRestore, laplacianSmooth
|
| 23 |
+
import face_alignment
|
| 24 |
+
import base64
|
| 25 |
+
"""
|
| 26 |
+
If you are enlarging the image, you should prefer to use INTER_LINEAR or INTER_CUBIC interpolation. If you are shrinking the image, you should prefer to use INTER_AREA interpolation.
|
| 27 |
+
https://stackoverflow.com/questions/23853632/which-kind-of-interpolation-best-for-resizing-image
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_fixed_mask(resolution: int) -> torch.Tensor:
|
| 32 |
+
base64_mask="""iVBORw0KGgoAAAANSUhEUgAAAQAAAAEACAIAAADTED8xAAAHFklEQVR4Ae3BAWpcQYIFwcz7HzoHFhYGxiW3bclW/3oRVszcyoqZW1kxcysrZm5lxcytrJi5lRUzt7Ji5lZWzNzKiplbWTFzKytmbmXFzK2smLmVFTO3smLmVlbM3MqKmVtZMXMrK2ZuZcXMrayYuZUVM7eyYuZWVszcyoqZW1kxcysrZm5lxcytrJi5lRUzt7Ji5lZWzNzKiplbWTFzKytmbmXFzK2smLmVFTO3smLmVlbM3MqKmVtZMXMrK2ZuZcXMrayYuZUVM7eyYuZWVszcyoqZW1kxcysrZm5lxcytrJi5lRUzt7Ji5lZWzNzKiplbWTFzKytmbmXFzK2smLmVFTO3smLmVlbM3MqKmVtZMXMrK2ZuZcXMrayYuZUVM7eyYuZWVszcyoqZW1kxcysrZm5lxcytrJi5lRUzt7Ji5lZWzNzKiplbWTFzKytmbmXFzK2smLmVFTO3smLmVlbM3MqKmVtZMXMrK2ZuZcXMrayYuZUVM7ey4kDlQSrmNSoPUnFgxYHKg1TMa1QepOLAigOVB6mY16g8SMWBFQcqD1Ixr1F5kIoDKw5UHqRiXqPyIBUHVhyoPEjFvEblQSoOrDhQeZCKeY3Kg1QcWHGg8iAV8xqVB6k4sOJA5UEq5jUqD1JxYMWByoNUzGtUHqTiwIoDlQepmNeoPEjFgRUHKg9SMa9ReZCKAysOVB6kYl6j8iAVB1YcqDxIxbxG5UEqDqw4UHmQinmNyoNUHFhxoPIgFfMalQepOLDiQOVBKuY1Kg9ScWDFgcqDVMxrVB6k4sCKA5UHqZjXqDxIxYEVByoPUjGvUXmQigMrDlQepGJeo/IgFQdWHKg8SMW8RuVBKg6sOFB5kIp5jcqDVBxYcaDyIBXzGpUHqTiw4kDlQSrmNSoPUnFgxYHKg1TMa1QepOLAigOVB6mY16g8SMWBFQcqD1Ixr1F5kIoDKw5UHqRiXqPyIBUHVhyoPEjFvEblQSoOrDhQeZCKeY3Kg1QcWHGg8iAV8xqVB6k4sOJA5UEq5jUqD1JxYMWByswjVPyIFQcqM49Q8SNWHKjMPELFj1hxoDLzCBU/YsWByswjVPyIFQcqM49Q8SNWHKjMPELFj1hxoDLz/ioOrDhTmXlzFQdWnKnMvLmKAyvOVGbeXMWBFWcqM2+u4sCKM5WZN1dxYMWZysybqziw4kxl5s1VHFhxpjLz5ioOrDhTmXlzFQdWnKnMvLmKAyvOVGbeXMWBFWcqM2+u4sCKM5WZN1dxYMWZysw7qziz4kMqM2+r4syKD6nMvK2KMys+pDLztirOrPiQyszbqjiz4kMqM2+r4syKD6nMvK2KMys+pDLztirOrPiQyszbqjiz4kMqM2+r4syKD6nMvK2KMys+pDLztirOrPiQyszbqjiz4kMqM2+r4syKD6nMvKeKD1nxMyozb6jiQ1b8jMrMG6r4kBU/ozLzhio+ZMXPqMy8oYoPWfEzKjNvqOJDVvyMyswbqviQFT+jMvNuKn7GiheozLyVip+x4gUqM++j4gVWvEZl5k1UvMCK16jMvIOK11jxMpWZb6/iNVa8TGXme6t4mRW/QmXmG6t4mRW/QmXmu6r4FVb8IpWZ76fiF1nx61RmvpmKX2TFr1OZ+U4qfp0Vv0Vl5nuo+C1W/BaVme+h4rdY8btUZv61it9lxR9Qmfl3Kv6AFX9AZebfqfgDVvwZlZl/oeLPWPFnVGb+uoo/ZsUfU5n5uyr+mBWfQWXmb6n4DFZ8EpWZr1fxSaz4PCozX6zik1jxeVRmvlLF57HiU6nMfI2KT2XFZ1OZ+QIVn8qKL6Ay86kqPpsVX0Nl5pNUfAErvozKzB+r+BpWfCWVmT9Q8WWs+GIqM7+l4itZ8fVUZn5dxVey4uupzPyiii9mxV+hMvOyiq9nxd+iMvOCir/Cir9IZeZDFX+LFX+XysxBxV9kxV+nMvM/Kv4uK/4FlZn/UvHXWfHvqMxAxb9gxT+lMner+Ees+NdU5lYV/44V34DK3Kfin7Li21CZa1T8a1Z8JypzgYpvwIpvRmWeq+LbsOJbUpnHqfhOrPiuVOZBKr4ZK743lXlzFd+SFe9AZd5TxXdlxZtQmXdT8Y1Z8VZU5h1UfHtWvBuV+d4q3oEV70llvp+K92HFO1OZb6PirVjx/lTmn6p4Q1Y8hcr8dRVvy4pnUZm/peKdWfFEKvOVKt6fFc+lMp+t4imseDqV+QwVz2LFHVTmd1U8kRU3UZmXVTyaFfdRmQ9VXMCKi6nMf6m4iRXXU7lexX2smP+ncp+KW1kx/0Pl6SquZ8UcqDxRxfwfK+YFKu+vYv6LFfPrVN5ExRxYMX9G5ZupmBdYMd+Ayp+pmF9kxcytrJi5lRUzt7Ji5lZWzNzKiplbWTFzKytmbmXFzK2smLmVFTO3smLmVlbM3Oo/KaAdTJ6mh88AAAAASUVORK5CYII="""
|
| 33 |
+
# Decode base64 to bytes
|
| 34 |
+
image_data = base64.b64decode(base64_mask)
|
| 35 |
+
# Convert bytes to numpy array
|
| 36 |
+
np_array = np.frombuffer(image_data, np.uint8)
|
| 37 |
+
# Load image with OpenCV
|
| 38 |
+
mask_image = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
|
| 39 |
+
#mask_image = cv2.imread("latentsync/utils/mask.png")
|
| 40 |
+
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
|
| 41 |
+
mask_image = cv2.resize(mask_image, (resolution, resolution), interpolation=cv2.INTER_AREA) / 255.0
|
| 42 |
+
mask_image = rearrange(torch.from_numpy(mask_image), "h w c -> c h w")
|
| 43 |
+
return mask_image
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ImageProcessor:
|
| 47 |
+
def __init__(self, resolution: int = 512, mask: str = "fix_mask", device: str = "cpu", mask_image=None):
|
| 48 |
+
self.resolution = resolution
|
| 49 |
+
self.resize = transforms.Resize(
|
| 50 |
+
(resolution, resolution), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True
|
| 51 |
+
)
|
| 52 |
+
self.normalize = transforms.Normalize([0.5], [0.5], inplace=True)
|
| 53 |
+
self.mask = mask
|
| 54 |
+
|
| 55 |
+
if mask in ["mouth", "face", "eye"]:
|
| 56 |
+
self.face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True) # Process single image
|
| 57 |
+
if mask == "fix_mask":
|
| 58 |
+
self.face_mesh = None
|
| 59 |
+
self.smoother = laplacianSmooth()
|
| 60 |
+
self.restorer = AlignRestore()
|
| 61 |
+
|
| 62 |
+
if mask_image is None:
|
| 63 |
+
self.mask_image = load_fixed_mask(resolution)
|
| 64 |
+
else:
|
| 65 |
+
self.mask_image = mask_image
|
| 66 |
+
|
| 67 |
+
if device != "cpu":
|
| 68 |
+
self.fa = face_alignment.FaceAlignment(
|
| 69 |
+
face_alignment.LandmarksType.TWO_D, flip_input=False, device=device
|
| 70 |
+
)
|
| 71 |
+
self.face_mesh = None
|
| 72 |
+
else:
|
| 73 |
+
# self.face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True) # Process single image
|
| 74 |
+
self.face_mesh = None
|
| 75 |
+
self.fa = None
|
| 76 |
+
|
| 77 |
+
def detect_facial_landmarks(self, image: np.ndarray):
|
| 78 |
+
height, width, _ = image.shape
|
| 79 |
+
results = self.face_mesh.process(image)
|
| 80 |
+
if not results.multi_face_landmarks: # Face not detected
|
| 81 |
+
raise RuntimeError("Face not detected")
|
| 82 |
+
face_landmarks = results.multi_face_landmarks[0] # Only use the first face in the image
|
| 83 |
+
landmark_coordinates = [
|
| 84 |
+
(int(landmark.x * width), int(landmark.y * height)) for landmark in face_landmarks.landmark
|
| 85 |
+
] # x means width, y means height
|
| 86 |
+
return landmark_coordinates
|
| 87 |
+
|
| 88 |
+
def preprocess_one_masked_image(self, image: torch.Tensor) -> np.ndarray:
|
| 89 |
+
image = self.resize(image)
|
| 90 |
+
|
| 91 |
+
if self.mask == "mouth" or self.mask == "face":
|
| 92 |
+
landmark_coordinates = self.detect_facial_landmarks(image)
|
| 93 |
+
if self.mask == "mouth":
|
| 94 |
+
surround_landmarks = mouth_surround_landmarks
|
| 95 |
+
else:
|
| 96 |
+
surround_landmarks = face_surround_landmarks
|
| 97 |
+
|
| 98 |
+
points = [landmark_coordinates[landmark] for landmark in surround_landmarks]
|
| 99 |
+
points = np.array(points)
|
| 100 |
+
mask = np.ones((self.resolution, self.resolution))
|
| 101 |
+
mask = cv2.fillPoly(mask, pts=[points], color=(0, 0, 0))
|
| 102 |
+
mask = torch.from_numpy(mask)
|
| 103 |
+
mask = mask.unsqueeze(0)
|
| 104 |
+
elif self.mask == "half":
|
| 105 |
+
mask = torch.ones((self.resolution, self.resolution))
|
| 106 |
+
height = mask.shape[0]
|
| 107 |
+
mask[height // 2 :, :] = 0
|
| 108 |
+
mask = mask.unsqueeze(0)
|
| 109 |
+
elif self.mask == "eye":
|
| 110 |
+
mask = torch.ones((self.resolution, self.resolution))
|
| 111 |
+
landmark_coordinates = self.detect_facial_landmarks(image)
|
| 112 |
+
y = landmark_coordinates[195][1]
|
| 113 |
+
mask[y:, :] = 0
|
| 114 |
+
mask = mask.unsqueeze(0)
|
| 115 |
+
else:
|
| 116 |
+
raise ValueError("Invalid mask type")
|
| 117 |
+
|
| 118 |
+
image = image.to(dtype=torch.float32)
|
| 119 |
+
pixel_values = self.normalize(image / 255.0)
|
| 120 |
+
masked_pixel_values = pixel_values * mask
|
| 121 |
+
mask = 1 - mask
|
| 122 |
+
|
| 123 |
+
return pixel_values, masked_pixel_values, mask
|
| 124 |
+
|
| 125 |
+
def affine_transform(self, image: torch.Tensor) -> np.ndarray:
|
| 126 |
+
# image = rearrange(image, "c h w-> h w c").numpy()
|
| 127 |
+
if self.fa is None:
|
| 128 |
+
landmark_coordinates = np.array(self.detect_facial_landmarks(image))
|
| 129 |
+
lm68 = mediapipe_lm478_to_face_alignment_lm68(landmark_coordinates)
|
| 130 |
+
else:
|
| 131 |
+
detected_faces = self.fa.get_landmarks(image)
|
| 132 |
+
if detected_faces is None:
|
| 133 |
+
raise RuntimeError("Face not detected")
|
| 134 |
+
lm68 = detected_faces[0]
|
| 135 |
+
|
| 136 |
+
points = self.smoother.smooth(lm68)
|
| 137 |
+
lmk3_ = np.zeros((3, 2))
|
| 138 |
+
lmk3_[0] = points[17:22].mean(0)
|
| 139 |
+
lmk3_[1] = points[22:27].mean(0)
|
| 140 |
+
lmk3_[2] = points[27:36].mean(0)
|
| 141 |
+
# print(lmk3_)
|
| 142 |
+
face, affine_matrix = self.restorer.align_warp_face(
|
| 143 |
+
image.copy(), lmks3=lmk3_, smooth=True, border_mode="constant"
|
| 144 |
+
)
|
| 145 |
+
box = [0, 0, face.shape[1], face.shape[0]] # x1, y1, x2, y2
|
| 146 |
+
face = cv2.resize(face, (self.resolution, self.resolution), interpolation=cv2.INTER_CUBIC)
|
| 147 |
+
face = rearrange(torch.from_numpy(face), "h w c -> c h w")
|
| 148 |
+
return face, box, affine_matrix
|
| 149 |
+
|
| 150 |
+
def preprocess_fixed_mask_image(self, image: torch.Tensor, affine_transform=False):
|
| 151 |
+
if affine_transform:
|
| 152 |
+
image, _, _ = self.affine_transform(image)
|
| 153 |
+
else:
|
| 154 |
+
image = self.resize(image)
|
| 155 |
+
pixel_values = self.normalize(image / 255.0)
|
| 156 |
+
masked_pixel_values = pixel_values * self.mask_image
|
| 157 |
+
return pixel_values, masked_pixel_values, self.mask_image[0:1]
|
| 158 |
+
|
| 159 |
+
def prepare_masks_and_masked_images(self, images: Union[torch.Tensor, np.ndarray], affine_transform=False):
|
| 160 |
+
if isinstance(images, np.ndarray):
|
| 161 |
+
images = torch.from_numpy(images)
|
| 162 |
+
if images.shape[3] == 3:
|
| 163 |
+
images = rearrange(images, "b h w c -> b c h w")
|
| 164 |
+
if self.mask == "fix_mask":
|
| 165 |
+
results = [self.preprocess_fixed_mask_image(image, affine_transform=affine_transform) for image in images]
|
| 166 |
+
else:
|
| 167 |
+
results = [self.preprocess_one_masked_image(image) for image in images]
|
| 168 |
+
|
| 169 |
+
pixel_values_list, masked_pixel_values_list, masks_list = list(zip(*results))
|
| 170 |
+
return torch.stack(pixel_values_list), torch.stack(masked_pixel_values_list), torch.stack(masks_list)
|
| 171 |
+
|
| 172 |
+
def process_images(self, images: Union[torch.Tensor, np.ndarray]):
|
| 173 |
+
if isinstance(images, np.ndarray):
|
| 174 |
+
images = torch.from_numpy(images)
|
| 175 |
+
if images.shape[3] == 3:
|
| 176 |
+
images = rearrange(images, "b h w c -> b c h w")
|
| 177 |
+
images = self.resize(images)
|
| 178 |
+
pixel_values = self.normalize(images / 255.0)
|
| 179 |
+
return pixel_values
|
| 180 |
+
|
| 181 |
+
def close(self):
|
| 182 |
+
if self.face_mesh is not None:
|
| 183 |
+
self.face_mesh.close()
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def mediapipe_lm478_to_face_alignment_lm68(lm478, return_2d=True):
|
| 187 |
+
"""
|
| 188 |
+
lm478: [B, 478, 3] or [478,3]
|
| 189 |
+
"""
|
| 190 |
+
# lm478[..., 0] *= W
|
| 191 |
+
# lm478[..., 1] *= H
|
| 192 |
+
landmarks_extracted = []
|
| 193 |
+
for index in landmark_points_68:
|
| 194 |
+
x = lm478[index][0]
|
| 195 |
+
y = lm478[index][1]
|
| 196 |
+
landmarks_extracted.append((x, y))
|
| 197 |
+
return np.array(landmarks_extracted)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
landmark_points_68 = [
|
| 201 |
+
162,
|
| 202 |
+
234,
|
| 203 |
+
93,
|
| 204 |
+
58,
|
| 205 |
+
172,
|
| 206 |
+
136,
|
| 207 |
+
149,
|
| 208 |
+
148,
|
| 209 |
+
152,
|
| 210 |
+
377,
|
| 211 |
+
378,
|
| 212 |
+
365,
|
| 213 |
+
397,
|
| 214 |
+
288,
|
| 215 |
+
323,
|
| 216 |
+
454,
|
| 217 |
+
389,
|
| 218 |
+
71,
|
| 219 |
+
63,
|
| 220 |
+
105,
|
| 221 |
+
66,
|
| 222 |
+
107,
|
| 223 |
+
336,
|
| 224 |
+
296,
|
| 225 |
+
334,
|
| 226 |
+
293,
|
| 227 |
+
301,
|
| 228 |
+
168,
|
| 229 |
+
197,
|
| 230 |
+
5,
|
| 231 |
+
4,
|
| 232 |
+
75,
|
| 233 |
+
97,
|
| 234 |
+
2,
|
| 235 |
+
326,
|
| 236 |
+
305,
|
| 237 |
+
33,
|
| 238 |
+
160,
|
| 239 |
+
158,
|
| 240 |
+
133,
|
| 241 |
+
153,
|
| 242 |
+
144,
|
| 243 |
+
362,
|
| 244 |
+
385,
|
| 245 |
+
387,
|
| 246 |
+
263,
|
| 247 |
+
373,
|
| 248 |
+
380,
|
| 249 |
+
61,
|
| 250 |
+
39,
|
| 251 |
+
37,
|
| 252 |
+
0,
|
| 253 |
+
267,
|
| 254 |
+
269,
|
| 255 |
+
291,
|
| 256 |
+
405,
|
| 257 |
+
314,
|
| 258 |
+
17,
|
| 259 |
+
84,
|
| 260 |
+
181,
|
| 261 |
+
78,
|
| 262 |
+
82,
|
| 263 |
+
13,
|
| 264 |
+
312,
|
| 265 |
+
308,
|
| 266 |
+
317,
|
| 267 |
+
14,
|
| 268 |
+
87,
|
| 269 |
+
]
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# Refer to https://storage.googleapis.com/mediapipe-assets/documentation/mediapipe_face_landmark_fullsize.png
|
| 273 |
+
mouth_surround_landmarks = [
|
| 274 |
+
164,
|
| 275 |
+
165,
|
| 276 |
+
167,
|
| 277 |
+
92,
|
| 278 |
+
186,
|
| 279 |
+
57,
|
| 280 |
+
43,
|
| 281 |
+
106,
|
| 282 |
+
182,
|
| 283 |
+
83,
|
| 284 |
+
18,
|
| 285 |
+
313,
|
| 286 |
+
406,
|
| 287 |
+
335,
|
| 288 |
+
273,
|
| 289 |
+
287,
|
| 290 |
+
410,
|
| 291 |
+
322,
|
| 292 |
+
391,
|
| 293 |
+
393,
|
| 294 |
+
]
|
| 295 |
+
|
| 296 |
+
face_surround_landmarks = [
|
| 297 |
+
152,
|
| 298 |
+
377,
|
| 299 |
+
400,
|
| 300 |
+
378,
|
| 301 |
+
379,
|
| 302 |
+
365,
|
| 303 |
+
397,
|
| 304 |
+
288,
|
| 305 |
+
435,
|
| 306 |
+
433,
|
| 307 |
+
411,
|
| 308 |
+
425,
|
| 309 |
+
423,
|
| 310 |
+
327,
|
| 311 |
+
326,
|
| 312 |
+
94,
|
| 313 |
+
97,
|
| 314 |
+
98,
|
| 315 |
+
203,
|
| 316 |
+
205,
|
| 317 |
+
187,
|
| 318 |
+
213,
|
| 319 |
+
215,
|
| 320 |
+
58,
|
| 321 |
+
172,
|
| 322 |
+
136,
|
| 323 |
+
150,
|
| 324 |
+
149,
|
| 325 |
+
176,
|
| 326 |
+
148,
|
| 327 |
+
]
|
| 328 |
+
|
| 329 |
+
if __name__ == "__main__":
|
| 330 |
+
image_processor = ImageProcessor(512, mask="fix_mask")
|
| 331 |
+
video = cv2.VideoCapture("/mnt/bn/maliva-gen-ai-v2/chunyu.li/HDTF/original/val/RD_Radio57_000.mp4")
|
| 332 |
+
while True:
|
| 333 |
+
ret, frame = video.read()
|
| 334 |
+
# if not ret:
|
| 335 |
+
# break
|
| 336 |
+
|
| 337 |
+
# cv2.imwrite("image.jpg", frame)
|
| 338 |
+
|
| 339 |
+
frame = rearrange(torch.Tensor(frame).type(torch.uint8), "h w c -> c h w")
|
| 340 |
+
# face, masked_face, _ = image_processor.preprocess_fixed_mask_image(frame, affine_transform=True)
|
| 341 |
+
face, _, _ = image_processor.affine_transform(frame)
|
| 342 |
+
|
| 343 |
+
break
|
| 344 |
+
|
| 345 |
+
face = (rearrange(face, "c h w -> h w c").detach().cpu().numpy()).astype(np.uint8)
|
| 346 |
+
cv2.imwrite("face.jpg", face)
|
| 347 |
+
|
| 348 |
+
# masked_face = (rearrange(masked_face, "c h w -> h w c").detach().cpu().numpy()).astype(np.uint8)
|
| 349 |
+
# cv2.imwrite("masked_face.jpg", masked_face)
|
latentsync/utils/mask.png
ADDED
|
latentsync/utils/util.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import imageio
|
| 17 |
+
import numpy as np
|
| 18 |
+
import json
|
| 19 |
+
from typing import Union
|
| 20 |
+
import matplotlib.pyplot as plt
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
import torchvision
|
| 26 |
+
import torch.distributed as dist
|
| 27 |
+
from torchvision import transforms
|
| 28 |
+
|
| 29 |
+
from tqdm import tqdm
|
| 30 |
+
from einops import rearrange
|
| 31 |
+
import cv2
|
| 32 |
+
from decord import AudioReader, VideoReader
|
| 33 |
+
import shutil
|
| 34 |
+
import subprocess
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Machine epsilon for a float32 (single precision)
|
| 38 |
+
eps = np.finfo(np.float32).eps
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def read_json(filepath: str):
|
| 42 |
+
with open(filepath) as f:
|
| 43 |
+
json_dict = json.load(f)
|
| 44 |
+
return json_dict
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def read_video(video_path: str, change_fps=True, use_decord=True):
|
| 48 |
+
if change_fps:
|
| 49 |
+
temp_dir = "temp"
|
| 50 |
+
if os.path.exists(temp_dir):
|
| 51 |
+
shutil.rmtree(temp_dir)
|
| 52 |
+
os.makedirs(temp_dir, exist_ok=True)
|
| 53 |
+
command = (
|
| 54 |
+
f"ffmpeg -loglevel error -y -nostdin -i {video_path} -r 25 -crf 18 {os.path.join(temp_dir, 'video.mp4')}"
|
| 55 |
+
)
|
| 56 |
+
subprocess.run(command, shell=True)
|
| 57 |
+
target_video_path = os.path.join(temp_dir, "video.mp4")
|
| 58 |
+
else:
|
| 59 |
+
target_video_path = video_path
|
| 60 |
+
|
| 61 |
+
if use_decord:
|
| 62 |
+
return read_video_decord(target_video_path)
|
| 63 |
+
else:
|
| 64 |
+
return read_video_cv2(target_video_path)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def read_video_decord(video_path: str):
|
| 68 |
+
vr = VideoReader(video_path)
|
| 69 |
+
video_frames = vr[:].asnumpy()
|
| 70 |
+
vr.seek(0)
|
| 71 |
+
return video_frames
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def read_video_cv2(video_path: str):
|
| 75 |
+
# Open the video file
|
| 76 |
+
cap = cv2.VideoCapture(video_path)
|
| 77 |
+
|
| 78 |
+
# Check if the video was opened successfully
|
| 79 |
+
if not cap.isOpened():
|
| 80 |
+
print("Error: Could not open video.")
|
| 81 |
+
return np.array([])
|
| 82 |
+
|
| 83 |
+
frames = []
|
| 84 |
+
|
| 85 |
+
while True:
|
| 86 |
+
# Read a frame
|
| 87 |
+
ret, frame = cap.read()
|
| 88 |
+
|
| 89 |
+
# If frame is read correctly ret is True
|
| 90 |
+
if not ret:
|
| 91 |
+
break
|
| 92 |
+
|
| 93 |
+
# Convert BGR to RGB
|
| 94 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 95 |
+
|
| 96 |
+
frames.append(frame_rgb)
|
| 97 |
+
|
| 98 |
+
# Release the video capture object
|
| 99 |
+
cap.release()
|
| 100 |
+
|
| 101 |
+
return np.array(frames)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def read_audio(audio_path: str, audio_sample_rate: int = 16000):
|
| 105 |
+
if audio_path is None:
|
| 106 |
+
raise ValueError("Audio path is required.")
|
| 107 |
+
ar = AudioReader(audio_path, sample_rate=audio_sample_rate, mono=True)
|
| 108 |
+
|
| 109 |
+
# To access the audio samples
|
| 110 |
+
audio_samples = torch.from_numpy(ar[:].asnumpy())
|
| 111 |
+
audio_samples = audio_samples.squeeze(0)
|
| 112 |
+
|
| 113 |
+
return audio_samples
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def write_video(video_output_path: str, video_frames: np.ndarray, fps: int):
|
| 117 |
+
height, width = video_frames[0].shape[:2]
|
| 118 |
+
out = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
|
| 119 |
+
# out = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*"vp09"), fps, (width, height))
|
| 120 |
+
for frame in video_frames:
|
| 121 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
| 122 |
+
out.write(frame)
|
| 123 |
+
out.release()
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def init_dist(backend="nccl", **kwargs):
|
| 127 |
+
"""Initializes distributed environment."""
|
| 128 |
+
rank = int(os.environ["RANK"])
|
| 129 |
+
num_gpus = torch.cuda.device_count()
|
| 130 |
+
if num_gpus == 0:
|
| 131 |
+
raise RuntimeError("No GPUs available for training.")
|
| 132 |
+
local_rank = rank % num_gpus
|
| 133 |
+
torch.cuda.set_device(local_rank)
|
| 134 |
+
dist.init_process_group(backend=backend, **kwargs)
|
| 135 |
+
|
| 136 |
+
return local_rank
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def zero_rank_print(s):
|
| 140 |
+
if dist.is_initialized() and dist.get_rank() == 0:
|
| 141 |
+
print("### " + s)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def zero_rank_log(logger, message: str):
|
| 145 |
+
if dist.is_initialized() and dist.get_rank() == 0:
|
| 146 |
+
logger.info(message)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def make_audio_window(audio_embeddings: torch.Tensor, window_size: int):
|
| 150 |
+
audio_window = []
|
| 151 |
+
end_idx = audio_embeddings.shape[1] - window_size + 1
|
| 152 |
+
for i in range(end_idx):
|
| 153 |
+
audio_window.append(audio_embeddings[:, i : i + window_size, :])
|
| 154 |
+
audio_window = torch.stack(audio_window)
|
| 155 |
+
audio_window = rearrange(audio_window, "f b w d -> b f w d")
|
| 156 |
+
return audio_window
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def check_video_fps(video_path: str):
|
| 160 |
+
cam = cv2.VideoCapture(video_path)
|
| 161 |
+
fps = cam.get(cv2.CAP_PROP_FPS)
|
| 162 |
+
if fps != 25:
|
| 163 |
+
raise ValueError(f"Video FPS is not 25, it is {fps}. Please convert the video to 25 FPS.")
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def tailor_tensor_to_length(tensor: torch.Tensor, length: int):
|
| 167 |
+
if len(tensor) == length:
|
| 168 |
+
return tensor
|
| 169 |
+
elif len(tensor) > length:
|
| 170 |
+
return tensor[:length]
|
| 171 |
+
else:
|
| 172 |
+
return torch.cat([tensor, tensor[-1].repeat(length - len(tensor))])
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
|
| 176 |
+
videos = rearrange(videos, "b c f h w -> f b c h w")
|
| 177 |
+
outputs = []
|
| 178 |
+
for x in videos:
|
| 179 |
+
x = torchvision.utils.make_grid(x, nrow=n_rows)
|
| 180 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
| 181 |
+
if rescale:
|
| 182 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
| 183 |
+
x = (x * 255).numpy().astype(np.uint8)
|
| 184 |
+
outputs.append(x)
|
| 185 |
+
|
| 186 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 187 |
+
imageio.mimsave(path, outputs, fps=fps)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def interpolate_features(features: torch.Tensor, output_len: int) -> torch.Tensor:
|
| 191 |
+
features = features.cpu().numpy()
|
| 192 |
+
input_len, num_features = features.shape
|
| 193 |
+
|
| 194 |
+
input_timesteps = np.linspace(0, 10, input_len)
|
| 195 |
+
output_timesteps = np.linspace(0, 10, output_len)
|
| 196 |
+
output_features = np.zeros((output_len, num_features))
|
| 197 |
+
for feat in range(num_features):
|
| 198 |
+
output_features[:, feat] = np.interp(output_timesteps, input_timesteps, features[:, feat])
|
| 199 |
+
return torch.from_numpy(output_features)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# DDIM Inversion
|
| 203 |
+
@torch.no_grad()
|
| 204 |
+
def init_prompt(prompt, pipeline):
|
| 205 |
+
uncond_input = pipeline.tokenizer(
|
| 206 |
+
[""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, return_tensors="pt"
|
| 207 |
+
)
|
| 208 |
+
uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
|
| 209 |
+
text_input = pipeline.tokenizer(
|
| 210 |
+
[prompt],
|
| 211 |
+
padding="max_length",
|
| 212 |
+
max_length=pipeline.tokenizer.model_max_length,
|
| 213 |
+
truncation=True,
|
| 214 |
+
return_tensors="pt",
|
| 215 |
+
)
|
| 216 |
+
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
|
| 217 |
+
context = torch.cat([uncond_embeddings, text_embeddings])
|
| 218 |
+
|
| 219 |
+
return context
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def reversed_forward(ddim_scheduler, pred_noise, timesteps, x_t):
|
| 223 |
+
# Compute alphas, betas
|
| 224 |
+
alpha_prod_t = ddim_scheduler.alphas_cumprod[timesteps]
|
| 225 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 226 |
+
|
| 227 |
+
# 3. compute predicted original sample from predicted noise also called
|
| 228 |
+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 229 |
+
if ddim_scheduler.config.prediction_type == "epsilon":
|
| 230 |
+
beta_prod_t = beta_prod_t[:, None, None, None, None]
|
| 231 |
+
alpha_prod_t = alpha_prod_t[:, None, None, None, None]
|
| 232 |
+
pred_original_sample = (x_t - beta_prod_t ** (0.5) * pred_noise) / alpha_prod_t ** (0.5)
|
| 233 |
+
else:
|
| 234 |
+
raise NotImplementedError("This prediction type is not implemented yet")
|
| 235 |
+
|
| 236 |
+
# Clip "predicted x_0"
|
| 237 |
+
if ddim_scheduler.config.clip_sample:
|
| 238 |
+
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
|
| 239 |
+
return pred_original_sample
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def next_step(
|
| 243 |
+
model_output: Union[torch.FloatTensor, np.ndarray],
|
| 244 |
+
timestep: int,
|
| 245 |
+
sample: Union[torch.FloatTensor, np.ndarray],
|
| 246 |
+
ddim_scheduler,
|
| 247 |
+
):
|
| 248 |
+
timestep, next_timestep = (
|
| 249 |
+
min(timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999),
|
| 250 |
+
timestep,
|
| 251 |
+
)
|
| 252 |
+
alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
|
| 253 |
+
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
|
| 254 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 255 |
+
next_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
|
| 256 |
+
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
|
| 257 |
+
next_sample = alpha_prod_t_next**0.5 * next_original_sample + next_sample_direction
|
| 258 |
+
return next_sample
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def get_noise_pred_single(latents, t, context, unet):
|
| 262 |
+
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
|
| 263 |
+
return noise_pred
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
@torch.no_grad()
|
| 267 |
+
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
|
| 268 |
+
context = init_prompt(prompt, pipeline)
|
| 269 |
+
uncond_embeddings, cond_embeddings = context.chunk(2)
|
| 270 |
+
all_latent = [latent]
|
| 271 |
+
latent = latent.clone().detach()
|
| 272 |
+
for i in tqdm(range(num_inv_steps)):
|
| 273 |
+
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
|
| 274 |
+
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
|
| 275 |
+
latent = next_step(noise_pred, t, latent, ddim_scheduler)
|
| 276 |
+
all_latent.append(latent)
|
| 277 |
+
return all_latent
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
@torch.no_grad()
|
| 281 |
+
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
|
| 282 |
+
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
|
| 283 |
+
return ddim_latents
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def plot_loss_chart(save_path: str, *args):
|
| 287 |
+
# Creating the plot
|
| 288 |
+
plt.figure()
|
| 289 |
+
for loss_line in args:
|
| 290 |
+
plt.plot(loss_line[1], loss_line[2], label=loss_line[0])
|
| 291 |
+
plt.xlabel("Step")
|
| 292 |
+
plt.ylabel("Loss")
|
| 293 |
+
plt.legend()
|
| 294 |
+
|
| 295 |
+
# Save the figure to a file
|
| 296 |
+
plt.savefig(save_path)
|
| 297 |
+
|
| 298 |
+
# Close the figure to free memory
|
| 299 |
+
plt.close()
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
CRED = "\033[91m"
|
| 303 |
+
CEND = "\033[0m"
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def red_text(text: str):
|
| 307 |
+
return f"{CRED}{text}{CEND}"
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
log_loss = nn.BCELoss(reduction="none")
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def cosine_loss(vision_embeds, audio_embeds, y):
|
| 314 |
+
sims = nn.functional.cosine_similarity(vision_embeds, audio_embeds)
|
| 315 |
+
# sims[sims!=sims] = 0 # remove nan
|
| 316 |
+
# sims = sims.clamp(0, 1)
|
| 317 |
+
loss = log_loss(sims.unsqueeze(1), y).squeeze()
|
| 318 |
+
return loss
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def save_image(image, save_path):
|
| 322 |
+
# input size (C, H, W)
|
| 323 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 324 |
+
image = (image * 255).to(torch.uint8)
|
| 325 |
+
image = transforms.ToPILImage()(image)
|
| 326 |
+
# Save the image copy
|
| 327 |
+
image.save(save_path)
|
| 328 |
+
|
| 329 |
+
# Close the image file
|
| 330 |
+
image.close()
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def gather_loss(loss, device):
|
| 334 |
+
# Sum the local loss across all processes
|
| 335 |
+
local_loss = loss.item()
|
| 336 |
+
global_loss = torch.tensor(local_loss, dtype=torch.float32).to(device)
|
| 337 |
+
dist.all_reduce(global_loss, op=dist.ReduceOp.SUM)
|
| 338 |
+
|
| 339 |
+
# Calculate the average loss across all processes
|
| 340 |
+
global_average_loss = global_loss.item() / dist.get_world_size()
|
| 341 |
+
return global_average_loss
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def gather_video_paths_recursively(input_dir):
|
| 345 |
+
print(f"Recursively gathering video paths of {input_dir} ...")
|
| 346 |
+
paths = []
|
| 347 |
+
gather_video_paths(input_dir, paths)
|
| 348 |
+
return paths
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def gather_video_paths(input_dir, paths):
|
| 352 |
+
for file in sorted(os.listdir(input_dir)):
|
| 353 |
+
if file.endswith(".mp4"):
|
| 354 |
+
filepath = os.path.join(input_dir, file)
|
| 355 |
+
paths.append(filepath)
|
| 356 |
+
elif os.path.isdir(os.path.join(input_dir, file)):
|
| 357 |
+
gather_video_paths(os.path.join(input_dir, file), paths)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def count_video_time(video_path):
|
| 361 |
+
video = cv2.VideoCapture(video_path)
|
| 362 |
+
|
| 363 |
+
frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
|
| 364 |
+
fps = video.get(cv2.CAP_PROP_FPS)
|
| 365 |
+
return frame_count / fps
|
latentsync/whisper/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# latentsync/whisper/__init__.py
|
| 2 |
+
|
| 3 |
+
from .whisper import *
|
| 4 |
+
|
| 5 |
+
__all__ = ["Transcription"]
|
| 6 |
+
|
latentsync/whisper/audio2feature.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/TMElyralab/MuseTalk/blob/main/musetalk/whisper/audio2feature.py
|
| 2 |
+
|
| 3 |
+
from .whisper import load_model
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Audio2Feature:
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
model_path="checkpoints/whisper/tiny.pt",
|
| 13 |
+
device=None,
|
| 14 |
+
audio_embeds_cache_dir=None,
|
| 15 |
+
num_frames=16,
|
| 16 |
+
):
|
| 17 |
+
self.model = load_model(model_path, device)
|
| 18 |
+
self.audio_embeds_cache_dir = audio_embeds_cache_dir
|
| 19 |
+
self.num_frames = num_frames
|
| 20 |
+
self.embedding_dim = self.model.dims.n_audio_state
|
| 21 |
+
|
| 22 |
+
def get_sliced_feature(self, feature_array, vid_idx, audio_feat_length=[2, 2], fps=25):
|
| 23 |
+
"""
|
| 24 |
+
Get sliced features based on a given index
|
| 25 |
+
:param feature_array:
|
| 26 |
+
:param start_idx: the start index of the feature
|
| 27 |
+
:param audio_feat_length:
|
| 28 |
+
:return:
|
| 29 |
+
"""
|
| 30 |
+
length = len(feature_array)
|
| 31 |
+
selected_feature = []
|
| 32 |
+
selected_idx = []
|
| 33 |
+
|
| 34 |
+
center_idx = int(vid_idx * 50 / fps)
|
| 35 |
+
left_idx = center_idx - audio_feat_length[0] * 2
|
| 36 |
+
right_idx = center_idx + (audio_feat_length[1] + 1) * 2
|
| 37 |
+
|
| 38 |
+
for idx in range(left_idx, right_idx):
|
| 39 |
+
idx = max(0, idx)
|
| 40 |
+
idx = min(length - 1, idx)
|
| 41 |
+
x = feature_array[idx]
|
| 42 |
+
selected_feature.append(x)
|
| 43 |
+
selected_idx.append(idx)
|
| 44 |
+
|
| 45 |
+
selected_feature = torch.cat(selected_feature, dim=0)
|
| 46 |
+
selected_feature = selected_feature.reshape(-1, self.embedding_dim) # 50*384
|
| 47 |
+
return selected_feature, selected_idx
|
| 48 |
+
|
| 49 |
+
def get_sliced_feature_sparse(self, feature_array, vid_idx, audio_feat_length=[2, 2], fps=25):
|
| 50 |
+
"""
|
| 51 |
+
Get sliced features based on a given index
|
| 52 |
+
:param feature_array:
|
| 53 |
+
:param start_idx: the start index of the feature
|
| 54 |
+
:param audio_feat_length:
|
| 55 |
+
:return:
|
| 56 |
+
"""
|
| 57 |
+
length = len(feature_array)
|
| 58 |
+
selected_feature = []
|
| 59 |
+
selected_idx = []
|
| 60 |
+
|
| 61 |
+
for dt in range(-audio_feat_length[0], audio_feat_length[1] + 1):
|
| 62 |
+
left_idx = int((vid_idx + dt) * 50 / fps)
|
| 63 |
+
if left_idx < 1 or left_idx > length - 1:
|
| 64 |
+
left_idx = max(0, left_idx)
|
| 65 |
+
left_idx = min(length - 1, left_idx)
|
| 66 |
+
|
| 67 |
+
x = feature_array[left_idx]
|
| 68 |
+
x = x[np.newaxis, :, :]
|
| 69 |
+
x = np.repeat(x, 2, axis=0)
|
| 70 |
+
selected_feature.append(x)
|
| 71 |
+
selected_idx.append(left_idx)
|
| 72 |
+
selected_idx.append(left_idx)
|
| 73 |
+
else:
|
| 74 |
+
x = feature_array[left_idx - 1 : left_idx + 1]
|
| 75 |
+
selected_feature.append(x)
|
| 76 |
+
selected_idx.append(left_idx - 1)
|
| 77 |
+
selected_idx.append(left_idx)
|
| 78 |
+
selected_feature = np.concatenate(selected_feature, axis=0)
|
| 79 |
+
selected_feature = selected_feature.reshape(-1, self.embedding_dim) # 50*384
|
| 80 |
+
selected_feature = torch.from_numpy(selected_feature)
|
| 81 |
+
return selected_feature, selected_idx
|
| 82 |
+
|
| 83 |
+
def feature2chunks(self, feature_array, fps, audio_feat_length=[2, 2]):
|
| 84 |
+
whisper_chunks = []
|
| 85 |
+
whisper_idx_multiplier = 50.0 / fps
|
| 86 |
+
i = 0
|
| 87 |
+
print(f"video in {fps} FPS, audio idx in 50FPS")
|
| 88 |
+
|
| 89 |
+
while True:
|
| 90 |
+
start_idx = int(i * whisper_idx_multiplier)
|
| 91 |
+
selected_feature, selected_idx = self.get_sliced_feature(
|
| 92 |
+
feature_array=feature_array, vid_idx=i, audio_feat_length=audio_feat_length, fps=fps
|
| 93 |
+
)
|
| 94 |
+
# print(f"i:{i},selected_idx {selected_idx}")
|
| 95 |
+
whisper_chunks.append(selected_feature)
|
| 96 |
+
i += 1
|
| 97 |
+
if start_idx > len(feature_array):
|
| 98 |
+
break
|
| 99 |
+
|
| 100 |
+
return whisper_chunks
|
| 101 |
+
|
| 102 |
+
def _audio2feat(self, audio_path: str):
|
| 103 |
+
# get the sample rate of the audio
|
| 104 |
+
result = self.model.transcribe(audio_path)
|
| 105 |
+
embed_list = []
|
| 106 |
+
for emb in result["segments"]:
|
| 107 |
+
encoder_embeddings = emb["encoder_embeddings"]
|
| 108 |
+
encoder_embeddings = encoder_embeddings.transpose(0, 2, 1, 3)
|
| 109 |
+
encoder_embeddings = encoder_embeddings.squeeze(0)
|
| 110 |
+
start_idx = int(emb["start"])
|
| 111 |
+
end_idx = int(emb["end"])
|
| 112 |
+
emb_end_idx = int((end_idx - start_idx) / 2)
|
| 113 |
+
embed_list.append(encoder_embeddings[:emb_end_idx])
|
| 114 |
+
concatenated_array = torch.from_numpy(np.concatenate(embed_list, axis=0))
|
| 115 |
+
return concatenated_array
|
| 116 |
+
|
| 117 |
+
def audio2feat(self, audio_path):
|
| 118 |
+
if self.audio_embeds_cache_dir == "" or self.audio_embeds_cache_dir is None:
|
| 119 |
+
return self._audio2feat(audio_path)
|
| 120 |
+
|
| 121 |
+
audio_embeds_cache_path = os.path.join(self.audio_embeds_cache_dir, os.path.basename(audio_path) + ".pt")
|
| 122 |
+
|
| 123 |
+
if os.path.isfile(audio_embeds_cache_path):
|
| 124 |
+
try:
|
| 125 |
+
audio_feat = torch.load(audio_embeds_cache_path)
|
| 126 |
+
except Exception as e:
|
| 127 |
+
print(f"{type(e).__name__} - {e} - {audio_embeds_cache_path}")
|
| 128 |
+
os.remove(audio_embeds_cache_path)
|
| 129 |
+
audio_feat = self._audio2feat(audio_path)
|
| 130 |
+
torch.save(audio_feat, audio_embeds_cache_path)
|
| 131 |
+
else:
|
| 132 |
+
audio_feat = self._audio2feat(audio_path)
|
| 133 |
+
torch.save(audio_feat, audio_embeds_cache_path)
|
| 134 |
+
|
| 135 |
+
return audio_feat
|
| 136 |
+
|
| 137 |
+
def crop_overlap_audio_window(self, audio_feat, start_index):
|
| 138 |
+
selected_feature_list = []
|
| 139 |
+
for i in range(start_index, start_index + self.num_frames):
|
| 140 |
+
selected_feature, selected_idx = self.get_sliced_feature(
|
| 141 |
+
feature_array=audio_feat, vid_idx=i, audio_feat_length=[2, 2], fps=25
|
| 142 |
+
)
|
| 143 |
+
selected_feature_list.append(selected_feature)
|
| 144 |
+
mel_overlap = torch.stack(selected_feature_list)
|
| 145 |
+
return mel_overlap
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
if __name__ == "__main__":
|
| 149 |
+
audio_encoder = Audio2Feature(model_path="checkpoints/whisper/tiny.pt")
|
| 150 |
+
audio_path = "assets/demo1_audio.wav"
|
| 151 |
+
array = audio_encoder.audio2feat(audio_path)
|
| 152 |
+
print(array.shape)
|
| 153 |
+
fps = 25
|
| 154 |
+
whisper_idx_multiplier = 50.0 / fps
|
| 155 |
+
|
| 156 |
+
i = 0
|
| 157 |
+
print(f"video in {fps} FPS, audio idx in 50FPS")
|
| 158 |
+
while True:
|
| 159 |
+
start_idx = int(i * whisper_idx_multiplier)
|
| 160 |
+
selected_feature, selected_idx = audio_encoder.get_sliced_feature(
|
| 161 |
+
feature_array=array, vid_idx=i, audio_feat_length=[2, 2], fps=fps
|
| 162 |
+
)
|
| 163 |
+
print(f"video idx {i},\t audio idx {selected_idx},\t shape {selected_feature.shape}")
|
| 164 |
+
i += 1
|
| 165 |
+
if start_idx > len(array):
|
| 166 |
+
break
|
latentsync/whisper/whisper/__init__.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import io
|
| 3 |
+
import os
|
| 4 |
+
import urllib
|
| 5 |
+
import warnings
|
| 6 |
+
from typing import List, Optional, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
| 12 |
+
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
| 13 |
+
from .model import Whisper, ModelDimensions
|
| 14 |
+
from .transcribe import transcribe
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
_MODELS = {
|
| 18 |
+
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
| 19 |
+
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
| 20 |
+
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
| 21 |
+
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
| 22 |
+
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
| 23 |
+
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
| 24 |
+
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
| 25 |
+
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
| 26 |
+
"large": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt",
|
| 27 |
+
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
| 28 |
+
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
| 29 |
+
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
| 34 |
+
os.makedirs(root, exist_ok=True)
|
| 35 |
+
|
| 36 |
+
expected_sha256 = url.split("/")[-2]
|
| 37 |
+
download_target = os.path.join(root, os.path.basename(url))
|
| 38 |
+
|
| 39 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
| 40 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
| 41 |
+
|
| 42 |
+
if os.path.isfile(download_target):
|
| 43 |
+
model_bytes = open(download_target, "rb").read()
|
| 44 |
+
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
| 45 |
+
return model_bytes if in_memory else download_target
|
| 46 |
+
else:
|
| 47 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
| 48 |
+
|
| 49 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
| 50 |
+
with tqdm(
|
| 51 |
+
total=int(source.info().get("Content-Length")), ncols=80, unit="iB", unit_scale=True, unit_divisor=1024
|
| 52 |
+
) as loop:
|
| 53 |
+
while True:
|
| 54 |
+
buffer = source.read(8192)
|
| 55 |
+
if not buffer:
|
| 56 |
+
break
|
| 57 |
+
|
| 58 |
+
output.write(buffer)
|
| 59 |
+
loop.update(len(buffer))
|
| 60 |
+
|
| 61 |
+
model_bytes = open(download_target, "rb").read()
|
| 62 |
+
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
| 63 |
+
raise RuntimeError(
|
| 64 |
+
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
return model_bytes if in_memory else download_target
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def available_models() -> List[str]:
|
| 71 |
+
"""Returns the names of available models"""
|
| 72 |
+
return list(_MODELS.keys())
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def load_model(
|
| 76 |
+
name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False
|
| 77 |
+
) -> Whisper:
|
| 78 |
+
"""
|
| 79 |
+
Load a Whisper ASR model
|
| 80 |
+
|
| 81 |
+
Parameters
|
| 82 |
+
----------
|
| 83 |
+
name : str
|
| 84 |
+
one of the official model names listed by `whisper.available_models()`, or
|
| 85 |
+
path to a model checkpoint containing the model dimensions and the model state_dict.
|
| 86 |
+
device : Union[str, torch.device]
|
| 87 |
+
the PyTorch device to put the model into
|
| 88 |
+
download_root: str
|
| 89 |
+
path to download the model files; by default, it uses "~/.cache/whisper"
|
| 90 |
+
in_memory: bool
|
| 91 |
+
whether to preload the model weights into host memory
|
| 92 |
+
|
| 93 |
+
Returns
|
| 94 |
+
-------
|
| 95 |
+
model : Whisper
|
| 96 |
+
The Whisper ASR model instance
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
if device is None:
|
| 100 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 101 |
+
if download_root is None:
|
| 102 |
+
download_root = os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache", "whisper"))
|
| 103 |
+
|
| 104 |
+
if name in _MODELS:
|
| 105 |
+
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
| 106 |
+
elif os.path.isfile(name):
|
| 107 |
+
checkpoint_file = open(name, "rb").read() if in_memory else name
|
| 108 |
+
else:
|
| 109 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
| 110 |
+
|
| 111 |
+
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
|
| 112 |
+
checkpoint = torch.load(fp, map_location=device)
|
| 113 |
+
del checkpoint_file
|
| 114 |
+
|
| 115 |
+
dims = ModelDimensions(**checkpoint["dims"])
|
| 116 |
+
model = Whisper(dims)
|
| 117 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 118 |
+
|
| 119 |
+
return model.to(device)
|
latentsync/whisper/whisper/__main__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .transcribe import cli
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
cli()
|
latentsync/whisper/whisper/assets/gpt2/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
latentsync/whisper/whisper/assets/gpt2/special_tokens_map.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
|
latentsync/whisper/whisper/assets/gpt2/tokenizer_config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}
|
latentsync/whisper/whisper/assets/gpt2/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
latentsync/whisper/whisper/assets/mel_filters.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dd2cc75e70e36fcbdd8ffbc2499062f30094093e6bf2cbafa9859f59972b420b
|
| 3 |
+
size 2048
|
latentsync/whisper/whisper/assets/multilingual/added_tokens.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"<|endoftext|>": 50257}
|
latentsync/whisper/whisper/assets/multilingual/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
latentsync/whisper/whisper/assets/multilingual/special_tokens_map.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
|
latentsync/whisper/whisper/assets/multilingual/tokenizer_config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}
|
latentsync/whisper/whisper/assets/multilingual/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
latentsync/whisper/whisper/audio.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from functools import lru_cache
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
import ffmpeg
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from .utils import exact_div
|
| 11 |
+
|
| 12 |
+
# hard-coded audio hyperparameters
|
| 13 |
+
SAMPLE_RATE = 16000
|
| 14 |
+
N_FFT = 400
|
| 15 |
+
N_MELS = 80
|
| 16 |
+
HOP_LENGTH = 160
|
| 17 |
+
CHUNK_LENGTH = 30
|
| 18 |
+
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
|
| 19 |
+
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
| 23 |
+
"""
|
| 24 |
+
Open an audio file and read as mono waveform, resampling as necessary
|
| 25 |
+
|
| 26 |
+
Parameters
|
| 27 |
+
----------
|
| 28 |
+
file: str
|
| 29 |
+
The audio file to open
|
| 30 |
+
|
| 31 |
+
sr: int
|
| 32 |
+
The sample rate to resample the audio if necessary
|
| 33 |
+
|
| 34 |
+
Returns
|
| 35 |
+
-------
|
| 36 |
+
A NumPy array containing the audio waveform, in float32 dtype.
|
| 37 |
+
"""
|
| 38 |
+
try:
|
| 39 |
+
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
| 40 |
+
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
|
| 41 |
+
out, _ = (
|
| 42 |
+
ffmpeg.input(file, threads=0)
|
| 43 |
+
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
|
| 44 |
+
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
|
| 45 |
+
)
|
| 46 |
+
except ffmpeg.Error as e:
|
| 47 |
+
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
| 48 |
+
|
| 49 |
+
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
| 53 |
+
"""
|
| 54 |
+
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
| 55 |
+
"""
|
| 56 |
+
if torch.is_tensor(array):
|
| 57 |
+
if array.shape[axis] > length:
|
| 58 |
+
array = array.index_select(dim=axis, index=torch.arange(length))
|
| 59 |
+
|
| 60 |
+
if array.shape[axis] < length:
|
| 61 |
+
pad_widths = [(0, 0)] * array.ndim
|
| 62 |
+
pad_widths[axis] = (0, length - array.shape[axis])
|
| 63 |
+
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
|
| 64 |
+
else:
|
| 65 |
+
if array.shape[axis] > length:
|
| 66 |
+
array = array.take(indices=range(length), axis=axis)
|
| 67 |
+
|
| 68 |
+
if array.shape[axis] < length:
|
| 69 |
+
pad_widths = [(0, 0)] * array.ndim
|
| 70 |
+
pad_widths[axis] = (0, length - array.shape[axis])
|
| 71 |
+
array = np.pad(array, pad_widths)
|
| 72 |
+
|
| 73 |
+
return array
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@lru_cache(maxsize=None)
|
| 77 |
+
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
| 78 |
+
"""
|
| 79 |
+
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
| 80 |
+
Allows decoupling librosa dependency; saved using:
|
| 81 |
+
|
| 82 |
+
np.savez_compressed(
|
| 83 |
+
"mel_filters.npz",
|
| 84 |
+
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
| 85 |
+
)
|
| 86 |
+
"""
|
| 87 |
+
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
| 88 |
+
hard_path="/data/data/assets/mel_filters.npz"
|
| 89 |
+
try:
|
| 90 |
+
with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
|
| 91 |
+
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
| 92 |
+
except:
|
| 93 |
+
print("path not found we will try hard link for demo")
|
| 94 |
+
try:
|
| 95 |
+
with np.load(hard_path) as f:
|
| 96 |
+
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
| 97 |
+
except e:
|
| 98 |
+
print(e)
|
| 99 |
+
def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
|
| 100 |
+
"""
|
| 101 |
+
Compute the log-Mel spectrogram of
|
| 102 |
+
|
| 103 |
+
Parameters
|
| 104 |
+
----------
|
| 105 |
+
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
| 106 |
+
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
| 107 |
+
|
| 108 |
+
n_mels: int
|
| 109 |
+
The number of Mel-frequency filters, only 80 is supported
|
| 110 |
+
|
| 111 |
+
Returns
|
| 112 |
+
-------
|
| 113 |
+
torch.Tensor, shape = (80, n_frames)
|
| 114 |
+
A Tensor that contains the Mel spectrogram
|
| 115 |
+
"""
|
| 116 |
+
if not torch.is_tensor(audio):
|
| 117 |
+
if isinstance(audio, str):
|
| 118 |
+
audio = load_audio(audio)
|
| 119 |
+
audio = torch.from_numpy(audio)
|
| 120 |
+
|
| 121 |
+
window = torch.hann_window(N_FFT).to(audio.device)
|
| 122 |
+
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
| 123 |
+
|
| 124 |
+
magnitudes = stft[:, :-1].abs() ** 2
|
| 125 |
+
|
| 126 |
+
filters = mel_filters(audio.device, n_mels)
|
| 127 |
+
mel_spec = filters @ magnitudes
|
| 128 |
+
|
| 129 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
| 130 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
| 131 |
+
log_spec = (log_spec + 4.0) / 4.0
|
| 132 |
+
return log_spec
|
latentsync/whisper/whisper/decoding.py
ADDED
|
@@ -0,0 +1,729 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from torch.distributions import Categorical
|
| 9 |
+
|
| 10 |
+
from .audio import CHUNK_LENGTH
|
| 11 |
+
from .tokenizer import Tokenizer, get_tokenizer
|
| 12 |
+
from .utils import compression_ratio
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from .model import Whisper
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@torch.no_grad()
|
| 19 |
+
def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
|
| 20 |
+
"""
|
| 21 |
+
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
| 22 |
+
of the most probable language tokens and the probability distribution over all language tokens.
|
| 23 |
+
This is performed outside the main decode loop in order to not interfere with kv-caching.
|
| 24 |
+
|
| 25 |
+
Returns
|
| 26 |
+
-------
|
| 27 |
+
language_tokens : Tensor, shape = (n_audio,)
|
| 28 |
+
ids of the most probable language tokens, which appears after the startoftranscript token.
|
| 29 |
+
language_probs : List[Dict[str, float]], length = n_audio
|
| 30 |
+
list of dictionaries containing the probability distribution over all languages.
|
| 31 |
+
"""
|
| 32 |
+
if tokenizer is None:
|
| 33 |
+
tokenizer = get_tokenizer(model.is_multilingual)
|
| 34 |
+
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
|
| 35 |
+
raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
|
| 36 |
+
|
| 37 |
+
single = mel.ndim == 2
|
| 38 |
+
if single:
|
| 39 |
+
mel = mel.unsqueeze(0)
|
| 40 |
+
|
| 41 |
+
# skip encoder forward pass if already-encoded audio features were given
|
| 42 |
+
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
|
| 43 |
+
mel = model.encoder(mel)
|
| 44 |
+
|
| 45 |
+
# forward pass using a single token, startoftranscript
|
| 46 |
+
n_audio = mel.shape[0]
|
| 47 |
+
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
| 48 |
+
logits = model.logits(x, mel)[:, 0]
|
| 49 |
+
|
| 50 |
+
# collect detected languages; suppress all non-language tokens
|
| 51 |
+
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
| 52 |
+
mask[list(tokenizer.all_language_tokens)] = False
|
| 53 |
+
logits[:, mask] = -np.inf
|
| 54 |
+
language_tokens = logits.argmax(dim=-1)
|
| 55 |
+
language_token_probs = logits.softmax(dim=-1).cpu()
|
| 56 |
+
language_probs = [
|
| 57 |
+
{
|
| 58 |
+
c: language_token_probs[i, j].item()
|
| 59 |
+
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
|
| 60 |
+
}
|
| 61 |
+
for i in range(n_audio)
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
if single:
|
| 65 |
+
language_tokens = language_tokens[0]
|
| 66 |
+
language_probs = language_probs[0]
|
| 67 |
+
|
| 68 |
+
return language_tokens, language_probs
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass(frozen=True)
|
| 72 |
+
class DecodingOptions:
|
| 73 |
+
task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
|
| 74 |
+
language: Optional[str] = None # language that the audio is in; uses detected language if None
|
| 75 |
+
|
| 76 |
+
# sampling-related options
|
| 77 |
+
temperature: float = 0.0
|
| 78 |
+
sample_len: Optional[int] = None # maximum number of tokens to sample
|
| 79 |
+
best_of: Optional[int] = None # number of independent samples to collect, when t > 0
|
| 80 |
+
beam_size: Optional[int] = None # number of beams in beam search, when t == 0
|
| 81 |
+
patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
|
| 82 |
+
|
| 83 |
+
# options for ranking generations (either beams or best-of-N samples)
|
| 84 |
+
length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
|
| 85 |
+
|
| 86 |
+
# prompt, prefix, and token suppression
|
| 87 |
+
prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
|
| 88 |
+
prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
|
| 89 |
+
suppress_blank: bool = True # this will suppress blank outputs
|
| 90 |
+
|
| 91 |
+
# list of tokens ids (or comma-separated token ids) to suppress
|
| 92 |
+
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
| 93 |
+
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
| 94 |
+
|
| 95 |
+
# timestamp sampling options
|
| 96 |
+
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
| 97 |
+
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
|
| 98 |
+
|
| 99 |
+
# implementation details
|
| 100 |
+
fp16: bool = True # use fp16 for most of the calculation
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@dataclass(frozen=True)
|
| 104 |
+
class DecodingResult:
|
| 105 |
+
audio_features: Tensor
|
| 106 |
+
language: str
|
| 107 |
+
encoder_embeddings: np.ndarray
|
| 108 |
+
decoder_embeddings: np.ndarray
|
| 109 |
+
language_probs: Optional[Dict[str, float]] = None
|
| 110 |
+
tokens: List[int] = field(default_factory=list)
|
| 111 |
+
text: str = ""
|
| 112 |
+
avg_logprob: float = np.nan
|
| 113 |
+
no_speech_prob: float = np.nan
|
| 114 |
+
temperature: float = np.nan
|
| 115 |
+
compression_ratio: float = np.nan
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class Inference:
|
| 119 |
+
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
| 120 |
+
"""Perform a forward pass on the decoder and return per-token logits"""
|
| 121 |
+
raise NotImplementedError
|
| 122 |
+
|
| 123 |
+
def rearrange_kv_cache(self, source_indices) -> None:
|
| 124 |
+
"""Update the key-value cache according to the updated beams"""
|
| 125 |
+
raise NotImplementedError
|
| 126 |
+
|
| 127 |
+
def cleanup_caching(self) -> None:
|
| 128 |
+
"""Clean up any resources or hooks after decoding is finished"""
|
| 129 |
+
pass
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class PyTorchInference(Inference):
|
| 133 |
+
def __init__(self, model: "Whisper", initial_token_length: int):
|
| 134 |
+
self.model: "Whisper" = model
|
| 135 |
+
self.initial_token_length = initial_token_length
|
| 136 |
+
self.kv_cache = {}
|
| 137 |
+
self.hooks = []
|
| 138 |
+
|
| 139 |
+
def logits(self, tokens: Tensor, audio_features: Tensor, include_embeddings=False) -> Tensor:
|
| 140 |
+
if not self.kv_cache:
|
| 141 |
+
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
| 142 |
+
|
| 143 |
+
if tokens.shape[-1] > self.initial_token_length:
|
| 144 |
+
# only need to use the last token except in the first forward pass
|
| 145 |
+
tokens = tokens[:, -1:]
|
| 146 |
+
|
| 147 |
+
return_val = self.model.decoder(tokens, audio_features,
|
| 148 |
+
kv_cache=self.kv_cache, include_embeddings=include_embeddings)
|
| 149 |
+
return return_val
|
| 150 |
+
|
| 151 |
+
def cleanup_caching(self):
|
| 152 |
+
for hook in self.hooks:
|
| 153 |
+
hook.remove()
|
| 154 |
+
|
| 155 |
+
self.kv_cache = {}
|
| 156 |
+
self.hooks = []
|
| 157 |
+
|
| 158 |
+
def rearrange_kv_cache(self, source_indices):
|
| 159 |
+
for module, tensor in self.kv_cache.items():
|
| 160 |
+
# update the key/value cache to contain the selected sequences
|
| 161 |
+
self.kv_cache[module] = tensor[source_indices].detach()
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class SequenceRanker:
|
| 165 |
+
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
|
| 166 |
+
"""
|
| 167 |
+
Given a list of groups of samples and their cumulative log probabilities,
|
| 168 |
+
return the indices of the samples in each group to select as the final result
|
| 169 |
+
"""
|
| 170 |
+
raise NotImplementedError
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class MaximumLikelihoodRanker(SequenceRanker):
|
| 174 |
+
"""
|
| 175 |
+
Select the sample with the highest log probabilities, penalized using either
|
| 176 |
+
a simple length normalization or Google NMT paper's length penalty
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
def __init__(self, length_penalty: Optional[float]):
|
| 180 |
+
self.length_penalty = length_penalty
|
| 181 |
+
|
| 182 |
+
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
|
| 183 |
+
def scores(logprobs, lengths):
|
| 184 |
+
result = []
|
| 185 |
+
for logprob, length in zip(logprobs, lengths):
|
| 186 |
+
if self.length_penalty is None:
|
| 187 |
+
penalty = length
|
| 188 |
+
else:
|
| 189 |
+
# from the Google NMT paper
|
| 190 |
+
penalty = ((5 + length) / 6) ** self.length_penalty
|
| 191 |
+
result.append(logprob / penalty)
|
| 192 |
+
return result
|
| 193 |
+
|
| 194 |
+
# get the sequence with the highest score
|
| 195 |
+
lengths = [[len(t) for t in s] for s in tokens]
|
| 196 |
+
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class TokenDecoder:
|
| 200 |
+
def reset(self):
|
| 201 |
+
"""Initialize any stateful variables for decoding a new sequence"""
|
| 202 |
+
|
| 203 |
+
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
| 204 |
+
"""Specify how to select the next token, based on the current trace and logits
|
| 205 |
+
|
| 206 |
+
Parameters
|
| 207 |
+
----------
|
| 208 |
+
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
| 209 |
+
all tokens in the context so far, including the prefix and sot_sequence tokens
|
| 210 |
+
|
| 211 |
+
logits : Tensor, shape = (n_batch, vocab_size)
|
| 212 |
+
per-token logits of the probability distribution at the current step
|
| 213 |
+
|
| 214 |
+
sum_logprobs : Tensor, shape = (n_batch)
|
| 215 |
+
cumulative log probabilities for each sequence
|
| 216 |
+
|
| 217 |
+
Returns
|
| 218 |
+
-------
|
| 219 |
+
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
|
| 220 |
+
the tokens, appended with the selected next token
|
| 221 |
+
|
| 222 |
+
completed : bool
|
| 223 |
+
True if all sequences has reached the end of text
|
| 224 |
+
|
| 225 |
+
"""
|
| 226 |
+
raise NotImplementedError
|
| 227 |
+
|
| 228 |
+
def finalize(
|
| 229 |
+
self, tokens: Tensor, sum_logprobs: Tensor
|
| 230 |
+
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
|
| 231 |
+
"""Finalize search and return the final candidate sequences
|
| 232 |
+
|
| 233 |
+
Parameters
|
| 234 |
+
----------
|
| 235 |
+
tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
|
| 236 |
+
all tokens in the context so far, including the prefix and sot_sequence
|
| 237 |
+
|
| 238 |
+
sum_logprobs : Tensor, shape = (n_audio, n_group)
|
| 239 |
+
cumulative log probabilities for each sequence
|
| 240 |
+
|
| 241 |
+
Returns
|
| 242 |
+
-------
|
| 243 |
+
tokens : Sequence[Sequence[Tensor]], length = n_audio
|
| 244 |
+
sequence of Tensors containing candidate token sequences, for each audio input
|
| 245 |
+
|
| 246 |
+
sum_logprobs : List[List[float]], length = n_audio
|
| 247 |
+
sequence of cumulative log probabilities corresponding to the above
|
| 248 |
+
|
| 249 |
+
"""
|
| 250 |
+
raise NotImplementedError
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class GreedyDecoder(TokenDecoder):
|
| 254 |
+
def __init__(self, temperature: float, eot: int):
|
| 255 |
+
self.temperature = temperature
|
| 256 |
+
self.eot = eot
|
| 257 |
+
|
| 258 |
+
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
| 259 |
+
temperature = self.temperature
|
| 260 |
+
if temperature == 0:
|
| 261 |
+
next_tokens = logits.argmax(dim=-1)
|
| 262 |
+
else:
|
| 263 |
+
next_tokens = Categorical(logits=logits / temperature).sample()
|
| 264 |
+
|
| 265 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
| 266 |
+
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
| 267 |
+
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
| 268 |
+
|
| 269 |
+
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
| 270 |
+
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
| 271 |
+
|
| 272 |
+
completed = (tokens[:, -1] == self.eot).all()
|
| 273 |
+
return tokens, completed
|
| 274 |
+
|
| 275 |
+
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
|
| 276 |
+
# make sure each sequence has at least one EOT token at the end
|
| 277 |
+
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
| 278 |
+
return tokens, sum_logprobs.tolist()
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class BeamSearchDecoder(TokenDecoder):
|
| 282 |
+
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
|
| 283 |
+
self.beam_size = beam_size
|
| 284 |
+
self.eot = eot
|
| 285 |
+
self.inference = inference
|
| 286 |
+
self.patience = patience or 1.0
|
| 287 |
+
self.max_candidates: int = round(beam_size * self.patience)
|
| 288 |
+
self.finished_sequences = None
|
| 289 |
+
|
| 290 |
+
assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
|
| 291 |
+
|
| 292 |
+
def reset(self):
|
| 293 |
+
self.finished_sequences = None
|
| 294 |
+
|
| 295 |
+
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
| 296 |
+
if tokens.shape[0] % self.beam_size != 0:
|
| 297 |
+
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
| 298 |
+
|
| 299 |
+
n_audio = tokens.shape[0] // self.beam_size
|
| 300 |
+
if self.finished_sequences is None: # for the first update
|
| 301 |
+
self.finished_sequences = [{} for _ in range(n_audio)]
|
| 302 |
+
|
| 303 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
| 304 |
+
next_tokens, source_indices, finished_sequences = [], [], []
|
| 305 |
+
for i in range(n_audio):
|
| 306 |
+
scores, sources, finished = {}, {}, {}
|
| 307 |
+
|
| 308 |
+
# STEP 1: calculate the cumulative log probabilities for possible candidates
|
| 309 |
+
for j in range(self.beam_size):
|
| 310 |
+
idx = i * self.beam_size + j
|
| 311 |
+
prefix = tokens[idx].tolist()
|
| 312 |
+
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
|
| 313 |
+
new_logprob = (sum_logprobs[idx] + logprob).item()
|
| 314 |
+
sequence = tuple(prefix + [token.item()])
|
| 315 |
+
scores[sequence] = new_logprob
|
| 316 |
+
sources[sequence] = idx
|
| 317 |
+
|
| 318 |
+
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
|
| 319 |
+
saved = 0
|
| 320 |
+
for sequence in sorted(scores, key=scores.get, reverse=True):
|
| 321 |
+
if sequence[-1] == self.eot:
|
| 322 |
+
finished[sequence] = scores[sequence]
|
| 323 |
+
else:
|
| 324 |
+
sum_logprobs[len(next_tokens)] = scores[sequence]
|
| 325 |
+
next_tokens.append(sequence)
|
| 326 |
+
source_indices.append(sources[sequence])
|
| 327 |
+
|
| 328 |
+
saved += 1
|
| 329 |
+
if saved == self.beam_size:
|
| 330 |
+
break
|
| 331 |
+
|
| 332 |
+
finished_sequences.append(finished)
|
| 333 |
+
|
| 334 |
+
tokens = torch.tensor(next_tokens, device=tokens.device)
|
| 335 |
+
self.inference.rearrange_kv_cache(source_indices)
|
| 336 |
+
|
| 337 |
+
# add newly finished sequences to self.finished_sequences
|
| 338 |
+
assert len(self.finished_sequences) == len(finished_sequences)
|
| 339 |
+
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
|
| 340 |
+
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
| 341 |
+
if len(previously_finished) >= self.max_candidates:
|
| 342 |
+
break # the candidate list is full
|
| 343 |
+
previously_finished[seq] = newly_finished[seq]
|
| 344 |
+
|
| 345 |
+
# mark as completed if all audio has enough number of samples
|
| 346 |
+
completed = all(
|
| 347 |
+
len(sequences) >= self.max_candidates for sequences in self.finished_sequences
|
| 348 |
+
)
|
| 349 |
+
return tokens, completed
|
| 350 |
+
|
| 351 |
+
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
|
| 352 |
+
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
| 353 |
+
sum_logprobs = sum_logprobs.cpu()
|
| 354 |
+
for i, sequences in enumerate(self.finished_sequences):
|
| 355 |
+
if len(sequences) < self.beam_size: # when not enough sequences are finished
|
| 356 |
+
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
| 357 |
+
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
| 358 |
+
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
| 359 |
+
if len(sequences) >= self.beam_size:
|
| 360 |
+
break
|
| 361 |
+
|
| 362 |
+
tokens: List[List[Tensor]] = [
|
| 363 |
+
[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
|
| 364 |
+
]
|
| 365 |
+
sum_logprobs: List[List[float]] = [
|
| 366 |
+
list(sequences.values()) for sequences in self.finished_sequences
|
| 367 |
+
]
|
| 368 |
+
return tokens, sum_logprobs
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class LogitFilter:
|
| 372 |
+
def apply(self, logits: Tensor, tokens: Tensor) -> None:
|
| 373 |
+
"""Apply any filtering or masking to logits in-place
|
| 374 |
+
|
| 375 |
+
Parameters
|
| 376 |
+
----------
|
| 377 |
+
logits : Tensor, shape = (n_batch, vocab_size)
|
| 378 |
+
per-token logits of the probability distribution at the current step
|
| 379 |
+
|
| 380 |
+
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
| 381 |
+
all tokens in the context so far, including the prefix and sot_sequence tokens
|
| 382 |
+
|
| 383 |
+
"""
|
| 384 |
+
raise NotImplementedError
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class SuppressBlank(LogitFilter):
|
| 388 |
+
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
|
| 389 |
+
self.tokenizer = tokenizer
|
| 390 |
+
self.sample_begin = sample_begin
|
| 391 |
+
|
| 392 |
+
def apply(self, logits: Tensor, tokens: Tensor):
|
| 393 |
+
if tokens.shape[1] == self.sample_begin:
|
| 394 |
+
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
class SuppressTokens(LogitFilter):
|
| 398 |
+
def __init__(self, suppress_tokens: Sequence[int]):
|
| 399 |
+
self.suppress_tokens = list(suppress_tokens)
|
| 400 |
+
|
| 401 |
+
def apply(self, logits: Tensor, tokens: Tensor):
|
| 402 |
+
logits[:, self.suppress_tokens] = -np.inf
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class ApplyTimestampRules(LogitFilter):
|
| 406 |
+
def __init__(
|
| 407 |
+
self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
|
| 408 |
+
):
|
| 409 |
+
self.tokenizer = tokenizer
|
| 410 |
+
self.sample_begin = sample_begin
|
| 411 |
+
self.max_initial_timestamp_index = max_initial_timestamp_index
|
| 412 |
+
|
| 413 |
+
def apply(self, logits: Tensor, tokens: Tensor):
|
| 414 |
+
# suppress <|notimestamps|> which is handled by without_timestamps
|
| 415 |
+
if self.tokenizer.no_timestamps is not None:
|
| 416 |
+
logits[:, self.tokenizer.no_timestamps] = -np.inf
|
| 417 |
+
|
| 418 |
+
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
| 419 |
+
for k in range(tokens.shape[0]):
|
| 420 |
+
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
|
| 421 |
+
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
| 422 |
+
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
| 423 |
+
|
| 424 |
+
if last_was_timestamp:
|
| 425 |
+
if penultimate_was_timestamp: # has to be non-timestamp
|
| 426 |
+
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
|
| 427 |
+
else: # cannot be normal text tokens
|
| 428 |
+
logits[k, : self.tokenizer.eot] = -np.inf
|
| 429 |
+
|
| 430 |
+
# apply the `max_initial_timestamp` option
|
| 431 |
+
if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None:
|
| 432 |
+
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
| 433 |
+
logits[:, last_allowed + 1 :] = -np.inf
|
| 434 |
+
|
| 435 |
+
# if sum of probability over timestamps is above any other token, sample timestamp
|
| 436 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
| 437 |
+
for k in range(tokens.shape[0]):
|
| 438 |
+
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
|
| 439 |
+
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
|
| 440 |
+
if timestamp_logprob > max_text_token_logprob:
|
| 441 |
+
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
class DecodingTask:
|
| 445 |
+
inference: Inference
|
| 446 |
+
sequence_ranker: SequenceRanker
|
| 447 |
+
decoder: TokenDecoder
|
| 448 |
+
logit_filters: List[LogitFilter]
|
| 449 |
+
|
| 450 |
+
def __init__(self, model: "Whisper", options: DecodingOptions):
|
| 451 |
+
self.model = model
|
| 452 |
+
|
| 453 |
+
language = options.language or "en"
|
| 454 |
+
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
|
| 455 |
+
self.tokenizer: Tokenizer = tokenizer
|
| 456 |
+
self.options: DecodingOptions = self._verify_options(options)
|
| 457 |
+
|
| 458 |
+
self.n_group: int = options.beam_size or options.best_of or 1
|
| 459 |
+
self.n_ctx: int = model.dims.n_text_ctx
|
| 460 |
+
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
|
| 461 |
+
|
| 462 |
+
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
|
| 463 |
+
if self.options.without_timestamps:
|
| 464 |
+
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
| 465 |
+
|
| 466 |
+
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
|
| 467 |
+
self.sample_begin: int = len(self.initial_tokens)
|
| 468 |
+
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
| 469 |
+
|
| 470 |
+
# inference: implements the forward pass through the decoder, including kv caching
|
| 471 |
+
self.inference = PyTorchInference(model, len(self.initial_tokens))
|
| 472 |
+
|
| 473 |
+
# sequence ranker: implements how to rank a group of sampled sequences
|
| 474 |
+
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
| 475 |
+
|
| 476 |
+
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
| 477 |
+
if options.beam_size is not None:
|
| 478 |
+
self.decoder = BeamSearchDecoder(
|
| 479 |
+
options.beam_size, tokenizer.eot, self.inference, options.patience
|
| 480 |
+
)
|
| 481 |
+
else:
|
| 482 |
+
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
| 483 |
+
|
| 484 |
+
# logit filters: applies various rules to suppress or penalize certain tokens
|
| 485 |
+
self.logit_filters = []
|
| 486 |
+
if self.options.suppress_blank:
|
| 487 |
+
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
|
| 488 |
+
if self.options.suppress_tokens:
|
| 489 |
+
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
|
| 490 |
+
if not options.without_timestamps:
|
| 491 |
+
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
| 492 |
+
max_initial_timestamp_index = None
|
| 493 |
+
if options.max_initial_timestamp:
|
| 494 |
+
max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
|
| 495 |
+
self.logit_filters.append(
|
| 496 |
+
ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
| 500 |
+
if options.beam_size is not None and options.best_of is not None:
|
| 501 |
+
raise ValueError("beam_size and best_of can't be given together")
|
| 502 |
+
if options.temperature == 0:
|
| 503 |
+
if options.best_of is not None:
|
| 504 |
+
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
| 505 |
+
if options.patience is not None and options.beam_size is None:
|
| 506 |
+
raise ValueError("patience requires beam_size to be given")
|
| 507 |
+
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
|
| 508 |
+
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
|
| 509 |
+
|
| 510 |
+
return options
|
| 511 |
+
|
| 512 |
+
def _get_initial_tokens(self) -> Tuple[int]:
|
| 513 |
+
tokens = list(self.sot_sequence)
|
| 514 |
+
prefix = self.options.prefix
|
| 515 |
+
prompt = self.options.prompt
|
| 516 |
+
|
| 517 |
+
if prefix:
|
| 518 |
+
prefix_tokens = (
|
| 519 |
+
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
|
| 520 |
+
)
|
| 521 |
+
if self.sample_len is not None:
|
| 522 |
+
max_prefix_len = self.n_ctx // 2 - self.sample_len
|
| 523 |
+
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
| 524 |
+
tokens = tokens + prefix_tokens
|
| 525 |
+
|
| 526 |
+
if prompt:
|
| 527 |
+
prompt_tokens = (
|
| 528 |
+
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
|
| 529 |
+
)
|
| 530 |
+
tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
|
| 531 |
+
|
| 532 |
+
return tuple(tokens)
|
| 533 |
+
|
| 534 |
+
def _get_suppress_tokens(self) -> Tuple[int]:
|
| 535 |
+
suppress_tokens = self.options.suppress_tokens
|
| 536 |
+
|
| 537 |
+
if isinstance(suppress_tokens, str):
|
| 538 |
+
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
| 539 |
+
|
| 540 |
+
if -1 in suppress_tokens:
|
| 541 |
+
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
| 542 |
+
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
| 543 |
+
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
| 544 |
+
suppress_tokens = [] # interpret empty string as an empty list
|
| 545 |
+
else:
|
| 546 |
+
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
| 547 |
+
|
| 548 |
+
suppress_tokens.extend(
|
| 549 |
+
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
|
| 550 |
+
)
|
| 551 |
+
if self.tokenizer.no_speech is not None:
|
| 552 |
+
# no-speech probability is collected separately
|
| 553 |
+
suppress_tokens.append(self.tokenizer.no_speech)
|
| 554 |
+
|
| 555 |
+
return tuple(sorted(set(suppress_tokens)))
|
| 556 |
+
|
| 557 |
+
def _get_audio_features(self, mel: Tensor, include_embeddings: bool = False):
|
| 558 |
+
if self.options.fp16:
|
| 559 |
+
mel = mel.half()
|
| 560 |
+
|
| 561 |
+
if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
|
| 562 |
+
# encoded audio features are given; skip audio encoding
|
| 563 |
+
audio_features = mel
|
| 564 |
+
else:
|
| 565 |
+
result = self.model.encoder(mel, include_embeddings)
|
| 566 |
+
if include_embeddings:
|
| 567 |
+
audio_features, embeddings = result
|
| 568 |
+
else:
|
| 569 |
+
audio_features = result
|
| 570 |
+
|
| 571 |
+
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
|
| 572 |
+
return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
|
| 573 |
+
|
| 574 |
+
if include_embeddings:
|
| 575 |
+
return audio_features, embeddings
|
| 576 |
+
else:
|
| 577 |
+
return audio_features
|
| 578 |
+
|
| 579 |
+
def _detect_language(self, audio_features: Tensor, tokens: Tensor):
|
| 580 |
+
languages = [self.options.language] * audio_features.shape[0]
|
| 581 |
+
lang_probs = None
|
| 582 |
+
|
| 583 |
+
if self.options.language is None or self.options.task == "lang_id":
|
| 584 |
+
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
|
| 585 |
+
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
| 586 |
+
if self.options.language is None:
|
| 587 |
+
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
|
| 588 |
+
|
| 589 |
+
return languages, lang_probs
|
| 590 |
+
|
| 591 |
+
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
|
| 592 |
+
assert audio_features.shape[0] == tokens.shape[0]
|
| 593 |
+
n_batch = tokens.shape[0]
|
| 594 |
+
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
| 595 |
+
no_speech_probs = [np.nan] * n_batch
|
| 596 |
+
|
| 597 |
+
try:
|
| 598 |
+
embeddings = []
|
| 599 |
+
for i in range(self.sample_len):
|
| 600 |
+
logits, token_embeddings = self.inference.logits(tokens, audio_features, include_embeddings=True)
|
| 601 |
+
|
| 602 |
+
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
|
| 603 |
+
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
| 604 |
+
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
| 605 |
+
|
| 606 |
+
# now we need to consider the logits at the last token only
|
| 607 |
+
logits = logits[:, -1]
|
| 608 |
+
token_embeddings = token_embeddings[:, :, -1]
|
| 609 |
+
|
| 610 |
+
# Append embeddings together
|
| 611 |
+
embeddings.append(token_embeddings)
|
| 612 |
+
|
| 613 |
+
# apply the logit filters, e.g. for suppressing or applying penalty to
|
| 614 |
+
for logit_filter in self.logit_filters:
|
| 615 |
+
logit_filter.apply(logits, tokens)
|
| 616 |
+
|
| 617 |
+
# expand the tokens tensor with the selected next tokens
|
| 618 |
+
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
| 619 |
+
|
| 620 |
+
if completed or tokens.shape[-1] > self.n_ctx:
|
| 621 |
+
break
|
| 622 |
+
finally:
|
| 623 |
+
if completed:
|
| 624 |
+
embeddings = embeddings[:-1]
|
| 625 |
+
embeddings = np.stack(embeddings, 2)
|
| 626 |
+
self.inference.cleanup_caching()
|
| 627 |
+
|
| 628 |
+
return tokens, sum_logprobs, no_speech_probs, embeddings
|
| 629 |
+
|
| 630 |
+
@torch.no_grad()
|
| 631 |
+
def run(self, mel: Tensor) -> List[DecodingResult]:
|
| 632 |
+
self.decoder.reset()
|
| 633 |
+
tokenizer: Tokenizer = self.tokenizer
|
| 634 |
+
n_audio: int = mel.shape[0]
|
| 635 |
+
|
| 636 |
+
# encoder forward pass
|
| 637 |
+
forward_pass: Tuple[Tensor, np.ndarray] = self._get_audio_features(mel, include_embeddings=True)
|
| 638 |
+
audio_features, encoder_embeddings = forward_pass
|
| 639 |
+
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
| 640 |
+
|
| 641 |
+
# detect language if requested, overwriting the language token
|
| 642 |
+
languages, language_probs = self._detect_language(audio_features, tokens)
|
| 643 |
+
if self.options.task == "lang_id":
|
| 644 |
+
return [
|
| 645 |
+
DecodingResult(audio_features=features, language=language, language_probs=probs)
|
| 646 |
+
for features, language, probs in zip(audio_features, languages, language_probs)
|
| 647 |
+
]
|
| 648 |
+
|
| 649 |
+
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
|
| 650 |
+
audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
|
| 651 |
+
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
| 652 |
+
|
| 653 |
+
# call the main sampling loop
|
| 654 |
+
tokens, sum_logprobs, no_speech_probs, decoder_embeddings = self._main_loop(audio_features, tokens)
|
| 655 |
+
|
| 656 |
+
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
| 657 |
+
audio_features = audio_features[:: self.n_group]
|
| 658 |
+
no_speech_probs = no_speech_probs[:: self.n_group]
|
| 659 |
+
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
| 660 |
+
|
| 661 |
+
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
| 662 |
+
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
| 663 |
+
|
| 664 |
+
# get the final candidates for each group, and slice between the first sampled token and EOT
|
| 665 |
+
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
| 666 |
+
tokens: List[List[Tensor]] = [
|
| 667 |
+
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
|
| 668 |
+
]
|
| 669 |
+
|
| 670 |
+
# select the top-ranked sample in each group
|
| 671 |
+
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
| 672 |
+
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
| 673 |
+
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
| 674 |
+
|
| 675 |
+
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
| 676 |
+
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
|
| 677 |
+
|
| 678 |
+
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
|
| 679 |
+
if len(set(map(len, fields))) != 1:
|
| 680 |
+
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
| 681 |
+
|
| 682 |
+
return [
|
| 683 |
+
DecodingResult(
|
| 684 |
+
audio_features=features,
|
| 685 |
+
language=language,
|
| 686 |
+
tokens=tokens,
|
| 687 |
+
text=text,
|
| 688 |
+
avg_logprob=avg_logprob,
|
| 689 |
+
no_speech_prob=no_speech_prob,
|
| 690 |
+
temperature=self.options.temperature,
|
| 691 |
+
compression_ratio=compression_ratio(text),
|
| 692 |
+
encoder_embeddings=encoder_embeddings,
|
| 693 |
+
decoder_embeddings=decoder_embeddings
|
| 694 |
+
)
|
| 695 |
+
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
|
| 696 |
+
]
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
@torch.no_grad()
|
| 700 |
+
def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
|
| 701 |
+
"""
|
| 702 |
+
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
| 703 |
+
|
| 704 |
+
Parameters
|
| 705 |
+
----------
|
| 706 |
+
model: Whisper
|
| 707 |
+
the Whisper model instance
|
| 708 |
+
|
| 709 |
+
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
|
| 710 |
+
A tensor containing the Mel spectrogram(s)
|
| 711 |
+
|
| 712 |
+
options: DecodingOptions
|
| 713 |
+
A dataclass that contains all necessary options for decoding 30-second segments
|
| 714 |
+
|
| 715 |
+
Returns
|
| 716 |
+
-------
|
| 717 |
+
result: Union[DecodingResult, List[DecodingResult]]
|
| 718 |
+
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
| 719 |
+
"""
|
| 720 |
+
single = mel.ndim == 2
|
| 721 |
+
if single:
|
| 722 |
+
mel = mel.unsqueeze(0)
|
| 723 |
+
|
| 724 |
+
result = DecodingTask(model, options).run(mel)
|
| 725 |
+
|
| 726 |
+
if single:
|
| 727 |
+
result = result[0]
|
| 728 |
+
|
| 729 |
+
return result
|