Upload 41 files
Browse files- audio.py +136 -0
- checkpoints/lipsync_expert.pth +3 -0
- checkpoints/s3fd-619a316812.pth +3 -0
- checkpoints/visual_quality_disc.pth +3 -0
- checkpoints/wav2lip.pth +3 -0
- checkpoints/wav2lip_gan.pth +3 -0
- color_syncnet_train.py +279 -0
- evaluation/gen_videos_from_filelist.py +238 -0
- evaluation/real_videos_inference.py +305 -0
- evaluation/scores_LSE/SyncNetInstance_calc_scores.py +210 -0
- evaluation/scores_LSE/calculate_scores_LRS.py +53 -0
- evaluation/scores_LSE/calculate_scores_real_videos.py +45 -0
- evaluation/scores_LSE/calculate_scores_real_videos.sh +8 -0
- evaluation/test_filelists/ReSyncED/random_pairs.txt +160 -0
- evaluation/test_filelists/ReSyncED/tts_pairs.txt +18 -0
- evaluation/test_filelists/lrs2.txt +0 -0
- evaluation/test_filelists/lrs3.txt +0 -0
- evaluation/test_filelists/lrw.txt +0 -0
- face_detection/__init__.py +7 -0
- face_detection/api.py +79 -0
- face_detection/detection/__init__.py +1 -0
- face_detection/detection/core.py +130 -0
- face_detection/detection/sfd/__init__.py +1 -0
- face_detection/detection/sfd/bbox.py +129 -0
- face_detection/detection/sfd/detect.py +112 -0
- face_detection/detection/sfd/net_s3fd.py +129 -0
- face_detection/detection/sfd/s3fd.pth +3 -0
- face_detection/detection/sfd/sfd_detector.py +59 -0
- face_detection/models.py +261 -0
- face_detection/utils.py +313 -0
- hparams.py +101 -0
- hq_wav2lip_train.py +443 -0
- inference.py +280 -0
- models/__init__.py +2 -0
- models/conv.py +44 -0
- models/syncnet.py +66 -0
- models/wav2lip.py +184 -0
- preprocess.py +113 -0
- requirements.txt +8 -0
- requirementsCPU.txt +9 -0
- wav2lip_train.py +374 -0
audio.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
import librosa.filters
|
| 3 |
+
import numpy as np
|
| 4 |
+
# import tensorflow as tf
|
| 5 |
+
from scipy import signal
|
| 6 |
+
from scipy.io import wavfile
|
| 7 |
+
from hparams import hparams as hp
|
| 8 |
+
|
| 9 |
+
def load_wav(path, sr):
|
| 10 |
+
return librosa.core.load(path, sr=sr)[0]
|
| 11 |
+
|
| 12 |
+
def save_wav(wav, path, sr):
|
| 13 |
+
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
| 14 |
+
#proposed by @dsmiller
|
| 15 |
+
wavfile.write(path, sr, wav.astype(np.int16))
|
| 16 |
+
|
| 17 |
+
def save_wavenet_wav(wav, path, sr):
|
| 18 |
+
librosa.output.write_wav(path, wav, sr=sr)
|
| 19 |
+
|
| 20 |
+
def preemphasis(wav, k, preemphasize=True):
|
| 21 |
+
if preemphasize:
|
| 22 |
+
return signal.lfilter([1, -k], [1], wav)
|
| 23 |
+
return wav
|
| 24 |
+
|
| 25 |
+
def inv_preemphasis(wav, k, inv_preemphasize=True):
|
| 26 |
+
if inv_preemphasize:
|
| 27 |
+
return signal.lfilter([1], [1, -k], wav)
|
| 28 |
+
return wav
|
| 29 |
+
|
| 30 |
+
def get_hop_size():
|
| 31 |
+
hop_size = hp.hop_size
|
| 32 |
+
if hop_size is None:
|
| 33 |
+
assert hp.frame_shift_ms is not None
|
| 34 |
+
hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
|
| 35 |
+
return hop_size
|
| 36 |
+
|
| 37 |
+
def linearspectrogram(wav):
|
| 38 |
+
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
| 39 |
+
S = _amp_to_db(np.abs(D)) - hp.ref_level_db
|
| 40 |
+
|
| 41 |
+
if hp.signal_normalization:
|
| 42 |
+
return _normalize(S)
|
| 43 |
+
return S
|
| 44 |
+
|
| 45 |
+
def melspectrogram(wav):
|
| 46 |
+
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
| 47 |
+
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
|
| 48 |
+
|
| 49 |
+
if hp.signal_normalization:
|
| 50 |
+
return _normalize(S)
|
| 51 |
+
return S
|
| 52 |
+
|
| 53 |
+
def _lws_processor():
|
| 54 |
+
import lws
|
| 55 |
+
return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
|
| 56 |
+
|
| 57 |
+
def _stft(y):
|
| 58 |
+
if hp.use_lws:
|
| 59 |
+
return _lws_processor(hp).stft(y).T
|
| 60 |
+
else:
|
| 61 |
+
return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
|
| 62 |
+
|
| 63 |
+
##########################################################
|
| 64 |
+
#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
|
| 65 |
+
def num_frames(length, fsize, fshift):
|
| 66 |
+
"""Compute number of time frames of spectrogram
|
| 67 |
+
"""
|
| 68 |
+
pad = (fsize - fshift)
|
| 69 |
+
if length % fshift == 0:
|
| 70 |
+
M = (length + pad * 2 - fsize) // fshift + 1
|
| 71 |
+
else:
|
| 72 |
+
M = (length + pad * 2 - fsize) // fshift + 2
|
| 73 |
+
return M
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def pad_lr(x, fsize, fshift):
|
| 77 |
+
"""Compute left and right padding
|
| 78 |
+
"""
|
| 79 |
+
M = num_frames(len(x), fsize, fshift)
|
| 80 |
+
pad = (fsize - fshift)
|
| 81 |
+
T = len(x) + 2 * pad
|
| 82 |
+
r = (M - 1) * fshift + fsize - T
|
| 83 |
+
return pad, pad + r
|
| 84 |
+
##########################################################
|
| 85 |
+
#Librosa correct padding
|
| 86 |
+
def librosa_pad_lr(x, fsize, fshift):
|
| 87 |
+
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
| 88 |
+
|
| 89 |
+
# Conversions
|
| 90 |
+
_mel_basis = None
|
| 91 |
+
|
| 92 |
+
def _linear_to_mel(spectogram):
|
| 93 |
+
global _mel_basis
|
| 94 |
+
if _mel_basis is None:
|
| 95 |
+
_mel_basis = _build_mel_basis()
|
| 96 |
+
return np.dot(_mel_basis, spectogram)
|
| 97 |
+
|
| 98 |
+
def _build_mel_basis():
|
| 99 |
+
assert hp.fmax <= hp.sample_rate // 2
|
| 100 |
+
return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels,
|
| 101 |
+
fmin=hp.fmin, fmax=hp.fmax)
|
| 102 |
+
|
| 103 |
+
def _amp_to_db(x):
|
| 104 |
+
min_level = np.exp(hp.min_level_db / 20 * np.log(10))
|
| 105 |
+
return 20 * np.log10(np.maximum(min_level, x))
|
| 106 |
+
|
| 107 |
+
def _db_to_amp(x):
|
| 108 |
+
return np.power(10.0, (x) * 0.05)
|
| 109 |
+
|
| 110 |
+
def _normalize(S):
|
| 111 |
+
if hp.allow_clipping_in_normalization:
|
| 112 |
+
if hp.symmetric_mels:
|
| 113 |
+
return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
|
| 114 |
+
-hp.max_abs_value, hp.max_abs_value)
|
| 115 |
+
else:
|
| 116 |
+
return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
|
| 117 |
+
|
| 118 |
+
assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
|
| 119 |
+
if hp.symmetric_mels:
|
| 120 |
+
return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
|
| 121 |
+
else:
|
| 122 |
+
return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
|
| 123 |
+
|
| 124 |
+
def _denormalize(D):
|
| 125 |
+
if hp.allow_clipping_in_normalization:
|
| 126 |
+
if hp.symmetric_mels:
|
| 127 |
+
return (((np.clip(D, -hp.max_abs_value,
|
| 128 |
+
hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
|
| 129 |
+
+ hp.min_level_db)
|
| 130 |
+
else:
|
| 131 |
+
return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
|
| 132 |
+
|
| 133 |
+
if hp.symmetric_mels:
|
| 134 |
+
return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
|
| 135 |
+
else:
|
| 136 |
+
return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
|
checkpoints/lipsync_expert.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aa1f06d61ae86c47074ff9bc1bb7d0c40ab2d840724dc9258e255a8fab4b3559
|
| 3 |
+
size 134
|
checkpoints/s3fd-619a316812.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7636d0c9d2a8f4759aef537cbcc25c5fa2eb2d5d80b1fada4dcc800e967cf381
|
| 3 |
+
size 133
|
checkpoints/visual_quality_disc.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bdb40d624f6a1e9a07beec0f6bb5a19a91a9fac46ce6bcfd282fd9ccf1c3d3fc
|
| 3 |
+
size 134
|
checkpoints/wav2lip.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a8e58726ef72ac961e2fea864e93e10fd64076222e5bd98394736684aa63dd2d
|
| 3 |
+
size 131
|
checkpoints/wav2lip_gan.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:483f94a71bfd57ff73a2464a661b9af5766ce54c2ad1f06def1a2e1d8b8cd78a
|
| 3 |
+
size 134
|
color_syncnet_train.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os.path import dirname, join, basename, isfile
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
|
| 4 |
+
from models import SyncNet_color as SyncNet
|
| 5 |
+
import audio
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch import optim
|
| 10 |
+
import torch.backends.cudnn as cudnn
|
| 11 |
+
from torch.utils import data as data_utils
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from glob import glob
|
| 15 |
+
|
| 16 |
+
import os, random, cv2, argparse
|
| 17 |
+
from hparams import hparams, get_image_list
|
| 18 |
+
|
| 19 |
+
parser = argparse.ArgumentParser(description='Code to train the expert lip-sync discriminator')
|
| 20 |
+
|
| 21 |
+
parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True)
|
| 22 |
+
|
| 23 |
+
parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
|
| 24 |
+
parser.add_argument('--checkpoint_path', help='Resumed from this checkpoint', default=None, type=str)
|
| 25 |
+
|
| 26 |
+
args = parser.parse_args()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
global_step = 0
|
| 30 |
+
global_epoch = 0
|
| 31 |
+
use_cuda = torch.cuda.is_available()
|
| 32 |
+
print('use_cuda: {}'.format(use_cuda))
|
| 33 |
+
|
| 34 |
+
syncnet_T = 5
|
| 35 |
+
syncnet_mel_step_size = 16
|
| 36 |
+
|
| 37 |
+
class Dataset(object):
|
| 38 |
+
def __init__(self, split):
|
| 39 |
+
self.all_videos = get_image_list(args.data_root, split)
|
| 40 |
+
|
| 41 |
+
def get_frame_id(self, frame):
|
| 42 |
+
return int(basename(frame).split('.')[0])
|
| 43 |
+
|
| 44 |
+
def get_window(self, start_frame):
|
| 45 |
+
start_id = self.get_frame_id(start_frame)
|
| 46 |
+
vidname = dirname(start_frame)
|
| 47 |
+
|
| 48 |
+
window_fnames = []
|
| 49 |
+
for frame_id in range(start_id, start_id + syncnet_T):
|
| 50 |
+
frame = join(vidname, '{}.jpg'.format(frame_id))
|
| 51 |
+
if not isfile(frame):
|
| 52 |
+
return None
|
| 53 |
+
window_fnames.append(frame)
|
| 54 |
+
return window_fnames
|
| 55 |
+
|
| 56 |
+
def crop_audio_window(self, spec, start_frame):
|
| 57 |
+
# num_frames = (T x hop_size * fps) / sample_rate
|
| 58 |
+
start_frame_num = self.get_frame_id(start_frame)
|
| 59 |
+
start_idx = int(80. * (start_frame_num / float(hparams.fps)))
|
| 60 |
+
|
| 61 |
+
end_idx = start_idx + syncnet_mel_step_size
|
| 62 |
+
|
| 63 |
+
return spec[start_idx : end_idx, :]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def __len__(self):
|
| 67 |
+
return len(self.all_videos)
|
| 68 |
+
|
| 69 |
+
def __getitem__(self, idx):
|
| 70 |
+
while 1:
|
| 71 |
+
idx = random.randint(0, len(self.all_videos) - 1)
|
| 72 |
+
vidname = self.all_videos[idx]
|
| 73 |
+
|
| 74 |
+
img_names = list(glob(join(vidname, '*.jpg')))
|
| 75 |
+
if len(img_names) <= 3 * syncnet_T:
|
| 76 |
+
continue
|
| 77 |
+
img_name = random.choice(img_names)
|
| 78 |
+
wrong_img_name = random.choice(img_names)
|
| 79 |
+
while wrong_img_name == img_name:
|
| 80 |
+
wrong_img_name = random.choice(img_names)
|
| 81 |
+
|
| 82 |
+
if random.choice([True, False]):
|
| 83 |
+
y = torch.ones(1).float()
|
| 84 |
+
chosen = img_name
|
| 85 |
+
else:
|
| 86 |
+
y = torch.zeros(1).float()
|
| 87 |
+
chosen = wrong_img_name
|
| 88 |
+
|
| 89 |
+
window_fnames = self.get_window(chosen)
|
| 90 |
+
if window_fnames is None:
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
window = []
|
| 94 |
+
all_read = True
|
| 95 |
+
for fname in window_fnames:
|
| 96 |
+
img = cv2.imread(fname)
|
| 97 |
+
if img is None:
|
| 98 |
+
all_read = False
|
| 99 |
+
break
|
| 100 |
+
try:
|
| 101 |
+
img = cv2.resize(img, (hparams.img_size, hparams.img_size))
|
| 102 |
+
except Exception as e:
|
| 103 |
+
all_read = False
|
| 104 |
+
break
|
| 105 |
+
|
| 106 |
+
window.append(img)
|
| 107 |
+
|
| 108 |
+
if not all_read: continue
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
wavpath = join(vidname, "audio.wav")
|
| 112 |
+
wav = audio.load_wav(wavpath, hparams.sample_rate)
|
| 113 |
+
|
| 114 |
+
orig_mel = audio.melspectrogram(wav).T
|
| 115 |
+
except Exception as e:
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
mel = self.crop_audio_window(orig_mel.copy(), img_name)
|
| 119 |
+
|
| 120 |
+
if (mel.shape[0] != syncnet_mel_step_size):
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
# H x W x 3 * T
|
| 124 |
+
x = np.concatenate(window, axis=2) / 255.
|
| 125 |
+
x = x.transpose(2, 0, 1)
|
| 126 |
+
x = x[:, x.shape[1]//2:]
|
| 127 |
+
|
| 128 |
+
x = torch.FloatTensor(x)
|
| 129 |
+
mel = torch.FloatTensor(mel.T).unsqueeze(0)
|
| 130 |
+
|
| 131 |
+
return x, mel, y
|
| 132 |
+
|
| 133 |
+
logloss = nn.BCELoss()
|
| 134 |
+
def cosine_loss(a, v, y):
|
| 135 |
+
d = nn.functional.cosine_similarity(a, v)
|
| 136 |
+
loss = logloss(d.unsqueeze(1), y)
|
| 137 |
+
|
| 138 |
+
return loss
|
| 139 |
+
|
| 140 |
+
def train(device, model, train_data_loader, test_data_loader, optimizer,
|
| 141 |
+
checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
|
| 142 |
+
|
| 143 |
+
global global_step, global_epoch
|
| 144 |
+
resumed_step = global_step
|
| 145 |
+
|
| 146 |
+
while global_epoch < nepochs:
|
| 147 |
+
running_loss = 0.
|
| 148 |
+
prog_bar = tqdm(enumerate(train_data_loader))
|
| 149 |
+
for step, (x, mel, y) in prog_bar:
|
| 150 |
+
model.train()
|
| 151 |
+
optimizer.zero_grad()
|
| 152 |
+
|
| 153 |
+
# Transform data to CUDA device
|
| 154 |
+
x = x.to(device)
|
| 155 |
+
|
| 156 |
+
mel = mel.to(device)
|
| 157 |
+
|
| 158 |
+
a, v = model(mel, x)
|
| 159 |
+
y = y.to(device)
|
| 160 |
+
|
| 161 |
+
loss = cosine_loss(a, v, y)
|
| 162 |
+
loss.backward()
|
| 163 |
+
optimizer.step()
|
| 164 |
+
|
| 165 |
+
global_step += 1
|
| 166 |
+
cur_session_steps = global_step - resumed_step
|
| 167 |
+
running_loss += loss.item()
|
| 168 |
+
|
| 169 |
+
if global_step == 1 or global_step % checkpoint_interval == 0:
|
| 170 |
+
save_checkpoint(
|
| 171 |
+
model, optimizer, global_step, checkpoint_dir, global_epoch)
|
| 172 |
+
|
| 173 |
+
if global_step % hparams.syncnet_eval_interval == 0:
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
eval_model(test_data_loader, global_step, device, model, checkpoint_dir)
|
| 176 |
+
|
| 177 |
+
prog_bar.set_description('Loss: {}'.format(running_loss / (step + 1)))
|
| 178 |
+
|
| 179 |
+
global_epoch += 1
|
| 180 |
+
|
| 181 |
+
def eval_model(test_data_loader, global_step, device, model, checkpoint_dir):
|
| 182 |
+
eval_steps = 1400
|
| 183 |
+
print('Evaluating for {} steps'.format(eval_steps))
|
| 184 |
+
losses = []
|
| 185 |
+
while 1:
|
| 186 |
+
for step, (x, mel, y) in enumerate(test_data_loader):
|
| 187 |
+
|
| 188 |
+
model.eval()
|
| 189 |
+
|
| 190 |
+
# Transform data to CUDA device
|
| 191 |
+
x = x.to(device)
|
| 192 |
+
|
| 193 |
+
mel = mel.to(device)
|
| 194 |
+
|
| 195 |
+
a, v = model(mel, x)
|
| 196 |
+
y = y.to(device)
|
| 197 |
+
|
| 198 |
+
loss = cosine_loss(a, v, y)
|
| 199 |
+
losses.append(loss.item())
|
| 200 |
+
|
| 201 |
+
if step > eval_steps: break
|
| 202 |
+
|
| 203 |
+
averaged_loss = sum(losses) / len(losses)
|
| 204 |
+
print(averaged_loss)
|
| 205 |
+
|
| 206 |
+
return
|
| 207 |
+
|
| 208 |
+
def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch):
|
| 209 |
+
|
| 210 |
+
checkpoint_path = join(
|
| 211 |
+
checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step))
|
| 212 |
+
optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
|
| 213 |
+
torch.save({
|
| 214 |
+
"state_dict": model.state_dict(),
|
| 215 |
+
"optimizer": optimizer_state,
|
| 216 |
+
"global_step": step,
|
| 217 |
+
"global_epoch": epoch,
|
| 218 |
+
}, checkpoint_path)
|
| 219 |
+
print("Saved checkpoint:", checkpoint_path)
|
| 220 |
+
|
| 221 |
+
def _load(checkpoint_path):
|
| 222 |
+
if use_cuda:
|
| 223 |
+
checkpoint = torch.load(checkpoint_path)
|
| 224 |
+
else:
|
| 225 |
+
checkpoint = torch.load(checkpoint_path,
|
| 226 |
+
map_location=lambda storage, loc: storage)
|
| 227 |
+
return checkpoint
|
| 228 |
+
|
| 229 |
+
def load_checkpoint(path, model, optimizer, reset_optimizer=False):
|
| 230 |
+
global global_step
|
| 231 |
+
global global_epoch
|
| 232 |
+
|
| 233 |
+
print("Load checkpoint from: {}".format(path))
|
| 234 |
+
checkpoint = _load(path)
|
| 235 |
+
model.load_state_dict(checkpoint["state_dict"])
|
| 236 |
+
if not reset_optimizer:
|
| 237 |
+
optimizer_state = checkpoint["optimizer"]
|
| 238 |
+
if optimizer_state is not None:
|
| 239 |
+
print("Load optimizer state from {}".format(path))
|
| 240 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
| 241 |
+
global_step = checkpoint["global_step"]
|
| 242 |
+
global_epoch = checkpoint["global_epoch"]
|
| 243 |
+
|
| 244 |
+
return model
|
| 245 |
+
|
| 246 |
+
if __name__ == "__main__":
|
| 247 |
+
checkpoint_dir = args.checkpoint_dir
|
| 248 |
+
checkpoint_path = args.checkpoint_path
|
| 249 |
+
|
| 250 |
+
if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir)
|
| 251 |
+
|
| 252 |
+
# Dataset and Dataloader setup
|
| 253 |
+
train_dataset = Dataset('train')
|
| 254 |
+
test_dataset = Dataset('val')
|
| 255 |
+
|
| 256 |
+
train_data_loader = data_utils.DataLoader(
|
| 257 |
+
train_dataset, batch_size=hparams.syncnet_batch_size, shuffle=True,
|
| 258 |
+
num_workers=hparams.num_workers)
|
| 259 |
+
|
| 260 |
+
test_data_loader = data_utils.DataLoader(
|
| 261 |
+
test_dataset, batch_size=hparams.syncnet_batch_size,
|
| 262 |
+
num_workers=8)
|
| 263 |
+
|
| 264 |
+
device = torch.device("cuda" if use_cuda else "cpu")
|
| 265 |
+
|
| 266 |
+
# Model
|
| 267 |
+
model = SyncNet().to(device)
|
| 268 |
+
print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
| 269 |
+
|
| 270 |
+
optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
|
| 271 |
+
lr=hparams.syncnet_lr)
|
| 272 |
+
|
| 273 |
+
if checkpoint_path is not None:
|
| 274 |
+
load_checkpoint(checkpoint_path, model, optimizer, reset_optimizer=False)
|
| 275 |
+
|
| 276 |
+
train(device, model, train_data_loader, test_data_loader, optimizer,
|
| 277 |
+
checkpoint_dir=checkpoint_dir,
|
| 278 |
+
checkpoint_interval=hparams.syncnet_checkpoint_interval,
|
| 279 |
+
nepochs=hparams.nepochs)
|
evaluation/gen_videos_from_filelist.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os import listdir, path
|
| 2 |
+
import numpy as np
|
| 3 |
+
import scipy, cv2, os, sys, argparse
|
| 4 |
+
import dlib, json, subprocess
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from glob import glob
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
sys.path.append('../')
|
| 10 |
+
import audio
|
| 11 |
+
import face_detection
|
| 12 |
+
from models import Wav2Lip
|
| 13 |
+
|
| 14 |
+
parser = argparse.ArgumentParser(description='Code to generate results for test filelists')
|
| 15 |
+
|
| 16 |
+
parser.add_argument('--filelist', type=str,
|
| 17 |
+
help='Filepath of filelist file to read', required=True)
|
| 18 |
+
parser.add_argument('--results_dir', type=str, help='Folder to save all results into',
|
| 19 |
+
required=True)
|
| 20 |
+
parser.add_argument('--data_root', type=str, required=True)
|
| 21 |
+
parser.add_argument('--checkpoint_path', type=str,
|
| 22 |
+
help='Name of saved checkpoint to load weights from', required=True)
|
| 23 |
+
|
| 24 |
+
parser.add_argument('--pads', nargs='+', type=int, default=[0, 0, 0, 0],
|
| 25 |
+
help='Padding (top, bottom, left, right)')
|
| 26 |
+
parser.add_argument('--face_det_batch_size', type=int,
|
| 27 |
+
help='Single GPU batch size for face detection', default=64)
|
| 28 |
+
parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip', default=128)
|
| 29 |
+
|
| 30 |
+
# parser.add_argument('--resize_factor', default=1, type=int)
|
| 31 |
+
|
| 32 |
+
args = parser.parse_args()
|
| 33 |
+
args.img_size = 96
|
| 34 |
+
|
| 35 |
+
def get_smoothened_boxes(boxes, T):
|
| 36 |
+
for i in range(len(boxes)):
|
| 37 |
+
if i + T > len(boxes):
|
| 38 |
+
window = boxes[len(boxes) - T:]
|
| 39 |
+
else:
|
| 40 |
+
window = boxes[i : i + T]
|
| 41 |
+
boxes[i] = np.mean(window, axis=0)
|
| 42 |
+
return boxes
|
| 43 |
+
|
| 44 |
+
def face_detect(images):
|
| 45 |
+
batch_size = args.face_det_batch_size
|
| 46 |
+
|
| 47 |
+
while 1:
|
| 48 |
+
predictions = []
|
| 49 |
+
try:
|
| 50 |
+
for i in range(0, len(images), batch_size):
|
| 51 |
+
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
|
| 52 |
+
except RuntimeError:
|
| 53 |
+
if batch_size == 1:
|
| 54 |
+
raise RuntimeError('Image too big to run face detection on GPU')
|
| 55 |
+
batch_size //= 2
|
| 56 |
+
args.face_det_batch_size = batch_size
|
| 57 |
+
print('Recovering from OOM error; New batch size: {}'.format(batch_size))
|
| 58 |
+
continue
|
| 59 |
+
break
|
| 60 |
+
|
| 61 |
+
results = []
|
| 62 |
+
pady1, pady2, padx1, padx2 = args.pads
|
| 63 |
+
for rect, image in zip(predictions, images):
|
| 64 |
+
if rect is None:
|
| 65 |
+
raise ValueError('Face not detected!')
|
| 66 |
+
|
| 67 |
+
y1 = max(0, rect[1] - pady1)
|
| 68 |
+
y2 = min(image.shape[0], rect[3] + pady2)
|
| 69 |
+
x1 = max(0, rect[0] - padx1)
|
| 70 |
+
x2 = min(image.shape[1], rect[2] + padx2)
|
| 71 |
+
|
| 72 |
+
results.append([x1, y1, x2, y2])
|
| 73 |
+
|
| 74 |
+
boxes = get_smoothened_boxes(np.array(results), T=5)
|
| 75 |
+
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
|
| 76 |
+
|
| 77 |
+
return results
|
| 78 |
+
|
| 79 |
+
def datagen(frames, face_det_results, mels):
|
| 80 |
+
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
|
| 81 |
+
|
| 82 |
+
for i, m in enumerate(mels):
|
| 83 |
+
if i >= len(frames): raise ValueError('Equal or less lengths only')
|
| 84 |
+
|
| 85 |
+
frame_to_save = frames[i].copy()
|
| 86 |
+
face, coords, valid_frame = face_det_results[i].copy()
|
| 87 |
+
if not valid_frame:
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
face = cv2.resize(face, (args.img_size, args.img_size))
|
| 91 |
+
|
| 92 |
+
img_batch.append(face)
|
| 93 |
+
mel_batch.append(m)
|
| 94 |
+
frame_batch.append(frame_to_save)
|
| 95 |
+
coords_batch.append(coords)
|
| 96 |
+
|
| 97 |
+
if len(img_batch) >= args.wav2lip_batch_size:
|
| 98 |
+
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
|
| 99 |
+
|
| 100 |
+
img_masked = img_batch.copy()
|
| 101 |
+
img_masked[:, args.img_size//2:] = 0
|
| 102 |
+
|
| 103 |
+
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
| 104 |
+
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
| 105 |
+
|
| 106 |
+
yield img_batch, mel_batch, frame_batch, coords_batch
|
| 107 |
+
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
|
| 108 |
+
|
| 109 |
+
if len(img_batch) > 0:
|
| 110 |
+
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
|
| 111 |
+
|
| 112 |
+
img_masked = img_batch.copy()
|
| 113 |
+
img_masked[:, args.img_size//2:] = 0
|
| 114 |
+
|
| 115 |
+
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
| 116 |
+
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
| 117 |
+
|
| 118 |
+
yield img_batch, mel_batch, frame_batch, coords_batch
|
| 119 |
+
|
| 120 |
+
fps = 25
|
| 121 |
+
mel_step_size = 16
|
| 122 |
+
mel_idx_multiplier = 80./fps
|
| 123 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 124 |
+
print('Using {} for inference.'.format(device))
|
| 125 |
+
|
| 126 |
+
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
|
| 127 |
+
flip_input=False, device=device)
|
| 128 |
+
|
| 129 |
+
def _load(checkpoint_path):
|
| 130 |
+
if device == 'cuda':
|
| 131 |
+
checkpoint = torch.load(checkpoint_path)
|
| 132 |
+
else:
|
| 133 |
+
checkpoint = torch.load(checkpoint_path,
|
| 134 |
+
map_location=lambda storage, loc: storage)
|
| 135 |
+
return checkpoint
|
| 136 |
+
|
| 137 |
+
def load_model(path):
|
| 138 |
+
model = Wav2Lip()
|
| 139 |
+
print("Load checkpoint from: {}".format(path))
|
| 140 |
+
checkpoint = _load(path)
|
| 141 |
+
s = checkpoint["state_dict"]
|
| 142 |
+
new_s = {}
|
| 143 |
+
for k, v in s.items():
|
| 144 |
+
new_s[k.replace('module.', '')] = v
|
| 145 |
+
model.load_state_dict(new_s)
|
| 146 |
+
|
| 147 |
+
model = model.to(device)
|
| 148 |
+
return model.eval()
|
| 149 |
+
|
| 150 |
+
model = load_model(args.checkpoint_path)
|
| 151 |
+
|
| 152 |
+
def main():
|
| 153 |
+
assert args.data_root is not None
|
| 154 |
+
data_root = args.data_root
|
| 155 |
+
|
| 156 |
+
if not os.path.isdir(args.results_dir): os.makedirs(args.results_dir)
|
| 157 |
+
|
| 158 |
+
with open(args.filelist, 'r') as filelist:
|
| 159 |
+
lines = filelist.readlines()
|
| 160 |
+
|
| 161 |
+
for idx, line in enumerate(tqdm(lines)):
|
| 162 |
+
audio_src, video = line.strip().split()
|
| 163 |
+
|
| 164 |
+
audio_src = os.path.join(data_root, audio_src) + '.mp4'
|
| 165 |
+
video = os.path.join(data_root, video) + '.mp4'
|
| 166 |
+
|
| 167 |
+
command = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'.format(audio_src, '../temp/temp.wav')
|
| 168 |
+
subprocess.call(command, shell=True)
|
| 169 |
+
temp_audio = '../temp/temp.wav'
|
| 170 |
+
|
| 171 |
+
wav = audio.load_wav(temp_audio, 16000)
|
| 172 |
+
mel = audio.melspectrogram(wav)
|
| 173 |
+
if np.isnan(mel.reshape(-1)).sum() > 0:
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
mel_chunks = []
|
| 177 |
+
i = 0
|
| 178 |
+
while 1:
|
| 179 |
+
start_idx = int(i * mel_idx_multiplier)
|
| 180 |
+
if start_idx + mel_step_size > len(mel[0]):
|
| 181 |
+
break
|
| 182 |
+
mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
|
| 183 |
+
i += 1
|
| 184 |
+
|
| 185 |
+
video_stream = cv2.VideoCapture(video)
|
| 186 |
+
|
| 187 |
+
full_frames = []
|
| 188 |
+
while 1:
|
| 189 |
+
still_reading, frame = video_stream.read()
|
| 190 |
+
if not still_reading or len(full_frames) > len(mel_chunks):
|
| 191 |
+
video_stream.release()
|
| 192 |
+
break
|
| 193 |
+
full_frames.append(frame)
|
| 194 |
+
|
| 195 |
+
if len(full_frames) < len(mel_chunks):
|
| 196 |
+
continue
|
| 197 |
+
|
| 198 |
+
full_frames = full_frames[:len(mel_chunks)]
|
| 199 |
+
|
| 200 |
+
try:
|
| 201 |
+
face_det_results = face_detect(full_frames.copy())
|
| 202 |
+
except ValueError as e:
|
| 203 |
+
continue
|
| 204 |
+
|
| 205 |
+
batch_size = args.wav2lip_batch_size
|
| 206 |
+
gen = datagen(full_frames.copy(), face_det_results, mel_chunks)
|
| 207 |
+
|
| 208 |
+
for i, (img_batch, mel_batch, frames, coords) in enumerate(gen):
|
| 209 |
+
if i == 0:
|
| 210 |
+
frame_h, frame_w = full_frames[0].shape[:-1]
|
| 211 |
+
out = cv2.VideoWriter('../temp/result.avi',
|
| 212 |
+
cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
|
| 213 |
+
|
| 214 |
+
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
|
| 215 |
+
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
|
| 216 |
+
|
| 217 |
+
with torch.no_grad():
|
| 218 |
+
pred = model(mel_batch, img_batch)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
|
| 222 |
+
|
| 223 |
+
for pl, f, c in zip(pred, frames, coords):
|
| 224 |
+
y1, y2, x1, x2 = c
|
| 225 |
+
pl = cv2.resize(pl.astype(np.uint8), (x2 - x1, y2 - y1))
|
| 226 |
+
f[y1:y2, x1:x2] = pl
|
| 227 |
+
out.write(f)
|
| 228 |
+
|
| 229 |
+
out.release()
|
| 230 |
+
|
| 231 |
+
vid = os.path.join(args.results_dir, '{}.mp4'.format(idx))
|
| 232 |
+
|
| 233 |
+
command = 'ffmpeg -loglevel panic -y -i {} -i {} -strict -2 -q:v 1 {}'.format(temp_audio,
|
| 234 |
+
'../temp/result.avi', vid)
|
| 235 |
+
subprocess.call(command, shell=True)
|
| 236 |
+
|
| 237 |
+
if __name__ == '__main__':
|
| 238 |
+
main()
|
evaluation/real_videos_inference.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os import listdir, path
|
| 2 |
+
import numpy as np
|
| 3 |
+
import scipy, cv2, os, sys, argparse
|
| 4 |
+
import dlib, json, subprocess
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from glob import glob
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
sys.path.append('../')
|
| 10 |
+
import audio
|
| 11 |
+
import face_detection
|
| 12 |
+
from models import Wav2Lip
|
| 13 |
+
|
| 14 |
+
parser = argparse.ArgumentParser(description='Code to generate results on ReSyncED evaluation set')
|
| 15 |
+
|
| 16 |
+
parser.add_argument('--mode', type=str,
|
| 17 |
+
help='random | dubbed | tts', required=True)
|
| 18 |
+
|
| 19 |
+
parser.add_argument('--filelist', type=str,
|
| 20 |
+
help='Filepath of filelist file to read', default=None)
|
| 21 |
+
|
| 22 |
+
parser.add_argument('--results_dir', type=str, help='Folder to save all results into',
|
| 23 |
+
required=True)
|
| 24 |
+
parser.add_argument('--data_root', type=str, required=True)
|
| 25 |
+
parser.add_argument('--checkpoint_path', type=str,
|
| 26 |
+
help='Name of saved checkpoint to load weights from', required=True)
|
| 27 |
+
parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
|
| 28 |
+
help='Padding (top, bottom, left, right)')
|
| 29 |
+
|
| 30 |
+
parser.add_argument('--face_det_batch_size', type=int,
|
| 31 |
+
help='Single GPU batch size for face detection', default=16)
|
| 32 |
+
|
| 33 |
+
parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip', default=128)
|
| 34 |
+
parser.add_argument('--face_res', help='Approximate resolution of the face at which to test', default=180)
|
| 35 |
+
parser.add_argument('--min_frame_res', help='Do not downsample further below this frame resolution', default=480)
|
| 36 |
+
parser.add_argument('--max_frame_res', help='Downsample to at least this frame resolution', default=720)
|
| 37 |
+
# parser.add_argument('--resize_factor', default=1, type=int)
|
| 38 |
+
|
| 39 |
+
args = parser.parse_args()
|
| 40 |
+
args.img_size = 96
|
| 41 |
+
|
| 42 |
+
def get_smoothened_boxes(boxes, T):
|
| 43 |
+
for i in range(len(boxes)):
|
| 44 |
+
if i + T > len(boxes):
|
| 45 |
+
window = boxes[len(boxes) - T:]
|
| 46 |
+
else:
|
| 47 |
+
window = boxes[i : i + T]
|
| 48 |
+
boxes[i] = np.mean(window, axis=0)
|
| 49 |
+
return boxes
|
| 50 |
+
|
| 51 |
+
def rescale_frames(images):
|
| 52 |
+
rect = detector.get_detections_for_batch(np.array([images[0]]))[0]
|
| 53 |
+
if rect is None:
|
| 54 |
+
raise ValueError('Face not detected!')
|
| 55 |
+
h, w = images[0].shape[:-1]
|
| 56 |
+
|
| 57 |
+
x1, y1, x2, y2 = rect
|
| 58 |
+
|
| 59 |
+
face_size = max(np.abs(y1 - y2), np.abs(x1 - x2))
|
| 60 |
+
|
| 61 |
+
diff = np.abs(face_size - args.face_res)
|
| 62 |
+
for factor in range(2, 16):
|
| 63 |
+
downsampled_res = face_size // factor
|
| 64 |
+
if min(h//factor, w//factor) < args.min_frame_res: break
|
| 65 |
+
if np.abs(downsampled_res - args.face_res) >= diff: break
|
| 66 |
+
|
| 67 |
+
factor -= 1
|
| 68 |
+
if factor == 1: return images
|
| 69 |
+
|
| 70 |
+
return [cv2.resize(im, (im.shape[1]//(factor), im.shape[0]//(factor))) for im in images]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def face_detect(images):
|
| 74 |
+
batch_size = args.face_det_batch_size
|
| 75 |
+
images = rescale_frames(images)
|
| 76 |
+
|
| 77 |
+
while 1:
|
| 78 |
+
predictions = []
|
| 79 |
+
try:
|
| 80 |
+
for i in range(0, len(images), batch_size):
|
| 81 |
+
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
|
| 82 |
+
except RuntimeError:
|
| 83 |
+
if batch_size == 1:
|
| 84 |
+
raise RuntimeError('Image too big to run face detection on GPU')
|
| 85 |
+
batch_size //= 2
|
| 86 |
+
print('Recovering from OOM error; New batch size: {}'.format(batch_size))
|
| 87 |
+
continue
|
| 88 |
+
break
|
| 89 |
+
|
| 90 |
+
results = []
|
| 91 |
+
pady1, pady2, padx1, padx2 = args.pads
|
| 92 |
+
for rect, image in zip(predictions, images):
|
| 93 |
+
if rect is None:
|
| 94 |
+
raise ValueError('Face not detected!')
|
| 95 |
+
|
| 96 |
+
y1 = max(0, rect[1] - pady1)
|
| 97 |
+
y2 = min(image.shape[0], rect[3] + pady2)
|
| 98 |
+
x1 = max(0, rect[0] - padx1)
|
| 99 |
+
x2 = min(image.shape[1], rect[2] + padx2)
|
| 100 |
+
|
| 101 |
+
results.append([x1, y1, x2, y2])
|
| 102 |
+
|
| 103 |
+
boxes = get_smoothened_boxes(np.array(results), T=5)
|
| 104 |
+
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
|
| 105 |
+
|
| 106 |
+
return results, images
|
| 107 |
+
|
| 108 |
+
def datagen(frames, face_det_results, mels):
|
| 109 |
+
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
|
| 110 |
+
|
| 111 |
+
for i, m in enumerate(mels):
|
| 112 |
+
if i >= len(frames): raise ValueError('Equal or less lengths only')
|
| 113 |
+
|
| 114 |
+
frame_to_save = frames[i].copy()
|
| 115 |
+
face, coords, valid_frame = face_det_results[i].copy()
|
| 116 |
+
if not valid_frame:
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
face = cv2.resize(face, (args.img_size, args.img_size))
|
| 120 |
+
|
| 121 |
+
img_batch.append(face)
|
| 122 |
+
mel_batch.append(m)
|
| 123 |
+
frame_batch.append(frame_to_save)
|
| 124 |
+
coords_batch.append(coords)
|
| 125 |
+
|
| 126 |
+
if len(img_batch) >= args.wav2lip_batch_size:
|
| 127 |
+
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
|
| 128 |
+
|
| 129 |
+
img_masked = img_batch.copy()
|
| 130 |
+
img_masked[:, args.img_size//2:] = 0
|
| 131 |
+
|
| 132 |
+
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
| 133 |
+
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
| 134 |
+
|
| 135 |
+
yield img_batch, mel_batch, frame_batch, coords_batch
|
| 136 |
+
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
|
| 137 |
+
|
| 138 |
+
if len(img_batch) > 0:
|
| 139 |
+
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
|
| 140 |
+
|
| 141 |
+
img_masked = img_batch.copy()
|
| 142 |
+
img_masked[:, args.img_size//2:] = 0
|
| 143 |
+
|
| 144 |
+
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
| 145 |
+
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
| 146 |
+
|
| 147 |
+
yield img_batch, mel_batch, frame_batch, coords_batch
|
| 148 |
+
|
| 149 |
+
def increase_frames(frames, l):
|
| 150 |
+
## evenly duplicating frames to increase length of video
|
| 151 |
+
while len(frames) < l:
|
| 152 |
+
dup_every = float(l) / len(frames)
|
| 153 |
+
|
| 154 |
+
final_frames = []
|
| 155 |
+
next_duplicate = 0.
|
| 156 |
+
|
| 157 |
+
for i, f in enumerate(frames):
|
| 158 |
+
final_frames.append(f)
|
| 159 |
+
|
| 160 |
+
if int(np.ceil(next_duplicate)) == i:
|
| 161 |
+
final_frames.append(f)
|
| 162 |
+
|
| 163 |
+
next_duplicate += dup_every
|
| 164 |
+
|
| 165 |
+
frames = final_frames
|
| 166 |
+
|
| 167 |
+
return frames[:l]
|
| 168 |
+
|
| 169 |
+
mel_step_size = 16
|
| 170 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 171 |
+
print('Using {} for inference.'.format(device))
|
| 172 |
+
|
| 173 |
+
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
|
| 174 |
+
flip_input=False, device=device)
|
| 175 |
+
|
| 176 |
+
def _load(checkpoint_path):
|
| 177 |
+
if device == 'cuda':
|
| 178 |
+
checkpoint = torch.load(checkpoint_path)
|
| 179 |
+
else:
|
| 180 |
+
checkpoint = torch.load(checkpoint_path,
|
| 181 |
+
map_location=lambda storage, loc: storage)
|
| 182 |
+
return checkpoint
|
| 183 |
+
|
| 184 |
+
def load_model(path):
|
| 185 |
+
model = Wav2Lip()
|
| 186 |
+
print("Load checkpoint from: {}".format(path))
|
| 187 |
+
checkpoint = _load(path)
|
| 188 |
+
s = checkpoint["state_dict"]
|
| 189 |
+
new_s = {}
|
| 190 |
+
for k, v in s.items():
|
| 191 |
+
new_s[k.replace('module.', '')] = v
|
| 192 |
+
model.load_state_dict(new_s)
|
| 193 |
+
|
| 194 |
+
model = model.to(device)
|
| 195 |
+
return model.eval()
|
| 196 |
+
|
| 197 |
+
model = load_model(args.checkpoint_path)
|
| 198 |
+
|
| 199 |
+
def main():
|
| 200 |
+
if not os.path.isdir(args.results_dir): os.makedirs(args.results_dir)
|
| 201 |
+
|
| 202 |
+
if args.mode == 'dubbed':
|
| 203 |
+
files = listdir(args.data_root)
|
| 204 |
+
lines = ['{} {}'.format(f, f) for f in files]
|
| 205 |
+
|
| 206 |
+
else:
|
| 207 |
+
assert args.filelist is not None
|
| 208 |
+
with open(args.filelist, 'r') as filelist:
|
| 209 |
+
lines = filelist.readlines()
|
| 210 |
+
|
| 211 |
+
for idx, line in enumerate(tqdm(lines)):
|
| 212 |
+
video, audio_src = line.strip().split()
|
| 213 |
+
|
| 214 |
+
audio_src = os.path.join(args.data_root, audio_src)
|
| 215 |
+
video = os.path.join(args.data_root, video)
|
| 216 |
+
|
| 217 |
+
command = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'.format(audio_src, '../temp/temp.wav')
|
| 218 |
+
subprocess.call(command, shell=True)
|
| 219 |
+
temp_audio = '../temp/temp.wav'
|
| 220 |
+
|
| 221 |
+
wav = audio.load_wav(temp_audio, 16000)
|
| 222 |
+
mel = audio.melspectrogram(wav)
|
| 223 |
+
|
| 224 |
+
if np.isnan(mel.reshape(-1)).sum() > 0:
|
| 225 |
+
raise ValueError('Mel contains nan!')
|
| 226 |
+
|
| 227 |
+
video_stream = cv2.VideoCapture(video)
|
| 228 |
+
|
| 229 |
+
fps = video_stream.get(cv2.CAP_PROP_FPS)
|
| 230 |
+
mel_idx_multiplier = 80./fps
|
| 231 |
+
|
| 232 |
+
full_frames = []
|
| 233 |
+
while 1:
|
| 234 |
+
still_reading, frame = video_stream.read()
|
| 235 |
+
if not still_reading:
|
| 236 |
+
video_stream.release()
|
| 237 |
+
break
|
| 238 |
+
|
| 239 |
+
if min(frame.shape[:-1]) > args.max_frame_res:
|
| 240 |
+
h, w = frame.shape[:-1]
|
| 241 |
+
scale_factor = min(h, w) / float(args.max_frame_res)
|
| 242 |
+
h = int(h/scale_factor)
|
| 243 |
+
w = int(w/scale_factor)
|
| 244 |
+
|
| 245 |
+
frame = cv2.resize(frame, (w, h))
|
| 246 |
+
full_frames.append(frame)
|
| 247 |
+
|
| 248 |
+
mel_chunks = []
|
| 249 |
+
i = 0
|
| 250 |
+
while 1:
|
| 251 |
+
start_idx = int(i * mel_idx_multiplier)
|
| 252 |
+
if start_idx + mel_step_size > len(mel[0]):
|
| 253 |
+
break
|
| 254 |
+
mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
|
| 255 |
+
i += 1
|
| 256 |
+
|
| 257 |
+
if len(full_frames) < len(mel_chunks):
|
| 258 |
+
if args.mode == 'tts':
|
| 259 |
+
full_frames = increase_frames(full_frames, len(mel_chunks))
|
| 260 |
+
else:
|
| 261 |
+
raise ValueError('#Frames, audio length mismatch')
|
| 262 |
+
|
| 263 |
+
else:
|
| 264 |
+
full_frames = full_frames[:len(mel_chunks)]
|
| 265 |
+
|
| 266 |
+
try:
|
| 267 |
+
face_det_results, full_frames = face_detect(full_frames.copy())
|
| 268 |
+
except ValueError as e:
|
| 269 |
+
continue
|
| 270 |
+
|
| 271 |
+
batch_size = args.wav2lip_batch_size
|
| 272 |
+
gen = datagen(full_frames.copy(), face_det_results, mel_chunks)
|
| 273 |
+
|
| 274 |
+
for i, (img_batch, mel_batch, frames, coords) in enumerate(gen):
|
| 275 |
+
if i == 0:
|
| 276 |
+
frame_h, frame_w = full_frames[0].shape[:-1]
|
| 277 |
+
|
| 278 |
+
out = cv2.VideoWriter('../temp/result.avi',
|
| 279 |
+
cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
|
| 280 |
+
|
| 281 |
+
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
|
| 282 |
+
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
|
| 283 |
+
|
| 284 |
+
with torch.no_grad():
|
| 285 |
+
pred = model(mel_batch, img_batch)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
|
| 289 |
+
|
| 290 |
+
for pl, f, c in zip(pred, frames, coords):
|
| 291 |
+
y1, y2, x1, x2 = c
|
| 292 |
+
pl = cv2.resize(pl.astype(np.uint8), (x2 - x1, y2 - y1))
|
| 293 |
+
f[y1:y2, x1:x2] = pl
|
| 294 |
+
out.write(f)
|
| 295 |
+
|
| 296 |
+
out.release()
|
| 297 |
+
|
| 298 |
+
vid = os.path.join(args.results_dir, '{}.mp4'.format(idx))
|
| 299 |
+
command = 'ffmpeg -loglevel panic -y -i {} -i {} -strict -2 -q:v 1 {}'.format('../temp/temp.wav',
|
| 300 |
+
'../temp/result.avi', vid)
|
| 301 |
+
subprocess.call(command, shell=True)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
if __name__ == '__main__':
|
| 305 |
+
main()
|
evaluation/scores_LSE/SyncNetInstance_calc_scores.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
#-*- coding: utf-8 -*-
|
| 3 |
+
# Video 25 FPS, Audio 16000HZ
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import numpy
|
| 7 |
+
import time, pdb, argparse, subprocess, os, math, glob
|
| 8 |
+
import cv2
|
| 9 |
+
import python_speech_features
|
| 10 |
+
|
| 11 |
+
from scipy import signal
|
| 12 |
+
from scipy.io import wavfile
|
| 13 |
+
from SyncNetModel import *
|
| 14 |
+
from shutil import rmtree
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ==================== Get OFFSET ====================
|
| 18 |
+
|
| 19 |
+
def calc_pdist(feat1, feat2, vshift=10):
|
| 20 |
+
|
| 21 |
+
win_size = vshift*2+1
|
| 22 |
+
|
| 23 |
+
feat2p = torch.nn.functional.pad(feat2,(0,0,vshift,vshift))
|
| 24 |
+
|
| 25 |
+
dists = []
|
| 26 |
+
|
| 27 |
+
for i in range(0,len(feat1)):
|
| 28 |
+
|
| 29 |
+
dists.append(torch.nn.functional.pairwise_distance(feat1[[i],:].repeat(win_size, 1), feat2p[i:i+win_size,:]))
|
| 30 |
+
|
| 31 |
+
return dists
|
| 32 |
+
|
| 33 |
+
# ==================== MAIN DEF ====================
|
| 34 |
+
|
| 35 |
+
class SyncNetInstance(torch.nn.Module):
|
| 36 |
+
|
| 37 |
+
def __init__(self, dropout = 0, num_layers_in_fc_layers = 1024):
|
| 38 |
+
super(SyncNetInstance, self).__init__();
|
| 39 |
+
|
| 40 |
+
self.__S__ = S(num_layers_in_fc_layers = num_layers_in_fc_layers).cuda();
|
| 41 |
+
|
| 42 |
+
def evaluate(self, opt, videofile):
|
| 43 |
+
|
| 44 |
+
self.__S__.eval();
|
| 45 |
+
|
| 46 |
+
# ========== ==========
|
| 47 |
+
# Convert files
|
| 48 |
+
# ========== ==========
|
| 49 |
+
|
| 50 |
+
if os.path.exists(os.path.join(opt.tmp_dir,opt.reference)):
|
| 51 |
+
rmtree(os.path.join(opt.tmp_dir,opt.reference))
|
| 52 |
+
|
| 53 |
+
os.makedirs(os.path.join(opt.tmp_dir,opt.reference))
|
| 54 |
+
|
| 55 |
+
command = ("ffmpeg -loglevel error -y -i %s -threads 1 -f image2 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'%06d.jpg')))
|
| 56 |
+
output = subprocess.call(command, shell=True, stdout=None)
|
| 57 |
+
|
| 58 |
+
command = ("ffmpeg -loglevel error -y -i %s -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'audio.wav')))
|
| 59 |
+
output = subprocess.call(command, shell=True, stdout=None)
|
| 60 |
+
|
| 61 |
+
# ========== ==========
|
| 62 |
+
# Load video
|
| 63 |
+
# ========== ==========
|
| 64 |
+
|
| 65 |
+
images = []
|
| 66 |
+
|
| 67 |
+
flist = glob.glob(os.path.join(opt.tmp_dir,opt.reference,'*.jpg'))
|
| 68 |
+
flist.sort()
|
| 69 |
+
|
| 70 |
+
for fname in flist:
|
| 71 |
+
img_input = cv2.imread(fname)
|
| 72 |
+
img_input = cv2.resize(img_input, (224,224)) #HARD CODED, CHANGE BEFORE RELEASE
|
| 73 |
+
images.append(img_input)
|
| 74 |
+
|
| 75 |
+
im = numpy.stack(images,axis=3)
|
| 76 |
+
im = numpy.expand_dims(im,axis=0)
|
| 77 |
+
im = numpy.transpose(im,(0,3,4,1,2))
|
| 78 |
+
|
| 79 |
+
imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
|
| 80 |
+
|
| 81 |
+
# ========== ==========
|
| 82 |
+
# Load audio
|
| 83 |
+
# ========== ==========
|
| 84 |
+
|
| 85 |
+
sample_rate, audio = wavfile.read(os.path.join(opt.tmp_dir,opt.reference,'audio.wav'))
|
| 86 |
+
mfcc = zip(*python_speech_features.mfcc(audio,sample_rate))
|
| 87 |
+
mfcc = numpy.stack([numpy.array(i) for i in mfcc])
|
| 88 |
+
|
| 89 |
+
cc = numpy.expand_dims(numpy.expand_dims(mfcc,axis=0),axis=0)
|
| 90 |
+
cct = torch.autograd.Variable(torch.from_numpy(cc.astype(float)).float())
|
| 91 |
+
|
| 92 |
+
# ========== ==========
|
| 93 |
+
# Check audio and video input length
|
| 94 |
+
# ========== ==========
|
| 95 |
+
|
| 96 |
+
#if (float(len(audio))/16000) != (float(len(images))/25) :
|
| 97 |
+
# print("WARNING: Audio (%.4fs) and video (%.4fs) lengths are different."%(float(len(audio))/16000,float(len(images))/25))
|
| 98 |
+
|
| 99 |
+
min_length = min(len(images),math.floor(len(audio)/640))
|
| 100 |
+
|
| 101 |
+
# ========== ==========
|
| 102 |
+
# Generate video and audio feats
|
| 103 |
+
# ========== ==========
|
| 104 |
+
|
| 105 |
+
lastframe = min_length-5
|
| 106 |
+
im_feat = []
|
| 107 |
+
cc_feat = []
|
| 108 |
+
|
| 109 |
+
tS = time.time()
|
| 110 |
+
for i in range(0,lastframe,opt.batch_size):
|
| 111 |
+
|
| 112 |
+
im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ]
|
| 113 |
+
im_in = torch.cat(im_batch,0)
|
| 114 |
+
im_out = self.__S__.forward_lip(im_in.cuda());
|
| 115 |
+
im_feat.append(im_out.data.cpu())
|
| 116 |
+
|
| 117 |
+
cc_batch = [ cct[:,:,:,vframe*4:vframe*4+20] for vframe in range(i,min(lastframe,i+opt.batch_size)) ]
|
| 118 |
+
cc_in = torch.cat(cc_batch,0)
|
| 119 |
+
cc_out = self.__S__.forward_aud(cc_in.cuda())
|
| 120 |
+
cc_feat.append(cc_out.data.cpu())
|
| 121 |
+
|
| 122 |
+
im_feat = torch.cat(im_feat,0)
|
| 123 |
+
cc_feat = torch.cat(cc_feat,0)
|
| 124 |
+
|
| 125 |
+
# ========== ==========
|
| 126 |
+
# Compute offset
|
| 127 |
+
# ========== ==========
|
| 128 |
+
|
| 129 |
+
#print('Compute time %.3f sec.' % (time.time()-tS))
|
| 130 |
+
|
| 131 |
+
dists = calc_pdist(im_feat,cc_feat,vshift=opt.vshift)
|
| 132 |
+
mdist = torch.mean(torch.stack(dists,1),1)
|
| 133 |
+
|
| 134 |
+
minval, minidx = torch.min(mdist,0)
|
| 135 |
+
|
| 136 |
+
offset = opt.vshift-minidx
|
| 137 |
+
conf = torch.median(mdist) - minval
|
| 138 |
+
|
| 139 |
+
fdist = numpy.stack([dist[minidx].numpy() for dist in dists])
|
| 140 |
+
# fdist = numpy.pad(fdist, (3,3), 'constant', constant_values=15)
|
| 141 |
+
fconf = torch.median(mdist).numpy() - fdist
|
| 142 |
+
fconfm = signal.medfilt(fconf,kernel_size=9)
|
| 143 |
+
|
| 144 |
+
numpy.set_printoptions(formatter={'float': '{: 0.3f}'.format})
|
| 145 |
+
#print('Framewise conf: ')
|
| 146 |
+
#print(fconfm)
|
| 147 |
+
#print('AV offset: \t%d \nMin dist: \t%.3f\nConfidence: \t%.3f' % (offset,minval,conf))
|
| 148 |
+
|
| 149 |
+
dists_npy = numpy.array([ dist.numpy() for dist in dists ])
|
| 150 |
+
return offset.numpy(), conf.numpy(), minval.numpy()
|
| 151 |
+
|
| 152 |
+
def extract_feature(self, opt, videofile):
|
| 153 |
+
|
| 154 |
+
self.__S__.eval();
|
| 155 |
+
|
| 156 |
+
# ========== ==========
|
| 157 |
+
# Load video
|
| 158 |
+
# ========== ==========
|
| 159 |
+
cap = cv2.VideoCapture(videofile)
|
| 160 |
+
|
| 161 |
+
frame_num = 1;
|
| 162 |
+
images = []
|
| 163 |
+
while frame_num:
|
| 164 |
+
frame_num += 1
|
| 165 |
+
ret, image = cap.read()
|
| 166 |
+
if ret == 0:
|
| 167 |
+
break
|
| 168 |
+
|
| 169 |
+
images.append(image)
|
| 170 |
+
|
| 171 |
+
im = numpy.stack(images,axis=3)
|
| 172 |
+
im = numpy.expand_dims(im,axis=0)
|
| 173 |
+
im = numpy.transpose(im,(0,3,4,1,2))
|
| 174 |
+
|
| 175 |
+
imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
|
| 176 |
+
|
| 177 |
+
# ========== ==========
|
| 178 |
+
# Generate video feats
|
| 179 |
+
# ========== ==========
|
| 180 |
+
|
| 181 |
+
lastframe = len(images)-4
|
| 182 |
+
im_feat = []
|
| 183 |
+
|
| 184 |
+
tS = time.time()
|
| 185 |
+
for i in range(0,lastframe,opt.batch_size):
|
| 186 |
+
|
| 187 |
+
im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ]
|
| 188 |
+
im_in = torch.cat(im_batch,0)
|
| 189 |
+
im_out = self.__S__.forward_lipfeat(im_in.cuda());
|
| 190 |
+
im_feat.append(im_out.data.cpu())
|
| 191 |
+
|
| 192 |
+
im_feat = torch.cat(im_feat,0)
|
| 193 |
+
|
| 194 |
+
# ========== ==========
|
| 195 |
+
# Compute offset
|
| 196 |
+
# ========== ==========
|
| 197 |
+
|
| 198 |
+
print('Compute time %.3f sec.' % (time.time()-tS))
|
| 199 |
+
|
| 200 |
+
return im_feat
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def loadParameters(self, path):
|
| 204 |
+
loaded_state = torch.load(path, map_location=lambda storage, loc: storage);
|
| 205 |
+
|
| 206 |
+
self_state = self.__S__.state_dict();
|
| 207 |
+
|
| 208 |
+
for name, param in loaded_state.items():
|
| 209 |
+
|
| 210 |
+
self_state[name].copy_(param);
|
evaluation/scores_LSE/calculate_scores_LRS.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
#-*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import time, pdb, argparse, subprocess
|
| 5 |
+
import glob
|
| 6 |
+
import os
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
from SyncNetInstance_calc_scores import *
|
| 10 |
+
|
| 11 |
+
# ==================== LOAD PARAMS ====================
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
parser = argparse.ArgumentParser(description = "SyncNet");
|
| 15 |
+
|
| 16 |
+
parser.add_argument('--initial_model', type=str, default="data/syncnet_v2.model", help='');
|
| 17 |
+
parser.add_argument('--batch_size', type=int, default='20', help='');
|
| 18 |
+
parser.add_argument('--vshift', type=int, default='15', help='');
|
| 19 |
+
parser.add_argument('--data_root', type=str, required=True, help='');
|
| 20 |
+
parser.add_argument('--tmp_dir', type=str, default="data/work/pytmp", help='');
|
| 21 |
+
parser.add_argument('--reference', type=str, default="demo", help='');
|
| 22 |
+
|
| 23 |
+
opt = parser.parse_args();
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ==================== RUN EVALUATION ====================
|
| 27 |
+
|
| 28 |
+
s = SyncNetInstance();
|
| 29 |
+
|
| 30 |
+
s.loadParameters(opt.initial_model);
|
| 31 |
+
#print("Model %s loaded."%opt.initial_model);
|
| 32 |
+
path = os.path.join(opt.data_root, "*.mp4")
|
| 33 |
+
|
| 34 |
+
all_videos = glob.glob(path)
|
| 35 |
+
|
| 36 |
+
prog_bar = tqdm(range(len(all_videos)))
|
| 37 |
+
avg_confidence = 0.
|
| 38 |
+
avg_min_distance = 0.
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
for videofile_idx in prog_bar:
|
| 42 |
+
videofile = all_videos[videofile_idx]
|
| 43 |
+
offset, confidence, min_distance = s.evaluate(opt, videofile=videofile)
|
| 44 |
+
avg_confidence += confidence
|
| 45 |
+
avg_min_distance += min_distance
|
| 46 |
+
prog_bar.set_description('Avg Confidence: {}, Avg Minimum Dist: {}'.format(round(avg_confidence / (videofile_idx + 1), 3), round(avg_min_distance / (videofile_idx + 1), 3)))
|
| 47 |
+
prog_bar.refresh()
|
| 48 |
+
|
| 49 |
+
print ('Average Confidence: {}'.format(avg_confidence/len(all_videos)))
|
| 50 |
+
print ('Average Minimum Distance: {}'.format(avg_min_distance/len(all_videos)))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
evaluation/scores_LSE/calculate_scores_real_videos.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
#-*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import time, pdb, argparse, subprocess, pickle, os, gzip, glob
|
| 5 |
+
|
| 6 |
+
from SyncNetInstance_calc_scores import *
|
| 7 |
+
|
| 8 |
+
# ==================== PARSE ARGUMENT ====================
|
| 9 |
+
|
| 10 |
+
parser = argparse.ArgumentParser(description = "SyncNet");
|
| 11 |
+
parser.add_argument('--initial_model', type=str, default="data/syncnet_v2.model", help='');
|
| 12 |
+
parser.add_argument('--batch_size', type=int, default='20', help='');
|
| 13 |
+
parser.add_argument('--vshift', type=int, default='15', help='');
|
| 14 |
+
parser.add_argument('--data_dir', type=str, default='data/work', help='');
|
| 15 |
+
parser.add_argument('--videofile', type=str, default='', help='');
|
| 16 |
+
parser.add_argument('--reference', type=str, default='', help='');
|
| 17 |
+
opt = parser.parse_args();
|
| 18 |
+
|
| 19 |
+
setattr(opt,'avi_dir',os.path.join(opt.data_dir,'pyavi'))
|
| 20 |
+
setattr(opt,'tmp_dir',os.path.join(opt.data_dir,'pytmp'))
|
| 21 |
+
setattr(opt,'work_dir',os.path.join(opt.data_dir,'pywork'))
|
| 22 |
+
setattr(opt,'crop_dir',os.path.join(opt.data_dir,'pycrop'))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ==================== LOAD MODEL AND FILE LIST ====================
|
| 26 |
+
|
| 27 |
+
s = SyncNetInstance();
|
| 28 |
+
|
| 29 |
+
s.loadParameters(opt.initial_model);
|
| 30 |
+
#print("Model %s loaded."%opt.initial_model);
|
| 31 |
+
|
| 32 |
+
flist = glob.glob(os.path.join(opt.crop_dir,opt.reference,'0*.avi'))
|
| 33 |
+
flist.sort()
|
| 34 |
+
|
| 35 |
+
# ==================== GET OFFSETS ====================
|
| 36 |
+
|
| 37 |
+
dists = []
|
| 38 |
+
for idx, fname in enumerate(flist):
|
| 39 |
+
offset, conf, dist = s.evaluate(opt,videofile=fname)
|
| 40 |
+
print (str(dist)+" "+str(conf))
|
| 41 |
+
|
| 42 |
+
# ==================== PRINT RESULTS TO FILE ====================
|
| 43 |
+
|
| 44 |
+
#with open(os.path.join(opt.work_dir,opt.reference,'activesd.pckl'), 'wb') as fil:
|
| 45 |
+
# pickle.dump(dists, fil)
|
evaluation/scores_LSE/calculate_scores_real_videos.sh
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
rm all_scores.txt
|
| 2 |
+
yourfilenames=`ls $1`
|
| 3 |
+
|
| 4 |
+
for eachfile in $yourfilenames
|
| 5 |
+
do
|
| 6 |
+
python run_pipeline.py --videofile $1/$eachfile --reference wav2lip --data_dir tmp_dir
|
| 7 |
+
python calculate_scores_real_videos.py --videofile $1/$eachfile --reference wav2lip --data_dir tmp_dir >> all_scores.txt
|
| 8 |
+
done
|
evaluation/test_filelists/ReSyncED/random_pairs.txt
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sachin.mp4 emma_cropped.mp4
|
| 2 |
+
sachin.mp4 mourinho.mp4
|
| 3 |
+
sachin.mp4 elon.mp4
|
| 4 |
+
sachin.mp4 messi2.mp4
|
| 5 |
+
sachin.mp4 cr1.mp4
|
| 6 |
+
sachin.mp4 sachin.mp4
|
| 7 |
+
sachin.mp4 sg.mp4
|
| 8 |
+
sachin.mp4 fergi.mp4
|
| 9 |
+
sachin.mp4 spanish_lec1.mp4
|
| 10 |
+
sachin.mp4 bush_small.mp4
|
| 11 |
+
sachin.mp4 macca_cut.mp4
|
| 12 |
+
sachin.mp4 ca_cropped.mp4
|
| 13 |
+
sachin.mp4 lecun.mp4
|
| 14 |
+
sachin.mp4 spanish_lec0.mp4
|
| 15 |
+
srk.mp4 emma_cropped.mp4
|
| 16 |
+
srk.mp4 mourinho.mp4
|
| 17 |
+
srk.mp4 elon.mp4
|
| 18 |
+
srk.mp4 messi2.mp4
|
| 19 |
+
srk.mp4 cr1.mp4
|
| 20 |
+
srk.mp4 srk.mp4
|
| 21 |
+
srk.mp4 sachin.mp4
|
| 22 |
+
srk.mp4 sg.mp4
|
| 23 |
+
srk.mp4 fergi.mp4
|
| 24 |
+
srk.mp4 spanish_lec1.mp4
|
| 25 |
+
srk.mp4 bush_small.mp4
|
| 26 |
+
srk.mp4 macca_cut.mp4
|
| 27 |
+
srk.mp4 ca_cropped.mp4
|
| 28 |
+
srk.mp4 guardiola.mp4
|
| 29 |
+
srk.mp4 lecun.mp4
|
| 30 |
+
srk.mp4 spanish_lec0.mp4
|
| 31 |
+
cr1.mp4 emma_cropped.mp4
|
| 32 |
+
cr1.mp4 elon.mp4
|
| 33 |
+
cr1.mp4 messi2.mp4
|
| 34 |
+
cr1.mp4 cr1.mp4
|
| 35 |
+
cr1.mp4 spanish_lec1.mp4
|
| 36 |
+
cr1.mp4 bush_small.mp4
|
| 37 |
+
cr1.mp4 macca_cut.mp4
|
| 38 |
+
cr1.mp4 ca_cropped.mp4
|
| 39 |
+
cr1.mp4 lecun.mp4
|
| 40 |
+
cr1.mp4 spanish_lec0.mp4
|
| 41 |
+
macca_cut.mp4 emma_cropped.mp4
|
| 42 |
+
macca_cut.mp4 elon.mp4
|
| 43 |
+
macca_cut.mp4 messi2.mp4
|
| 44 |
+
macca_cut.mp4 spanish_lec1.mp4
|
| 45 |
+
macca_cut.mp4 macca_cut.mp4
|
| 46 |
+
macca_cut.mp4 ca_cropped.mp4
|
| 47 |
+
macca_cut.mp4 spanish_lec0.mp4
|
| 48 |
+
lecun.mp4 emma_cropped.mp4
|
| 49 |
+
lecun.mp4 elon.mp4
|
| 50 |
+
lecun.mp4 messi2.mp4
|
| 51 |
+
lecun.mp4 spanish_lec1.mp4
|
| 52 |
+
lecun.mp4 macca_cut.mp4
|
| 53 |
+
lecun.mp4 ca_cropped.mp4
|
| 54 |
+
lecun.mp4 lecun.mp4
|
| 55 |
+
lecun.mp4 spanish_lec0.mp4
|
| 56 |
+
messi2.mp4 emma_cropped.mp4
|
| 57 |
+
messi2.mp4 elon.mp4
|
| 58 |
+
messi2.mp4 messi2.mp4
|
| 59 |
+
messi2.mp4 spanish_lec1.mp4
|
| 60 |
+
messi2.mp4 macca_cut.mp4
|
| 61 |
+
messi2.mp4 ca_cropped.mp4
|
| 62 |
+
messi2.mp4 spanish_lec0.mp4
|
| 63 |
+
ca_cropped.mp4 emma_cropped.mp4
|
| 64 |
+
ca_cropped.mp4 elon.mp4
|
| 65 |
+
ca_cropped.mp4 spanish_lec1.mp4
|
| 66 |
+
ca_cropped.mp4 ca_cropped.mp4
|
| 67 |
+
ca_cropped.mp4 spanish_lec0.mp4
|
| 68 |
+
spanish_lec1.mp4 spanish_lec1.mp4
|
| 69 |
+
spanish_lec1.mp4 spanish_lec0.mp4
|
| 70 |
+
elon.mp4 elon.mp4
|
| 71 |
+
elon.mp4 spanish_lec1.mp4
|
| 72 |
+
elon.mp4 spanish_lec0.mp4
|
| 73 |
+
guardiola.mp4 emma_cropped.mp4
|
| 74 |
+
guardiola.mp4 mourinho.mp4
|
| 75 |
+
guardiola.mp4 elon.mp4
|
| 76 |
+
guardiola.mp4 messi2.mp4
|
| 77 |
+
guardiola.mp4 cr1.mp4
|
| 78 |
+
guardiola.mp4 sachin.mp4
|
| 79 |
+
guardiola.mp4 sg.mp4
|
| 80 |
+
guardiola.mp4 fergi.mp4
|
| 81 |
+
guardiola.mp4 spanish_lec1.mp4
|
| 82 |
+
guardiola.mp4 bush_small.mp4
|
| 83 |
+
guardiola.mp4 macca_cut.mp4
|
| 84 |
+
guardiola.mp4 ca_cropped.mp4
|
| 85 |
+
guardiola.mp4 guardiola.mp4
|
| 86 |
+
guardiola.mp4 lecun.mp4
|
| 87 |
+
guardiola.mp4 spanish_lec0.mp4
|
| 88 |
+
fergi.mp4 emma_cropped.mp4
|
| 89 |
+
fergi.mp4 mourinho.mp4
|
| 90 |
+
fergi.mp4 elon.mp4
|
| 91 |
+
fergi.mp4 messi2.mp4
|
| 92 |
+
fergi.mp4 cr1.mp4
|
| 93 |
+
fergi.mp4 sachin.mp4
|
| 94 |
+
fergi.mp4 sg.mp4
|
| 95 |
+
fergi.mp4 fergi.mp4
|
| 96 |
+
fergi.mp4 spanish_lec1.mp4
|
| 97 |
+
fergi.mp4 bush_small.mp4
|
| 98 |
+
fergi.mp4 macca_cut.mp4
|
| 99 |
+
fergi.mp4 ca_cropped.mp4
|
| 100 |
+
fergi.mp4 lecun.mp4
|
| 101 |
+
fergi.mp4 spanish_lec0.mp4
|
| 102 |
+
spanish.mp4 emma_cropped.mp4
|
| 103 |
+
spanish.mp4 spanish.mp4
|
| 104 |
+
spanish.mp4 mourinho.mp4
|
| 105 |
+
spanish.mp4 elon.mp4
|
| 106 |
+
spanish.mp4 messi2.mp4
|
| 107 |
+
spanish.mp4 cr1.mp4
|
| 108 |
+
spanish.mp4 srk.mp4
|
| 109 |
+
spanish.mp4 sachin.mp4
|
| 110 |
+
spanish.mp4 sg.mp4
|
| 111 |
+
spanish.mp4 fergi.mp4
|
| 112 |
+
spanish.mp4 spanish_lec1.mp4
|
| 113 |
+
spanish.mp4 bush_small.mp4
|
| 114 |
+
spanish.mp4 macca_cut.mp4
|
| 115 |
+
spanish.mp4 ca_cropped.mp4
|
| 116 |
+
spanish.mp4 guardiola.mp4
|
| 117 |
+
spanish.mp4 lecun.mp4
|
| 118 |
+
spanish.mp4 spanish_lec0.mp4
|
| 119 |
+
bush_small.mp4 emma_cropped.mp4
|
| 120 |
+
bush_small.mp4 elon.mp4
|
| 121 |
+
bush_small.mp4 messi2.mp4
|
| 122 |
+
bush_small.mp4 spanish_lec1.mp4
|
| 123 |
+
bush_small.mp4 bush_small.mp4
|
| 124 |
+
bush_small.mp4 macca_cut.mp4
|
| 125 |
+
bush_small.mp4 ca_cropped.mp4
|
| 126 |
+
bush_small.mp4 lecun.mp4
|
| 127 |
+
bush_small.mp4 spanish_lec0.mp4
|
| 128 |
+
emma_cropped.mp4 emma_cropped.mp4
|
| 129 |
+
emma_cropped.mp4 elon.mp4
|
| 130 |
+
emma_cropped.mp4 spanish_lec1.mp4
|
| 131 |
+
emma_cropped.mp4 spanish_lec0.mp4
|
| 132 |
+
sg.mp4 emma_cropped.mp4
|
| 133 |
+
sg.mp4 mourinho.mp4
|
| 134 |
+
sg.mp4 elon.mp4
|
| 135 |
+
sg.mp4 messi2.mp4
|
| 136 |
+
sg.mp4 cr1.mp4
|
| 137 |
+
sg.mp4 sachin.mp4
|
| 138 |
+
sg.mp4 sg.mp4
|
| 139 |
+
sg.mp4 fergi.mp4
|
| 140 |
+
sg.mp4 spanish_lec1.mp4
|
| 141 |
+
sg.mp4 bush_small.mp4
|
| 142 |
+
sg.mp4 macca_cut.mp4
|
| 143 |
+
sg.mp4 ca_cropped.mp4
|
| 144 |
+
sg.mp4 lecun.mp4
|
| 145 |
+
sg.mp4 spanish_lec0.mp4
|
| 146 |
+
spanish_lec0.mp4 spanish_lec0.mp4
|
| 147 |
+
mourinho.mp4 emma_cropped.mp4
|
| 148 |
+
mourinho.mp4 mourinho.mp4
|
| 149 |
+
mourinho.mp4 elon.mp4
|
| 150 |
+
mourinho.mp4 messi2.mp4
|
| 151 |
+
mourinho.mp4 cr1.mp4
|
| 152 |
+
mourinho.mp4 sachin.mp4
|
| 153 |
+
mourinho.mp4 sg.mp4
|
| 154 |
+
mourinho.mp4 fergi.mp4
|
| 155 |
+
mourinho.mp4 spanish_lec1.mp4
|
| 156 |
+
mourinho.mp4 bush_small.mp4
|
| 157 |
+
mourinho.mp4 macca_cut.mp4
|
| 158 |
+
mourinho.mp4 ca_cropped.mp4
|
| 159 |
+
mourinho.mp4 lecun.mp4
|
| 160 |
+
mourinho.mp4 spanish_lec0.mp4
|
evaluation/test_filelists/ReSyncED/tts_pairs.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
adam_1.mp4 andreng_optimization.wav
|
| 2 |
+
agad_2.mp4 agad_2.wav
|
| 3 |
+
agad_1.mp4 agad_1.wav
|
| 4 |
+
agad_3.mp4 agad_3.wav
|
| 5 |
+
rms_prop_1.mp4 rms_prop_tts.wav
|
| 6 |
+
tf_1.mp4 tf_1.wav
|
| 7 |
+
tf_2.mp4 tf_2.wav
|
| 8 |
+
andrew_ng_ai_business.mp4 andrewng_business_tts.wav
|
| 9 |
+
covid_autopsy_1.mp4 autopsy_tts.wav
|
| 10 |
+
news_1.mp4 news_tts.wav
|
| 11 |
+
andrew_ng_fund_1.mp4 andrewng_ai_fund.wav
|
| 12 |
+
covid_treatments_1.mp4 covid_tts.wav
|
| 13 |
+
pytorch_v_tf.mp4 pytorch_vs_tf_eng.wav
|
| 14 |
+
pytorch_1.mp4 pytorch.wav
|
| 15 |
+
pkb_1.mp4 pkb_1.wav
|
| 16 |
+
ss_1.mp4 ss_1.wav
|
| 17 |
+
carlsen_1.mp4 carlsen_eng.wav
|
| 18 |
+
french.mp4 french.wav
|
evaluation/test_filelists/lrs2.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
evaluation/test_filelists/lrs3.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
evaluation/test_filelists/lrw.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
face_detection/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
__author__ = """Adrian Bulat"""
|
| 4 |
+
__email__ = 'adrian.bulat@nottingham.ac.uk'
|
| 5 |
+
__version__ = '1.0.1'
|
| 6 |
+
|
| 7 |
+
from .api import FaceAlignment, LandmarksType, NetworkSize
|
face_detection/api.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.model_zoo import load_url
|
| 5 |
+
from enum import Enum
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
try:
|
| 9 |
+
import urllib.request as request_file
|
| 10 |
+
except BaseException:
|
| 11 |
+
import urllib as request_file
|
| 12 |
+
|
| 13 |
+
from .models import FAN, ResNetDepth
|
| 14 |
+
from .utils import *
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class LandmarksType(Enum):
|
| 18 |
+
"""Enum class defining the type of landmarks to detect.
|
| 19 |
+
|
| 20 |
+
``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
|
| 21 |
+
``_2halfD`` - this points represent the projection of the 3D points into 3D
|
| 22 |
+
``_3D`` - detect the points ``(x,y,z)``` in a 3D space
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
_2D = 1
|
| 26 |
+
_2halfD = 2
|
| 27 |
+
_3D = 3
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class NetworkSize(Enum):
|
| 31 |
+
# TINY = 1
|
| 32 |
+
# SMALL = 2
|
| 33 |
+
# MEDIUM = 3
|
| 34 |
+
LARGE = 4
|
| 35 |
+
|
| 36 |
+
def __new__(cls, value):
|
| 37 |
+
member = object.__new__(cls)
|
| 38 |
+
member._value_ = value
|
| 39 |
+
return member
|
| 40 |
+
|
| 41 |
+
def __int__(self):
|
| 42 |
+
return self.value
|
| 43 |
+
|
| 44 |
+
ROOT = os.path.dirname(os.path.abspath(__file__))
|
| 45 |
+
|
| 46 |
+
class FaceAlignment:
|
| 47 |
+
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
|
| 48 |
+
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
|
| 49 |
+
self.device = device
|
| 50 |
+
self.flip_input = flip_input
|
| 51 |
+
self.landmarks_type = landmarks_type
|
| 52 |
+
self.verbose = verbose
|
| 53 |
+
|
| 54 |
+
network_size = int(network_size)
|
| 55 |
+
|
| 56 |
+
if 'cuda' in device:
|
| 57 |
+
torch.backends.cudnn.benchmark = True
|
| 58 |
+
|
| 59 |
+
# Get the face detector
|
| 60 |
+
face_detector_module = __import__('face_detection.detection.' + face_detector,
|
| 61 |
+
globals(), locals(), [face_detector], 0)
|
| 62 |
+
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
|
| 63 |
+
|
| 64 |
+
def get_detections_for_batch(self, images):
|
| 65 |
+
images = images[..., ::-1]
|
| 66 |
+
detected_faces = self.face_detector.detect_from_batch(images.copy())
|
| 67 |
+
results = []
|
| 68 |
+
|
| 69 |
+
for i, d in enumerate(detected_faces):
|
| 70 |
+
if len(d) == 0:
|
| 71 |
+
results.append(None)
|
| 72 |
+
continue
|
| 73 |
+
d = d[0]
|
| 74 |
+
d = np.clip(d, 0, None)
|
| 75 |
+
|
| 76 |
+
x1, y1, x2, y2 = map(int, d[:-1])
|
| 77 |
+
results.append((x1, y1, x2, y2))
|
| 78 |
+
|
| 79 |
+
return results
|
face_detection/detection/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .core import FaceDetector
|
face_detection/detection/core.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import glob
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import cv2
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class FaceDetector(object):
|
| 10 |
+
"""An abstract class representing a face detector.
|
| 11 |
+
|
| 12 |
+
Any other face detection implementation must subclass it. All subclasses
|
| 13 |
+
must implement ``detect_from_image``, that return a list of detected
|
| 14 |
+
bounding boxes. Optionally, for speed considerations detect from path is
|
| 15 |
+
recommended.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, device, verbose):
|
| 19 |
+
self.device = device
|
| 20 |
+
self.verbose = verbose
|
| 21 |
+
|
| 22 |
+
if verbose:
|
| 23 |
+
if 'cpu' in device:
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
logger.warning("Detection running on CPU, this may be potentially slow.")
|
| 26 |
+
|
| 27 |
+
if 'cpu' not in device and 'cuda' not in device:
|
| 28 |
+
if verbose:
|
| 29 |
+
logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
|
| 30 |
+
raise ValueError
|
| 31 |
+
|
| 32 |
+
def detect_from_image(self, tensor_or_path):
|
| 33 |
+
"""Detects faces in a given image.
|
| 34 |
+
|
| 35 |
+
This function detects the faces present in a provided BGR(usually)
|
| 36 |
+
image. The input can be either the image itself or the path to it.
|
| 37 |
+
|
| 38 |
+
Arguments:
|
| 39 |
+
tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
|
| 40 |
+
to an image or the image itself.
|
| 41 |
+
|
| 42 |
+
Example::
|
| 43 |
+
|
| 44 |
+
>>> path_to_image = 'data/image_01.jpg'
|
| 45 |
+
... detected_faces = detect_from_image(path_to_image)
|
| 46 |
+
[A list of bounding boxes (x1, y1, x2, y2)]
|
| 47 |
+
>>> image = cv2.imread(path_to_image)
|
| 48 |
+
... detected_faces = detect_from_image(image)
|
| 49 |
+
[A list of bounding boxes (x1, y1, x2, y2)]
|
| 50 |
+
|
| 51 |
+
"""
|
| 52 |
+
raise NotImplementedError
|
| 53 |
+
|
| 54 |
+
def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
|
| 55 |
+
"""Detects faces from all the images present in a given directory.
|
| 56 |
+
|
| 57 |
+
Arguments:
|
| 58 |
+
path {string} -- a string containing a path that points to the folder containing the images
|
| 59 |
+
|
| 60 |
+
Keyword Arguments:
|
| 61 |
+
extensions {list} -- list of string containing the extensions to be
|
| 62 |
+
consider in the following format: ``.extension_name`` (default:
|
| 63 |
+
{['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
|
| 64 |
+
folder recursively (default: {False}) show_progress_bar {bool} --
|
| 65 |
+
display a progressbar (default: {True})
|
| 66 |
+
|
| 67 |
+
Example:
|
| 68 |
+
>>> directory = 'data'
|
| 69 |
+
... detected_faces = detect_from_directory(directory)
|
| 70 |
+
{A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
|
| 71 |
+
|
| 72 |
+
"""
|
| 73 |
+
if self.verbose:
|
| 74 |
+
logger = logging.getLogger(__name__)
|
| 75 |
+
|
| 76 |
+
if len(extensions) == 0:
|
| 77 |
+
if self.verbose:
|
| 78 |
+
logger.error("Expected at list one extension, but none was received.")
|
| 79 |
+
raise ValueError
|
| 80 |
+
|
| 81 |
+
if self.verbose:
|
| 82 |
+
logger.info("Constructing the list of images.")
|
| 83 |
+
additional_pattern = '/**/*' if recursive else '/*'
|
| 84 |
+
files = []
|
| 85 |
+
for extension in extensions:
|
| 86 |
+
files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
|
| 87 |
+
|
| 88 |
+
if self.verbose:
|
| 89 |
+
logger.info("Finished searching for images. %s images found", len(files))
|
| 90 |
+
logger.info("Preparing to run the detection.")
|
| 91 |
+
|
| 92 |
+
predictions = {}
|
| 93 |
+
for image_path in tqdm(files, disable=not show_progress_bar):
|
| 94 |
+
if self.verbose:
|
| 95 |
+
logger.info("Running the face detector on image: %s", image_path)
|
| 96 |
+
predictions[image_path] = self.detect_from_image(image_path)
|
| 97 |
+
|
| 98 |
+
if self.verbose:
|
| 99 |
+
logger.info("The detector was successfully run on all %s images", len(files))
|
| 100 |
+
|
| 101 |
+
return predictions
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def reference_scale(self):
|
| 105 |
+
raise NotImplementedError
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def reference_x_shift(self):
|
| 109 |
+
raise NotImplementedError
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def reference_y_shift(self):
|
| 113 |
+
raise NotImplementedError
|
| 114 |
+
|
| 115 |
+
@staticmethod
|
| 116 |
+
def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
|
| 117 |
+
"""Convert path (represented as a string) or torch.tensor to a numpy.ndarray
|
| 118 |
+
|
| 119 |
+
Arguments:
|
| 120 |
+
tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
|
| 121 |
+
"""
|
| 122 |
+
if isinstance(tensor_or_path, str):
|
| 123 |
+
return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
|
| 124 |
+
elif torch.is_tensor(tensor_or_path):
|
| 125 |
+
# Call cpu in case its coming from cuda
|
| 126 |
+
return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
|
| 127 |
+
elif isinstance(tensor_or_path, np.ndarray):
|
| 128 |
+
return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
|
| 129 |
+
else:
|
| 130 |
+
raise TypeError
|
face_detection/detection/sfd/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .sfd_detector import SFDDetector as FaceDetector
|
face_detection/detection/sfd/bbox.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import cv2
|
| 5 |
+
import random
|
| 6 |
+
import datetime
|
| 7 |
+
import time
|
| 8 |
+
import math
|
| 9 |
+
import argparse
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from iou import IOU
|
| 15 |
+
except BaseException:
|
| 16 |
+
# IOU cython speedup 10x
|
| 17 |
+
def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
|
| 18 |
+
sa = abs((ax2 - ax1) * (ay2 - ay1))
|
| 19 |
+
sb = abs((bx2 - bx1) * (by2 - by1))
|
| 20 |
+
x1, y1 = max(ax1, bx1), max(ay1, by1)
|
| 21 |
+
x2, y2 = min(ax2, bx2), min(ay2, by2)
|
| 22 |
+
w = x2 - x1
|
| 23 |
+
h = y2 - y1
|
| 24 |
+
if w < 0 or h < 0:
|
| 25 |
+
return 0.0
|
| 26 |
+
else:
|
| 27 |
+
return 1.0 * w * h / (sa + sb - w * h)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
|
| 31 |
+
xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
|
| 32 |
+
dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
|
| 33 |
+
dw, dh = math.log(ww / aww), math.log(hh / ahh)
|
| 34 |
+
return dx, dy, dw, dh
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
|
| 38 |
+
xc, yc = dx * aww + axc, dy * ahh + ayc
|
| 39 |
+
ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
|
| 40 |
+
x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
|
| 41 |
+
return x1, y1, x2, y2
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def nms(dets, thresh):
|
| 45 |
+
if 0 == len(dets):
|
| 46 |
+
return []
|
| 47 |
+
x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
|
| 48 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
| 49 |
+
order = scores.argsort()[::-1]
|
| 50 |
+
|
| 51 |
+
keep = []
|
| 52 |
+
while order.size > 0:
|
| 53 |
+
i = order[0]
|
| 54 |
+
keep.append(i)
|
| 55 |
+
xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
|
| 56 |
+
xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
|
| 57 |
+
|
| 58 |
+
w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
|
| 59 |
+
ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
|
| 60 |
+
|
| 61 |
+
inds = np.where(ovr <= thresh)[0]
|
| 62 |
+
order = order[inds + 1]
|
| 63 |
+
|
| 64 |
+
return keep
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def encode(matched, priors, variances):
|
| 68 |
+
"""Encode the variances from the priorbox layers into the ground truth boxes
|
| 69 |
+
we have matched (based on jaccard overlap) with the prior boxes.
|
| 70 |
+
Args:
|
| 71 |
+
matched: (tensor) Coords of ground truth for each prior in point-form
|
| 72 |
+
Shape: [num_priors, 4].
|
| 73 |
+
priors: (tensor) Prior boxes in center-offset form
|
| 74 |
+
Shape: [num_priors,4].
|
| 75 |
+
variances: (list[float]) Variances of priorboxes
|
| 76 |
+
Return:
|
| 77 |
+
encoded boxes (tensor), Shape: [num_priors, 4]
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
# dist b/t match center and prior's center
|
| 81 |
+
g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
|
| 82 |
+
# encode variance
|
| 83 |
+
g_cxcy /= (variances[0] * priors[:, 2:])
|
| 84 |
+
# match wh / prior wh
|
| 85 |
+
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
|
| 86 |
+
g_wh = torch.log(g_wh) / variances[1]
|
| 87 |
+
# return target for smooth_l1_loss
|
| 88 |
+
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def decode(loc, priors, variances):
|
| 92 |
+
"""Decode locations from predictions using priors to undo
|
| 93 |
+
the encoding we did for offset regression at train time.
|
| 94 |
+
Args:
|
| 95 |
+
loc (tensor): location predictions for loc layers,
|
| 96 |
+
Shape: [num_priors,4]
|
| 97 |
+
priors (tensor): Prior boxes in center-offset form.
|
| 98 |
+
Shape: [num_priors,4].
|
| 99 |
+
variances: (list[float]) Variances of priorboxes
|
| 100 |
+
Return:
|
| 101 |
+
decoded bounding box predictions
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
boxes = torch.cat((
|
| 105 |
+
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
| 106 |
+
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
|
| 107 |
+
boxes[:, :2] -= boxes[:, 2:] / 2
|
| 108 |
+
boxes[:, 2:] += boxes[:, :2]
|
| 109 |
+
return boxes
|
| 110 |
+
|
| 111 |
+
def batch_decode(loc, priors, variances):
|
| 112 |
+
"""Decode locations from predictions using priors to undo
|
| 113 |
+
the encoding we did for offset regression at train time.
|
| 114 |
+
Args:
|
| 115 |
+
loc (tensor): location predictions for loc layers,
|
| 116 |
+
Shape: [num_priors,4]
|
| 117 |
+
priors (tensor): Prior boxes in center-offset form.
|
| 118 |
+
Shape: [num_priors,4].
|
| 119 |
+
variances: (list[float]) Variances of priorboxes
|
| 120 |
+
Return:
|
| 121 |
+
decoded bounding box predictions
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
boxes = torch.cat((
|
| 125 |
+
priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
|
| 126 |
+
priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
|
| 127 |
+
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
|
| 128 |
+
boxes[:, :, 2:] += boxes[:, :, :2]
|
| 129 |
+
return boxes
|
face_detection/detection/sfd/detect.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import cv2
|
| 7 |
+
import random
|
| 8 |
+
import datetime
|
| 9 |
+
import math
|
| 10 |
+
import argparse
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
import scipy.io as sio
|
| 14 |
+
import zipfile
|
| 15 |
+
from .net_s3fd import s3fd
|
| 16 |
+
from .bbox import *
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def detect(net, img, device):
|
| 20 |
+
img = img - np.array([104, 117, 123])
|
| 21 |
+
img = img.transpose(2, 0, 1)
|
| 22 |
+
img = img.reshape((1,) + img.shape)
|
| 23 |
+
|
| 24 |
+
if 'cuda' in device:
|
| 25 |
+
torch.backends.cudnn.benchmark = True
|
| 26 |
+
|
| 27 |
+
img = torch.from_numpy(img).float().to(device)
|
| 28 |
+
BB, CC, HH, WW = img.size()
|
| 29 |
+
with torch.no_grad():
|
| 30 |
+
olist = net(img)
|
| 31 |
+
|
| 32 |
+
bboxlist = []
|
| 33 |
+
for i in range(len(olist) // 2):
|
| 34 |
+
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
|
| 35 |
+
olist = [oelem.data.cpu() for oelem in olist]
|
| 36 |
+
for i in range(len(olist) // 2):
|
| 37 |
+
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
|
| 38 |
+
FB, FC, FH, FW = ocls.size() # feature map size
|
| 39 |
+
stride = 2**(i + 2) # 4,8,16,32,64,128
|
| 40 |
+
anchor = stride * 4
|
| 41 |
+
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
|
| 42 |
+
for Iindex, hindex, windex in poss:
|
| 43 |
+
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
|
| 44 |
+
score = ocls[0, 1, hindex, windex]
|
| 45 |
+
loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
|
| 46 |
+
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
|
| 47 |
+
variances = [0.1, 0.2]
|
| 48 |
+
box = decode(loc, priors, variances)
|
| 49 |
+
x1, y1, x2, y2 = box[0] * 1.0
|
| 50 |
+
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
|
| 51 |
+
bboxlist.append([x1, y1, x2, y2, score])
|
| 52 |
+
bboxlist = np.array(bboxlist)
|
| 53 |
+
if 0 == len(bboxlist):
|
| 54 |
+
bboxlist = np.zeros((1, 5))
|
| 55 |
+
|
| 56 |
+
return bboxlist
|
| 57 |
+
|
| 58 |
+
def batch_detect(net, imgs, device):
|
| 59 |
+
imgs = imgs - np.array([104, 117, 123])
|
| 60 |
+
imgs = imgs.transpose(0, 3, 1, 2)
|
| 61 |
+
|
| 62 |
+
if 'cuda' in device:
|
| 63 |
+
torch.backends.cudnn.benchmark = True
|
| 64 |
+
|
| 65 |
+
imgs = torch.from_numpy(imgs).float().to(device)
|
| 66 |
+
BB, CC, HH, WW = imgs.size()
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
olist = net(imgs)
|
| 69 |
+
|
| 70 |
+
bboxlist = []
|
| 71 |
+
for i in range(len(olist) // 2):
|
| 72 |
+
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
|
| 73 |
+
olist = [oelem.data.cpu() for oelem in olist]
|
| 74 |
+
for i in range(len(olist) // 2):
|
| 75 |
+
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
|
| 76 |
+
FB, FC, FH, FW = ocls.size() # feature map size
|
| 77 |
+
stride = 2**(i + 2) # 4,8,16,32,64,128
|
| 78 |
+
anchor = stride * 4
|
| 79 |
+
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
|
| 80 |
+
for Iindex, hindex, windex in poss:
|
| 81 |
+
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
|
| 82 |
+
score = ocls[:, 1, hindex, windex]
|
| 83 |
+
loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
|
| 84 |
+
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
|
| 85 |
+
variances = [0.1, 0.2]
|
| 86 |
+
box = batch_decode(loc, priors, variances)
|
| 87 |
+
box = box[:, 0] * 1.0
|
| 88 |
+
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
|
| 89 |
+
bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
|
| 90 |
+
bboxlist = np.array(bboxlist)
|
| 91 |
+
if 0 == len(bboxlist):
|
| 92 |
+
bboxlist = np.zeros((1, BB, 5))
|
| 93 |
+
|
| 94 |
+
return bboxlist
|
| 95 |
+
|
| 96 |
+
def flip_detect(net, img, device):
|
| 97 |
+
img = cv2.flip(img, 1)
|
| 98 |
+
b = detect(net, img, device)
|
| 99 |
+
|
| 100 |
+
bboxlist = np.zeros(b.shape)
|
| 101 |
+
bboxlist[:, 0] = img.shape[1] - b[:, 2]
|
| 102 |
+
bboxlist[:, 1] = b[:, 1]
|
| 103 |
+
bboxlist[:, 2] = img.shape[1] - b[:, 0]
|
| 104 |
+
bboxlist[:, 3] = b[:, 3]
|
| 105 |
+
bboxlist[:, 4] = b[:, 4]
|
| 106 |
+
return bboxlist
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def pts_to_bb(pts):
|
| 110 |
+
min_x, min_y = np.min(pts, axis=0)
|
| 111 |
+
max_x, max_y = np.max(pts, axis=0)
|
| 112 |
+
return np.array([min_x, min_y, max_x, max_y])
|
face_detection/detection/sfd/net_s3fd.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class L2Norm(nn.Module):
|
| 7 |
+
def __init__(self, n_channels, scale=1.0):
|
| 8 |
+
super(L2Norm, self).__init__()
|
| 9 |
+
self.n_channels = n_channels
|
| 10 |
+
self.scale = scale
|
| 11 |
+
self.eps = 1e-10
|
| 12 |
+
self.weight = nn.Parameter(torch.Tensor(self.n_channels))
|
| 13 |
+
self.weight.data *= 0.0
|
| 14 |
+
self.weight.data += self.scale
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
|
| 18 |
+
x = x / norm * self.weight.view(1, -1, 1, 1)
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class s3fd(nn.Module):
|
| 23 |
+
def __init__(self):
|
| 24 |
+
super(s3fd, self).__init__()
|
| 25 |
+
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
| 26 |
+
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
| 27 |
+
|
| 28 |
+
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
|
| 29 |
+
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
|
| 30 |
+
|
| 31 |
+
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
|
| 32 |
+
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
| 33 |
+
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
| 34 |
+
|
| 35 |
+
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
|
| 36 |
+
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
| 37 |
+
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
| 38 |
+
|
| 39 |
+
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
| 40 |
+
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
| 41 |
+
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
| 42 |
+
|
| 43 |
+
self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
|
| 44 |
+
self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
|
| 45 |
+
|
| 46 |
+
self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
|
| 47 |
+
self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
|
| 48 |
+
|
| 49 |
+
self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
|
| 50 |
+
self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
|
| 51 |
+
|
| 52 |
+
self.conv3_3_norm = L2Norm(256, scale=10)
|
| 53 |
+
self.conv4_3_norm = L2Norm(512, scale=8)
|
| 54 |
+
self.conv5_3_norm = L2Norm(512, scale=5)
|
| 55 |
+
|
| 56 |
+
self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
| 57 |
+
self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
| 58 |
+
self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
| 59 |
+
self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
| 60 |
+
self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
| 61 |
+
self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
| 62 |
+
|
| 63 |
+
self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
|
| 64 |
+
self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
|
| 65 |
+
self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
| 66 |
+
self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
| 67 |
+
self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
|
| 68 |
+
self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
h = F.relu(self.conv1_1(x))
|
| 72 |
+
h = F.relu(self.conv1_2(h))
|
| 73 |
+
h = F.max_pool2d(h, 2, 2)
|
| 74 |
+
|
| 75 |
+
h = F.relu(self.conv2_1(h))
|
| 76 |
+
h = F.relu(self.conv2_2(h))
|
| 77 |
+
h = F.max_pool2d(h, 2, 2)
|
| 78 |
+
|
| 79 |
+
h = F.relu(self.conv3_1(h))
|
| 80 |
+
h = F.relu(self.conv3_2(h))
|
| 81 |
+
h = F.relu(self.conv3_3(h))
|
| 82 |
+
f3_3 = h
|
| 83 |
+
h = F.max_pool2d(h, 2, 2)
|
| 84 |
+
|
| 85 |
+
h = F.relu(self.conv4_1(h))
|
| 86 |
+
h = F.relu(self.conv4_2(h))
|
| 87 |
+
h = F.relu(self.conv4_3(h))
|
| 88 |
+
f4_3 = h
|
| 89 |
+
h = F.max_pool2d(h, 2, 2)
|
| 90 |
+
|
| 91 |
+
h = F.relu(self.conv5_1(h))
|
| 92 |
+
h = F.relu(self.conv5_2(h))
|
| 93 |
+
h = F.relu(self.conv5_3(h))
|
| 94 |
+
f5_3 = h
|
| 95 |
+
h = F.max_pool2d(h, 2, 2)
|
| 96 |
+
|
| 97 |
+
h = F.relu(self.fc6(h))
|
| 98 |
+
h = F.relu(self.fc7(h))
|
| 99 |
+
ffc7 = h
|
| 100 |
+
h = F.relu(self.conv6_1(h))
|
| 101 |
+
h = F.relu(self.conv6_2(h))
|
| 102 |
+
f6_2 = h
|
| 103 |
+
h = F.relu(self.conv7_1(h))
|
| 104 |
+
h = F.relu(self.conv7_2(h))
|
| 105 |
+
f7_2 = h
|
| 106 |
+
|
| 107 |
+
f3_3 = self.conv3_3_norm(f3_3)
|
| 108 |
+
f4_3 = self.conv4_3_norm(f4_3)
|
| 109 |
+
f5_3 = self.conv5_3_norm(f5_3)
|
| 110 |
+
|
| 111 |
+
cls1 = self.conv3_3_norm_mbox_conf(f3_3)
|
| 112 |
+
reg1 = self.conv3_3_norm_mbox_loc(f3_3)
|
| 113 |
+
cls2 = self.conv4_3_norm_mbox_conf(f4_3)
|
| 114 |
+
reg2 = self.conv4_3_norm_mbox_loc(f4_3)
|
| 115 |
+
cls3 = self.conv5_3_norm_mbox_conf(f5_3)
|
| 116 |
+
reg3 = self.conv5_3_norm_mbox_loc(f5_3)
|
| 117 |
+
cls4 = self.fc7_mbox_conf(ffc7)
|
| 118 |
+
reg4 = self.fc7_mbox_loc(ffc7)
|
| 119 |
+
cls5 = self.conv6_2_mbox_conf(f6_2)
|
| 120 |
+
reg5 = self.conv6_2_mbox_loc(f6_2)
|
| 121 |
+
cls6 = self.conv7_2_mbox_conf(f7_2)
|
| 122 |
+
reg6 = self.conv7_2_mbox_loc(f7_2)
|
| 123 |
+
|
| 124 |
+
# max-out background label
|
| 125 |
+
chunk = torch.chunk(cls1, 4, 1)
|
| 126 |
+
bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
|
| 127 |
+
cls1 = torch.cat([bmax, chunk[3]], dim=1)
|
| 128 |
+
|
| 129 |
+
return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
|
face_detection/detection/sfd/s3fd.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7636d0c9d2a8f4759aef537cbcc25c5fa2eb2d5d80b1fada4dcc800e967cf381
|
| 3 |
+
size 133
|
face_detection/detection/sfd/sfd_detector.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
from torch.utils.model_zoo import load_url
|
| 4 |
+
|
| 5 |
+
from ..core import FaceDetector
|
| 6 |
+
|
| 7 |
+
from .net_s3fd import s3fd
|
| 8 |
+
from .bbox import *
|
| 9 |
+
from .detect import *
|
| 10 |
+
|
| 11 |
+
models_urls = {
|
| 12 |
+
's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SFDDetector(FaceDetector):
|
| 17 |
+
def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
|
| 18 |
+
super(SFDDetector, self).__init__(device, verbose)
|
| 19 |
+
|
| 20 |
+
# Initialise the face detector
|
| 21 |
+
if not os.path.isfile(path_to_detector):
|
| 22 |
+
model_weights = load_url(models_urls['s3fd'])
|
| 23 |
+
else:
|
| 24 |
+
model_weights = torch.load(path_to_detector)
|
| 25 |
+
|
| 26 |
+
self.face_detector = s3fd()
|
| 27 |
+
self.face_detector.load_state_dict(model_weights)
|
| 28 |
+
self.face_detector.to(device)
|
| 29 |
+
self.face_detector.eval()
|
| 30 |
+
|
| 31 |
+
def detect_from_image(self, tensor_or_path):
|
| 32 |
+
image = self.tensor_or_path_to_ndarray(tensor_or_path)
|
| 33 |
+
|
| 34 |
+
bboxlist = detect(self.face_detector, image, device=self.device)
|
| 35 |
+
keep = nms(bboxlist, 0.3)
|
| 36 |
+
bboxlist = bboxlist[keep, :]
|
| 37 |
+
bboxlist = [x for x in bboxlist if x[-1] > 0.5]
|
| 38 |
+
|
| 39 |
+
return bboxlist
|
| 40 |
+
|
| 41 |
+
def detect_from_batch(self, images):
|
| 42 |
+
bboxlists = batch_detect(self.face_detector, images, device=self.device)
|
| 43 |
+
keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
|
| 44 |
+
bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
|
| 45 |
+
bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
|
| 46 |
+
|
| 47 |
+
return bboxlists
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def reference_scale(self):
|
| 51 |
+
return 195
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def reference_x_shift(self):
|
| 55 |
+
return 0
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def reference_y_shift(self):
|
| 59 |
+
return 0
|
face_detection/models.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
|
| 8 |
+
"3x3 convolution with padding"
|
| 9 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3,
|
| 10 |
+
stride=strd, padding=padding, bias=bias)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ConvBlock(nn.Module):
|
| 14 |
+
def __init__(self, in_planes, out_planes):
|
| 15 |
+
super(ConvBlock, self).__init__()
|
| 16 |
+
self.bn1 = nn.BatchNorm2d(in_planes)
|
| 17 |
+
self.conv1 = conv3x3(in_planes, int(out_planes / 2))
|
| 18 |
+
self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
|
| 19 |
+
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
|
| 20 |
+
self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
|
| 21 |
+
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
|
| 22 |
+
|
| 23 |
+
if in_planes != out_planes:
|
| 24 |
+
self.downsample = nn.Sequential(
|
| 25 |
+
nn.BatchNorm2d(in_planes),
|
| 26 |
+
nn.ReLU(True),
|
| 27 |
+
nn.Conv2d(in_planes, out_planes,
|
| 28 |
+
kernel_size=1, stride=1, bias=False),
|
| 29 |
+
)
|
| 30 |
+
else:
|
| 31 |
+
self.downsample = None
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
residual = x
|
| 35 |
+
|
| 36 |
+
out1 = self.bn1(x)
|
| 37 |
+
out1 = F.relu(out1, True)
|
| 38 |
+
out1 = self.conv1(out1)
|
| 39 |
+
|
| 40 |
+
out2 = self.bn2(out1)
|
| 41 |
+
out2 = F.relu(out2, True)
|
| 42 |
+
out2 = self.conv2(out2)
|
| 43 |
+
|
| 44 |
+
out3 = self.bn3(out2)
|
| 45 |
+
out3 = F.relu(out3, True)
|
| 46 |
+
out3 = self.conv3(out3)
|
| 47 |
+
|
| 48 |
+
out3 = torch.cat((out1, out2, out3), 1)
|
| 49 |
+
|
| 50 |
+
if self.downsample is not None:
|
| 51 |
+
residual = self.downsample(residual)
|
| 52 |
+
|
| 53 |
+
out3 += residual
|
| 54 |
+
|
| 55 |
+
return out3
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class Bottleneck(nn.Module):
|
| 59 |
+
|
| 60 |
+
expansion = 4
|
| 61 |
+
|
| 62 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 63 |
+
super(Bottleneck, self).__init__()
|
| 64 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 65 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 66 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
| 67 |
+
padding=1, bias=False)
|
| 68 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 69 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
| 70 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
| 71 |
+
self.relu = nn.ReLU(inplace=True)
|
| 72 |
+
self.downsample = downsample
|
| 73 |
+
self.stride = stride
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
residual = x
|
| 77 |
+
|
| 78 |
+
out = self.conv1(x)
|
| 79 |
+
out = self.bn1(out)
|
| 80 |
+
out = self.relu(out)
|
| 81 |
+
|
| 82 |
+
out = self.conv2(out)
|
| 83 |
+
out = self.bn2(out)
|
| 84 |
+
out = self.relu(out)
|
| 85 |
+
|
| 86 |
+
out = self.conv3(out)
|
| 87 |
+
out = self.bn3(out)
|
| 88 |
+
|
| 89 |
+
if self.downsample is not None:
|
| 90 |
+
residual = self.downsample(x)
|
| 91 |
+
|
| 92 |
+
out += residual
|
| 93 |
+
out = self.relu(out)
|
| 94 |
+
|
| 95 |
+
return out
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class HourGlass(nn.Module):
|
| 99 |
+
def __init__(self, num_modules, depth, num_features):
|
| 100 |
+
super(HourGlass, self).__init__()
|
| 101 |
+
self.num_modules = num_modules
|
| 102 |
+
self.depth = depth
|
| 103 |
+
self.features = num_features
|
| 104 |
+
|
| 105 |
+
self._generate_network(self.depth)
|
| 106 |
+
|
| 107 |
+
def _generate_network(self, level):
|
| 108 |
+
self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
|
| 109 |
+
|
| 110 |
+
self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
|
| 111 |
+
|
| 112 |
+
if level > 1:
|
| 113 |
+
self._generate_network(level - 1)
|
| 114 |
+
else:
|
| 115 |
+
self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
|
| 116 |
+
|
| 117 |
+
self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
|
| 118 |
+
|
| 119 |
+
def _forward(self, level, inp):
|
| 120 |
+
# Upper branch
|
| 121 |
+
up1 = inp
|
| 122 |
+
up1 = self._modules['b1_' + str(level)](up1)
|
| 123 |
+
|
| 124 |
+
# Lower branch
|
| 125 |
+
low1 = F.avg_pool2d(inp, 2, stride=2)
|
| 126 |
+
low1 = self._modules['b2_' + str(level)](low1)
|
| 127 |
+
|
| 128 |
+
if level > 1:
|
| 129 |
+
low2 = self._forward(level - 1, low1)
|
| 130 |
+
else:
|
| 131 |
+
low2 = low1
|
| 132 |
+
low2 = self._modules['b2_plus_' + str(level)](low2)
|
| 133 |
+
|
| 134 |
+
low3 = low2
|
| 135 |
+
low3 = self._modules['b3_' + str(level)](low3)
|
| 136 |
+
|
| 137 |
+
up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
|
| 138 |
+
|
| 139 |
+
return up1 + up2
|
| 140 |
+
|
| 141 |
+
def forward(self, x):
|
| 142 |
+
return self._forward(self.depth, x)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class FAN(nn.Module):
|
| 146 |
+
|
| 147 |
+
def __init__(self, num_modules=1):
|
| 148 |
+
super(FAN, self).__init__()
|
| 149 |
+
self.num_modules = num_modules
|
| 150 |
+
|
| 151 |
+
# Base part
|
| 152 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
| 153 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 154 |
+
self.conv2 = ConvBlock(64, 128)
|
| 155 |
+
self.conv3 = ConvBlock(128, 128)
|
| 156 |
+
self.conv4 = ConvBlock(128, 256)
|
| 157 |
+
|
| 158 |
+
# Stacking part
|
| 159 |
+
for hg_module in range(self.num_modules):
|
| 160 |
+
self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
|
| 161 |
+
self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
|
| 162 |
+
self.add_module('conv_last' + str(hg_module),
|
| 163 |
+
nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
| 164 |
+
self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
|
| 165 |
+
self.add_module('l' + str(hg_module), nn.Conv2d(256,
|
| 166 |
+
68, kernel_size=1, stride=1, padding=0))
|
| 167 |
+
|
| 168 |
+
if hg_module < self.num_modules - 1:
|
| 169 |
+
self.add_module(
|
| 170 |
+
'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
| 171 |
+
self.add_module('al' + str(hg_module), nn.Conv2d(68,
|
| 172 |
+
256, kernel_size=1, stride=1, padding=0))
|
| 173 |
+
|
| 174 |
+
def forward(self, x):
|
| 175 |
+
x = F.relu(self.bn1(self.conv1(x)), True)
|
| 176 |
+
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
|
| 177 |
+
x = self.conv3(x)
|
| 178 |
+
x = self.conv4(x)
|
| 179 |
+
|
| 180 |
+
previous = x
|
| 181 |
+
|
| 182 |
+
outputs = []
|
| 183 |
+
for i in range(self.num_modules):
|
| 184 |
+
hg = self._modules['m' + str(i)](previous)
|
| 185 |
+
|
| 186 |
+
ll = hg
|
| 187 |
+
ll = self._modules['top_m_' + str(i)](ll)
|
| 188 |
+
|
| 189 |
+
ll = F.relu(self._modules['bn_end' + str(i)]
|
| 190 |
+
(self._modules['conv_last' + str(i)](ll)), True)
|
| 191 |
+
|
| 192 |
+
# Predict heatmaps
|
| 193 |
+
tmp_out = self._modules['l' + str(i)](ll)
|
| 194 |
+
outputs.append(tmp_out)
|
| 195 |
+
|
| 196 |
+
if i < self.num_modules - 1:
|
| 197 |
+
ll = self._modules['bl' + str(i)](ll)
|
| 198 |
+
tmp_out_ = self._modules['al' + str(i)](tmp_out)
|
| 199 |
+
previous = previous + ll + tmp_out_
|
| 200 |
+
|
| 201 |
+
return outputs
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class ResNetDepth(nn.Module):
|
| 205 |
+
|
| 206 |
+
def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
|
| 207 |
+
self.inplanes = 64
|
| 208 |
+
super(ResNetDepth, self).__init__()
|
| 209 |
+
self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
|
| 210 |
+
bias=False)
|
| 211 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 212 |
+
self.relu = nn.ReLU(inplace=True)
|
| 213 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 214 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 215 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 216 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 217 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
| 218 |
+
self.avgpool = nn.AvgPool2d(7)
|
| 219 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
| 220 |
+
|
| 221 |
+
for m in self.modules():
|
| 222 |
+
if isinstance(m, nn.Conv2d):
|
| 223 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 224 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 225 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 226 |
+
m.weight.data.fill_(1)
|
| 227 |
+
m.bias.data.zero_()
|
| 228 |
+
|
| 229 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 230 |
+
downsample = None
|
| 231 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 232 |
+
downsample = nn.Sequential(
|
| 233 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 234 |
+
kernel_size=1, stride=stride, bias=False),
|
| 235 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
layers = []
|
| 239 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 240 |
+
self.inplanes = planes * block.expansion
|
| 241 |
+
for i in range(1, blocks):
|
| 242 |
+
layers.append(block(self.inplanes, planes))
|
| 243 |
+
|
| 244 |
+
return nn.Sequential(*layers)
|
| 245 |
+
|
| 246 |
+
def forward(self, x):
|
| 247 |
+
x = self.conv1(x)
|
| 248 |
+
x = self.bn1(x)
|
| 249 |
+
x = self.relu(x)
|
| 250 |
+
x = self.maxpool(x)
|
| 251 |
+
|
| 252 |
+
x = self.layer1(x)
|
| 253 |
+
x = self.layer2(x)
|
| 254 |
+
x = self.layer3(x)
|
| 255 |
+
x = self.layer4(x)
|
| 256 |
+
|
| 257 |
+
x = self.avgpool(x)
|
| 258 |
+
x = x.view(x.size(0), -1)
|
| 259 |
+
x = self.fc(x)
|
| 260 |
+
|
| 261 |
+
return x
|
face_detection/utils.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
import torch
|
| 6 |
+
import math
|
| 7 |
+
import numpy as np
|
| 8 |
+
import cv2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _gaussian(
|
| 12 |
+
size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
|
| 13 |
+
height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
|
| 14 |
+
mean_vert=0.5):
|
| 15 |
+
# handle some defaults
|
| 16 |
+
if width is None:
|
| 17 |
+
width = size
|
| 18 |
+
if height is None:
|
| 19 |
+
height = size
|
| 20 |
+
if sigma_horz is None:
|
| 21 |
+
sigma_horz = sigma
|
| 22 |
+
if sigma_vert is None:
|
| 23 |
+
sigma_vert = sigma
|
| 24 |
+
center_x = mean_horz * width + 0.5
|
| 25 |
+
center_y = mean_vert * height + 0.5
|
| 26 |
+
gauss = np.empty((height, width), dtype=np.float32)
|
| 27 |
+
# generate kernel
|
| 28 |
+
for i in range(height):
|
| 29 |
+
for j in range(width):
|
| 30 |
+
gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
|
| 31 |
+
sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
|
| 32 |
+
if normalize:
|
| 33 |
+
gauss = gauss / np.sum(gauss)
|
| 34 |
+
return gauss
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def draw_gaussian(image, point, sigma):
|
| 38 |
+
# Check if the gaussian is inside
|
| 39 |
+
ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
|
| 40 |
+
br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
|
| 41 |
+
if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
|
| 42 |
+
return image
|
| 43 |
+
size = 6 * sigma + 1
|
| 44 |
+
g = _gaussian(size)
|
| 45 |
+
g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
|
| 46 |
+
g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
|
| 47 |
+
img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
|
| 48 |
+
img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
|
| 49 |
+
assert (g_x[0] > 0 and g_y[1] > 0)
|
| 50 |
+
image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
|
| 51 |
+
] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
|
| 52 |
+
image[image > 1] = 1
|
| 53 |
+
return image
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def transform(point, center, scale, resolution, invert=False):
|
| 57 |
+
"""Generate and affine transformation matrix.
|
| 58 |
+
|
| 59 |
+
Given a set of points, a center, a scale and a targer resolution, the
|
| 60 |
+
function generates and affine transformation matrix. If invert is ``True``
|
| 61 |
+
it will produce the inverse transformation.
|
| 62 |
+
|
| 63 |
+
Arguments:
|
| 64 |
+
point {torch.tensor} -- the input 2D point
|
| 65 |
+
center {torch.tensor or numpy.array} -- the center around which to perform the transformations
|
| 66 |
+
scale {float} -- the scale of the face/object
|
| 67 |
+
resolution {float} -- the output resolution
|
| 68 |
+
|
| 69 |
+
Keyword Arguments:
|
| 70 |
+
invert {bool} -- define wherever the function should produce the direct or the
|
| 71 |
+
inverse transformation matrix (default: {False})
|
| 72 |
+
"""
|
| 73 |
+
_pt = torch.ones(3)
|
| 74 |
+
_pt[0] = point[0]
|
| 75 |
+
_pt[1] = point[1]
|
| 76 |
+
|
| 77 |
+
h = 200.0 * scale
|
| 78 |
+
t = torch.eye(3)
|
| 79 |
+
t[0, 0] = resolution / h
|
| 80 |
+
t[1, 1] = resolution / h
|
| 81 |
+
t[0, 2] = resolution * (-center[0] / h + 0.5)
|
| 82 |
+
t[1, 2] = resolution * (-center[1] / h + 0.5)
|
| 83 |
+
|
| 84 |
+
if invert:
|
| 85 |
+
t = torch.inverse(t)
|
| 86 |
+
|
| 87 |
+
new_point = (torch.matmul(t, _pt))[0:2]
|
| 88 |
+
|
| 89 |
+
return new_point.int()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def crop(image, center, scale, resolution=256.0):
|
| 93 |
+
"""Center crops an image or set of heatmaps
|
| 94 |
+
|
| 95 |
+
Arguments:
|
| 96 |
+
image {numpy.array} -- an rgb image
|
| 97 |
+
center {numpy.array} -- the center of the object, usually the same as of the bounding box
|
| 98 |
+
scale {float} -- scale of the face
|
| 99 |
+
|
| 100 |
+
Keyword Arguments:
|
| 101 |
+
resolution {float} -- the size of the output cropped image (default: {256.0})
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
[type] -- [description]
|
| 105 |
+
""" # Crop around the center point
|
| 106 |
+
""" Crops the image around the center. Input is expected to be an np.ndarray """
|
| 107 |
+
ul = transform([1, 1], center, scale, resolution, True)
|
| 108 |
+
br = transform([resolution, resolution], center, scale, resolution, True)
|
| 109 |
+
# pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
|
| 110 |
+
if image.ndim > 2:
|
| 111 |
+
newDim = np.array([br[1] - ul[1], br[0] - ul[0],
|
| 112 |
+
image.shape[2]], dtype=np.int32)
|
| 113 |
+
newImg = np.zeros(newDim, dtype=np.uint8)
|
| 114 |
+
else:
|
| 115 |
+
newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
|
| 116 |
+
newImg = np.zeros(newDim, dtype=np.uint8)
|
| 117 |
+
ht = image.shape[0]
|
| 118 |
+
wd = image.shape[1]
|
| 119 |
+
newX = np.array(
|
| 120 |
+
[max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
|
| 121 |
+
newY = np.array(
|
| 122 |
+
[max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
|
| 123 |
+
oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
|
| 124 |
+
oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
|
| 125 |
+
newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
|
| 126 |
+
] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
|
| 127 |
+
newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
|
| 128 |
+
interpolation=cv2.INTER_LINEAR)
|
| 129 |
+
return newImg
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def get_preds_fromhm(hm, center=None, scale=None):
|
| 133 |
+
"""Obtain (x,y) coordinates given a set of N heatmaps. If the center
|
| 134 |
+
and the scale is provided the function will return the points also in
|
| 135 |
+
the original coordinate frame.
|
| 136 |
+
|
| 137 |
+
Arguments:
|
| 138 |
+
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
|
| 139 |
+
|
| 140 |
+
Keyword Arguments:
|
| 141 |
+
center {torch.tensor} -- the center of the bounding box (default: {None})
|
| 142 |
+
scale {float} -- face scale (default: {None})
|
| 143 |
+
"""
|
| 144 |
+
max, idx = torch.max(
|
| 145 |
+
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
| 146 |
+
idx += 1
|
| 147 |
+
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
| 148 |
+
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
| 149 |
+
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
| 150 |
+
|
| 151 |
+
for i in range(preds.size(0)):
|
| 152 |
+
for j in range(preds.size(1)):
|
| 153 |
+
hm_ = hm[i, j, :]
|
| 154 |
+
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
| 155 |
+
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
| 156 |
+
diff = torch.FloatTensor(
|
| 157 |
+
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
| 158 |
+
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
| 159 |
+
preds[i, j].add_(diff.sign_().mul_(.25))
|
| 160 |
+
|
| 161 |
+
preds.add_(-.5)
|
| 162 |
+
|
| 163 |
+
preds_orig = torch.zeros(preds.size())
|
| 164 |
+
if center is not None and scale is not None:
|
| 165 |
+
for i in range(hm.size(0)):
|
| 166 |
+
for j in range(hm.size(1)):
|
| 167 |
+
preds_orig[i, j] = transform(
|
| 168 |
+
preds[i, j], center, scale, hm.size(2), True)
|
| 169 |
+
|
| 170 |
+
return preds, preds_orig
|
| 171 |
+
|
| 172 |
+
def get_preds_fromhm_batch(hm, centers=None, scales=None):
|
| 173 |
+
"""Obtain (x,y) coordinates given a set of N heatmaps. If the centers
|
| 174 |
+
and the scales is provided the function will return the points also in
|
| 175 |
+
the original coordinate frame.
|
| 176 |
+
|
| 177 |
+
Arguments:
|
| 178 |
+
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
|
| 179 |
+
|
| 180 |
+
Keyword Arguments:
|
| 181 |
+
centers {torch.tensor} -- the centers of the bounding box (default: {None})
|
| 182 |
+
scales {float} -- face scales (default: {None})
|
| 183 |
+
"""
|
| 184 |
+
max, idx = torch.max(
|
| 185 |
+
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
| 186 |
+
idx += 1
|
| 187 |
+
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
| 188 |
+
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
| 189 |
+
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
| 190 |
+
|
| 191 |
+
for i in range(preds.size(0)):
|
| 192 |
+
for j in range(preds.size(1)):
|
| 193 |
+
hm_ = hm[i, j, :]
|
| 194 |
+
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
| 195 |
+
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
| 196 |
+
diff = torch.FloatTensor(
|
| 197 |
+
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
| 198 |
+
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
| 199 |
+
preds[i, j].add_(diff.sign_().mul_(.25))
|
| 200 |
+
|
| 201 |
+
preds.add_(-.5)
|
| 202 |
+
|
| 203 |
+
preds_orig = torch.zeros(preds.size())
|
| 204 |
+
if centers is not None and scales is not None:
|
| 205 |
+
for i in range(hm.size(0)):
|
| 206 |
+
for j in range(hm.size(1)):
|
| 207 |
+
preds_orig[i, j] = transform(
|
| 208 |
+
preds[i, j], centers[i], scales[i], hm.size(2), True)
|
| 209 |
+
|
| 210 |
+
return preds, preds_orig
|
| 211 |
+
|
| 212 |
+
def shuffle_lr(parts, pairs=None):
|
| 213 |
+
"""Shuffle the points left-right according to the axis of symmetry
|
| 214 |
+
of the object.
|
| 215 |
+
|
| 216 |
+
Arguments:
|
| 217 |
+
parts {torch.tensor} -- a 3D or 4D object containing the
|
| 218 |
+
heatmaps.
|
| 219 |
+
|
| 220 |
+
Keyword Arguments:
|
| 221 |
+
pairs {list of integers} -- [order of the flipped points] (default: {None})
|
| 222 |
+
"""
|
| 223 |
+
if pairs is None:
|
| 224 |
+
pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
|
| 225 |
+
26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
|
| 226 |
+
34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
|
| 227 |
+
40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
|
| 228 |
+
62, 61, 60, 67, 66, 65]
|
| 229 |
+
if parts.ndimension() == 3:
|
| 230 |
+
parts = parts[pairs, ...]
|
| 231 |
+
else:
|
| 232 |
+
parts = parts[:, pairs, ...]
|
| 233 |
+
|
| 234 |
+
return parts
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def flip(tensor, is_label=False):
|
| 238 |
+
"""Flip an image or a set of heatmaps left-right
|
| 239 |
+
|
| 240 |
+
Arguments:
|
| 241 |
+
tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
|
| 242 |
+
|
| 243 |
+
Keyword Arguments:
|
| 244 |
+
is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
|
| 245 |
+
"""
|
| 246 |
+
if not torch.is_tensor(tensor):
|
| 247 |
+
tensor = torch.from_numpy(tensor)
|
| 248 |
+
|
| 249 |
+
if is_label:
|
| 250 |
+
tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
|
| 251 |
+
else:
|
| 252 |
+
tensor = tensor.flip(tensor.ndimension() - 1)
|
| 253 |
+
|
| 254 |
+
return tensor
|
| 255 |
+
|
| 256 |
+
# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def appdata_dir(appname=None, roaming=False):
|
| 260 |
+
""" appdata_dir(appname=None, roaming=False)
|
| 261 |
+
|
| 262 |
+
Get the path to the application directory, where applications are allowed
|
| 263 |
+
to write user specific files (e.g. configurations). For non-user specific
|
| 264 |
+
data, consider using common_appdata_dir().
|
| 265 |
+
If appname is given, a subdir is appended (and created if necessary).
|
| 266 |
+
If roaming is True, will prefer a roaming directory (Windows Vista/7).
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
# Define default user directory
|
| 270 |
+
userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
|
| 271 |
+
if userDir is None:
|
| 272 |
+
userDir = os.path.expanduser('~')
|
| 273 |
+
if not os.path.isdir(userDir): # pragma: no cover
|
| 274 |
+
userDir = '/var/tmp' # issue #54
|
| 275 |
+
|
| 276 |
+
# Get system app data dir
|
| 277 |
+
path = None
|
| 278 |
+
if sys.platform.startswith('win'):
|
| 279 |
+
path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
|
| 280 |
+
path = (path2 or path1) if roaming else (path1 or path2)
|
| 281 |
+
elif sys.platform.startswith('darwin'):
|
| 282 |
+
path = os.path.join(userDir, 'Library', 'Application Support')
|
| 283 |
+
# On Linux and as fallback
|
| 284 |
+
if not (path and os.path.isdir(path)):
|
| 285 |
+
path = userDir
|
| 286 |
+
|
| 287 |
+
# Maybe we should store things local to the executable (in case of a
|
| 288 |
+
# portable distro or a frozen application that wants to be portable)
|
| 289 |
+
prefix = sys.prefix
|
| 290 |
+
if getattr(sys, 'frozen', None):
|
| 291 |
+
prefix = os.path.abspath(os.path.dirname(sys.executable))
|
| 292 |
+
for reldir in ('settings', '../settings'):
|
| 293 |
+
localpath = os.path.abspath(os.path.join(prefix, reldir))
|
| 294 |
+
if os.path.isdir(localpath): # pragma: no cover
|
| 295 |
+
try:
|
| 296 |
+
open(os.path.join(localpath, 'test.write'), 'wb').close()
|
| 297 |
+
os.remove(os.path.join(localpath, 'test.write'))
|
| 298 |
+
except IOError:
|
| 299 |
+
pass # We cannot write in this directory
|
| 300 |
+
else:
|
| 301 |
+
path = localpath
|
| 302 |
+
break
|
| 303 |
+
|
| 304 |
+
# Get path specific for this app
|
| 305 |
+
if appname:
|
| 306 |
+
if path == userDir:
|
| 307 |
+
appname = '.' + appname.lstrip('.') # Make it a hidden directory
|
| 308 |
+
path = os.path.join(path, appname)
|
| 309 |
+
if not os.path.isdir(path): # pragma: no cover
|
| 310 |
+
os.mkdir(path)
|
| 311 |
+
|
| 312 |
+
# Done
|
| 313 |
+
return path
|
hparams.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from glob import glob
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
def get_image_list(data_root, split):
|
| 5 |
+
filelist = []
|
| 6 |
+
|
| 7 |
+
with open('filelists/{}.txt'.format(split)) as f:
|
| 8 |
+
for line in f:
|
| 9 |
+
line = line.strip()
|
| 10 |
+
if ' ' in line: line = line.split()[0]
|
| 11 |
+
filelist.append(os.path.join(data_root, line))
|
| 12 |
+
|
| 13 |
+
return filelist
|
| 14 |
+
|
| 15 |
+
class HParams:
|
| 16 |
+
def __init__(self, **kwargs):
|
| 17 |
+
self.data = {}
|
| 18 |
+
|
| 19 |
+
for key, value in kwargs.items():
|
| 20 |
+
self.data[key] = value
|
| 21 |
+
|
| 22 |
+
def __getattr__(self, key):
|
| 23 |
+
if key not in self.data:
|
| 24 |
+
raise AttributeError("'HParams' object has no attribute %s" % key)
|
| 25 |
+
return self.data[key]
|
| 26 |
+
|
| 27 |
+
def set_hparam(self, key, value):
|
| 28 |
+
self.data[key] = value
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Default hyperparameters
|
| 32 |
+
hparams = HParams(
|
| 33 |
+
num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
|
| 34 |
+
# network
|
| 35 |
+
rescale=True, # Whether to rescale audio prior to preprocessing
|
| 36 |
+
rescaling_max=0.9, # Rescaling value
|
| 37 |
+
|
| 38 |
+
# Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
|
| 39 |
+
# It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
|
| 40 |
+
# Does not work if n_ffit is not multiple of hop_size!!
|
| 41 |
+
use_lws=False,
|
| 42 |
+
|
| 43 |
+
n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
|
| 44 |
+
hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
|
| 45 |
+
win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
|
| 46 |
+
sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i <filename>)
|
| 47 |
+
|
| 48 |
+
frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
|
| 49 |
+
|
| 50 |
+
# Mel and Linear spectrograms normalization/scaling and clipping
|
| 51 |
+
signal_normalization=True,
|
| 52 |
+
# Whether to normalize mel spectrograms to some predefined range (following below parameters)
|
| 53 |
+
allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
|
| 54 |
+
symmetric_mels=True,
|
| 55 |
+
# Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
|
| 56 |
+
# faster and cleaner convergence)
|
| 57 |
+
max_abs_value=4.,
|
| 58 |
+
# max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
|
| 59 |
+
# be too big to avoid gradient explosion,
|
| 60 |
+
# not too small for fast convergence)
|
| 61 |
+
# Contribution by @begeekmyfriend
|
| 62 |
+
# Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
|
| 63 |
+
# levels. Also allows for better G&L phase reconstruction)
|
| 64 |
+
preemphasize=True, # whether to apply filter
|
| 65 |
+
preemphasis=0.97, # filter coefficient.
|
| 66 |
+
|
| 67 |
+
# Limits
|
| 68 |
+
min_level_db=-100,
|
| 69 |
+
ref_level_db=20,
|
| 70 |
+
fmin=55,
|
| 71 |
+
# Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
|
| 72 |
+
# test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
|
| 73 |
+
fmax=7600, # To be increased/reduced depending on data.
|
| 74 |
+
|
| 75 |
+
###################### Our training parameters #################################
|
| 76 |
+
img_size=96,
|
| 77 |
+
fps=25,
|
| 78 |
+
|
| 79 |
+
batch_size=16,
|
| 80 |
+
initial_learning_rate=1e-4,
|
| 81 |
+
nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
|
| 82 |
+
num_workers=16,
|
| 83 |
+
checkpoint_interval=3000,
|
| 84 |
+
eval_interval=3000,
|
| 85 |
+
save_optimizer_state=True,
|
| 86 |
+
|
| 87 |
+
syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
|
| 88 |
+
syncnet_batch_size=64,
|
| 89 |
+
syncnet_lr=1e-4,
|
| 90 |
+
syncnet_eval_interval=10000,
|
| 91 |
+
syncnet_checkpoint_interval=10000,
|
| 92 |
+
|
| 93 |
+
disc_wt=0.07,
|
| 94 |
+
disc_initial_learning_rate=1e-4,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def hparams_debug_string():
|
| 99 |
+
values = hparams.values()
|
| 100 |
+
hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"]
|
| 101 |
+
return "Hyperparameters:\n" + "\n".join(hp)
|
hq_wav2lip_train.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os.path import dirname, join, basename, isfile
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
|
| 4 |
+
from models import SyncNet_color as SyncNet
|
| 5 |
+
from models import Wav2Lip, Wav2Lip_disc_qual
|
| 6 |
+
import audio
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch.nn import functional as F
|
| 11 |
+
from torch import optim
|
| 12 |
+
import torch.backends.cudnn as cudnn
|
| 13 |
+
from torch.utils import data as data_utils
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
from glob import glob
|
| 17 |
+
|
| 18 |
+
import os, random, cv2, argparse
|
| 19 |
+
from hparams import hparams, get_image_list
|
| 20 |
+
|
| 21 |
+
parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model WITH the visual quality discriminator')
|
| 22 |
+
|
| 23 |
+
parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True, type=str)
|
| 24 |
+
|
| 25 |
+
parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
|
| 26 |
+
parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True, type=str)
|
| 27 |
+
|
| 28 |
+
parser.add_argument('--checkpoint_path', help='Resume generator from this checkpoint', default=None, type=str)
|
| 29 |
+
parser.add_argument('--disc_checkpoint_path', help='Resume quality disc from this checkpoint', default=None, type=str)
|
| 30 |
+
|
| 31 |
+
args = parser.parse_args()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
global_step = 0
|
| 35 |
+
global_epoch = 0
|
| 36 |
+
use_cuda = torch.cuda.is_available()
|
| 37 |
+
print('use_cuda: {}'.format(use_cuda))
|
| 38 |
+
|
| 39 |
+
syncnet_T = 5
|
| 40 |
+
syncnet_mel_step_size = 16
|
| 41 |
+
|
| 42 |
+
class Dataset(object):
|
| 43 |
+
def __init__(self, split):
|
| 44 |
+
self.all_videos = get_image_list(args.data_root, split)
|
| 45 |
+
|
| 46 |
+
def get_frame_id(self, frame):
|
| 47 |
+
return int(basename(frame).split('.')[0])
|
| 48 |
+
|
| 49 |
+
def get_window(self, start_frame):
|
| 50 |
+
start_id = self.get_frame_id(start_frame)
|
| 51 |
+
vidname = dirname(start_frame)
|
| 52 |
+
|
| 53 |
+
window_fnames = []
|
| 54 |
+
for frame_id in range(start_id, start_id + syncnet_T):
|
| 55 |
+
frame = join(vidname, '{}.jpg'.format(frame_id))
|
| 56 |
+
if not isfile(frame):
|
| 57 |
+
return None
|
| 58 |
+
window_fnames.append(frame)
|
| 59 |
+
return window_fnames
|
| 60 |
+
|
| 61 |
+
def read_window(self, window_fnames):
|
| 62 |
+
if window_fnames is None: return None
|
| 63 |
+
window = []
|
| 64 |
+
for fname in window_fnames:
|
| 65 |
+
img = cv2.imread(fname)
|
| 66 |
+
if img is None:
|
| 67 |
+
return None
|
| 68 |
+
try:
|
| 69 |
+
img = cv2.resize(img, (hparams.img_size, hparams.img_size))
|
| 70 |
+
except Exception as e:
|
| 71 |
+
return None
|
| 72 |
+
|
| 73 |
+
window.append(img)
|
| 74 |
+
|
| 75 |
+
return window
|
| 76 |
+
|
| 77 |
+
def crop_audio_window(self, spec, start_frame):
|
| 78 |
+
if type(start_frame) == int:
|
| 79 |
+
start_frame_num = start_frame
|
| 80 |
+
else:
|
| 81 |
+
start_frame_num = self.get_frame_id(start_frame)
|
| 82 |
+
start_idx = int(80. * (start_frame_num / float(hparams.fps)))
|
| 83 |
+
|
| 84 |
+
end_idx = start_idx + syncnet_mel_step_size
|
| 85 |
+
|
| 86 |
+
return spec[start_idx : end_idx, :]
|
| 87 |
+
|
| 88 |
+
def get_segmented_mels(self, spec, start_frame):
|
| 89 |
+
mels = []
|
| 90 |
+
assert syncnet_T == 5
|
| 91 |
+
start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing
|
| 92 |
+
if start_frame_num - 2 < 0: return None
|
| 93 |
+
for i in range(start_frame_num, start_frame_num + syncnet_T):
|
| 94 |
+
m = self.crop_audio_window(spec, i - 2)
|
| 95 |
+
if m.shape[0] != syncnet_mel_step_size:
|
| 96 |
+
return None
|
| 97 |
+
mels.append(m.T)
|
| 98 |
+
|
| 99 |
+
mels = np.asarray(mels)
|
| 100 |
+
|
| 101 |
+
return mels
|
| 102 |
+
|
| 103 |
+
def prepare_window(self, window):
|
| 104 |
+
# 3 x T x H x W
|
| 105 |
+
x = np.asarray(window) / 255.
|
| 106 |
+
x = np.transpose(x, (3, 0, 1, 2))
|
| 107 |
+
|
| 108 |
+
return x
|
| 109 |
+
|
| 110 |
+
def __len__(self):
|
| 111 |
+
return len(self.all_videos)
|
| 112 |
+
|
| 113 |
+
def __getitem__(self, idx):
|
| 114 |
+
while 1:
|
| 115 |
+
idx = random.randint(0, len(self.all_videos) - 1)
|
| 116 |
+
vidname = self.all_videos[idx]
|
| 117 |
+
img_names = list(glob(join(vidname, '*.jpg')))
|
| 118 |
+
if len(img_names) <= 3 * syncnet_T:
|
| 119 |
+
continue
|
| 120 |
+
|
| 121 |
+
img_name = random.choice(img_names)
|
| 122 |
+
wrong_img_name = random.choice(img_names)
|
| 123 |
+
while wrong_img_name == img_name:
|
| 124 |
+
wrong_img_name = random.choice(img_names)
|
| 125 |
+
|
| 126 |
+
window_fnames = self.get_window(img_name)
|
| 127 |
+
wrong_window_fnames = self.get_window(wrong_img_name)
|
| 128 |
+
if window_fnames is None or wrong_window_fnames is None:
|
| 129 |
+
continue
|
| 130 |
+
|
| 131 |
+
window = self.read_window(window_fnames)
|
| 132 |
+
if window is None:
|
| 133 |
+
continue
|
| 134 |
+
|
| 135 |
+
wrong_window = self.read_window(wrong_window_fnames)
|
| 136 |
+
if wrong_window is None:
|
| 137 |
+
continue
|
| 138 |
+
|
| 139 |
+
try:
|
| 140 |
+
wavpath = join(vidname, "audio.wav")
|
| 141 |
+
wav = audio.load_wav(wavpath, hparams.sample_rate)
|
| 142 |
+
|
| 143 |
+
orig_mel = audio.melspectrogram(wav).T
|
| 144 |
+
except Exception as e:
|
| 145 |
+
continue
|
| 146 |
+
|
| 147 |
+
mel = self.crop_audio_window(orig_mel.copy(), img_name)
|
| 148 |
+
|
| 149 |
+
if (mel.shape[0] != syncnet_mel_step_size):
|
| 150 |
+
continue
|
| 151 |
+
|
| 152 |
+
indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name)
|
| 153 |
+
if indiv_mels is None: continue
|
| 154 |
+
|
| 155 |
+
window = self.prepare_window(window)
|
| 156 |
+
y = window.copy()
|
| 157 |
+
window[:, :, window.shape[2]//2:] = 0.
|
| 158 |
+
|
| 159 |
+
wrong_window = self.prepare_window(wrong_window)
|
| 160 |
+
x = np.concatenate([window, wrong_window], axis=0)
|
| 161 |
+
|
| 162 |
+
x = torch.FloatTensor(x)
|
| 163 |
+
mel = torch.FloatTensor(mel.T).unsqueeze(0)
|
| 164 |
+
indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1)
|
| 165 |
+
y = torch.FloatTensor(y)
|
| 166 |
+
return x, indiv_mels, mel, y
|
| 167 |
+
|
| 168 |
+
def save_sample_images(x, g, gt, global_step, checkpoint_dir):
|
| 169 |
+
x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
|
| 170 |
+
g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
|
| 171 |
+
gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
|
| 172 |
+
|
| 173 |
+
refs, inps = x[..., 3:], x[..., :3]
|
| 174 |
+
folder = join(checkpoint_dir, "samples_step{:09d}".format(global_step))
|
| 175 |
+
if not os.path.exists(folder): os.mkdir(folder)
|
| 176 |
+
collage = np.concatenate((refs, inps, g, gt), axis=-2)
|
| 177 |
+
for batch_idx, c in enumerate(collage):
|
| 178 |
+
for t in range(len(c)):
|
| 179 |
+
cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t])
|
| 180 |
+
|
| 181 |
+
logloss = nn.BCELoss()
|
| 182 |
+
def cosine_loss(a, v, y):
|
| 183 |
+
d = nn.functional.cosine_similarity(a, v)
|
| 184 |
+
loss = logloss(d.unsqueeze(1), y)
|
| 185 |
+
|
| 186 |
+
return loss
|
| 187 |
+
|
| 188 |
+
device = torch.device("cuda" if use_cuda else "cpu")
|
| 189 |
+
syncnet = SyncNet().to(device)
|
| 190 |
+
for p in syncnet.parameters():
|
| 191 |
+
p.requires_grad = False
|
| 192 |
+
|
| 193 |
+
recon_loss = nn.L1Loss()
|
| 194 |
+
def get_sync_loss(mel, g):
|
| 195 |
+
g = g[:, :, :, g.size(3)//2:]
|
| 196 |
+
g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1)
|
| 197 |
+
# B, 3 * T, H//2, W
|
| 198 |
+
a, v = syncnet(mel, g)
|
| 199 |
+
y = torch.ones(g.size(0), 1).float().to(device)
|
| 200 |
+
return cosine_loss(a, v, y)
|
| 201 |
+
|
| 202 |
+
def train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer,
|
| 203 |
+
checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
|
| 204 |
+
global global_step, global_epoch
|
| 205 |
+
resumed_step = global_step
|
| 206 |
+
|
| 207 |
+
while global_epoch < nepochs:
|
| 208 |
+
print('Starting Epoch: {}'.format(global_epoch))
|
| 209 |
+
running_sync_loss, running_l1_loss, disc_loss, running_perceptual_loss = 0., 0., 0., 0.
|
| 210 |
+
running_disc_real_loss, running_disc_fake_loss = 0., 0.
|
| 211 |
+
prog_bar = tqdm(enumerate(train_data_loader))
|
| 212 |
+
for step, (x, indiv_mels, mel, gt) in prog_bar:
|
| 213 |
+
disc.train()
|
| 214 |
+
model.train()
|
| 215 |
+
|
| 216 |
+
x = x.to(device)
|
| 217 |
+
mel = mel.to(device)
|
| 218 |
+
indiv_mels = indiv_mels.to(device)
|
| 219 |
+
gt = gt.to(device)
|
| 220 |
+
|
| 221 |
+
### Train generator now. Remove ALL grads.
|
| 222 |
+
optimizer.zero_grad()
|
| 223 |
+
disc_optimizer.zero_grad()
|
| 224 |
+
|
| 225 |
+
g = model(indiv_mels, x)
|
| 226 |
+
|
| 227 |
+
if hparams.syncnet_wt > 0.:
|
| 228 |
+
sync_loss = get_sync_loss(mel, g)
|
| 229 |
+
else:
|
| 230 |
+
sync_loss = 0.
|
| 231 |
+
|
| 232 |
+
if hparams.disc_wt > 0.:
|
| 233 |
+
perceptual_loss = disc.perceptual_forward(g)
|
| 234 |
+
else:
|
| 235 |
+
perceptual_loss = 0.
|
| 236 |
+
|
| 237 |
+
l1loss = recon_loss(g, gt)
|
| 238 |
+
|
| 239 |
+
loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \
|
| 240 |
+
(1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss
|
| 241 |
+
|
| 242 |
+
loss.backward()
|
| 243 |
+
optimizer.step()
|
| 244 |
+
|
| 245 |
+
### Remove all gradients before Training disc
|
| 246 |
+
disc_optimizer.zero_grad()
|
| 247 |
+
|
| 248 |
+
pred = disc(gt)
|
| 249 |
+
disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device))
|
| 250 |
+
disc_real_loss.backward()
|
| 251 |
+
|
| 252 |
+
pred = disc(g.detach())
|
| 253 |
+
disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device))
|
| 254 |
+
disc_fake_loss.backward()
|
| 255 |
+
|
| 256 |
+
disc_optimizer.step()
|
| 257 |
+
|
| 258 |
+
running_disc_real_loss += disc_real_loss.item()
|
| 259 |
+
running_disc_fake_loss += disc_fake_loss.item()
|
| 260 |
+
|
| 261 |
+
if global_step % checkpoint_interval == 0:
|
| 262 |
+
save_sample_images(x, g, gt, global_step, checkpoint_dir)
|
| 263 |
+
|
| 264 |
+
# Logs
|
| 265 |
+
global_step += 1
|
| 266 |
+
cur_session_steps = global_step - resumed_step
|
| 267 |
+
|
| 268 |
+
running_l1_loss += l1loss.item()
|
| 269 |
+
if hparams.syncnet_wt > 0.:
|
| 270 |
+
running_sync_loss += sync_loss.item()
|
| 271 |
+
else:
|
| 272 |
+
running_sync_loss += 0.
|
| 273 |
+
|
| 274 |
+
if hparams.disc_wt > 0.:
|
| 275 |
+
running_perceptual_loss += perceptual_loss.item()
|
| 276 |
+
else:
|
| 277 |
+
running_perceptual_loss += 0.
|
| 278 |
+
|
| 279 |
+
if global_step == 1 or global_step % checkpoint_interval == 0:
|
| 280 |
+
save_checkpoint(
|
| 281 |
+
model, optimizer, global_step, checkpoint_dir, global_epoch)
|
| 282 |
+
save_checkpoint(disc, disc_optimizer, global_step, checkpoint_dir, global_epoch, prefix='disc_')
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
if global_step % hparams.eval_interval == 0:
|
| 286 |
+
with torch.no_grad():
|
| 287 |
+
average_sync_loss = eval_model(test_data_loader, global_step, device, model, disc)
|
| 288 |
+
|
| 289 |
+
if average_sync_loss < .75:
|
| 290 |
+
hparams.set_hparam('syncnet_wt', 0.03)
|
| 291 |
+
|
| 292 |
+
prog_bar.set_description('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(running_l1_loss / (step + 1),
|
| 293 |
+
running_sync_loss / (step + 1),
|
| 294 |
+
running_perceptual_loss / (step + 1),
|
| 295 |
+
running_disc_fake_loss / (step + 1),
|
| 296 |
+
running_disc_real_loss / (step + 1)))
|
| 297 |
+
|
| 298 |
+
global_epoch += 1
|
| 299 |
+
|
| 300 |
+
def eval_model(test_data_loader, global_step, device, model, disc):
|
| 301 |
+
eval_steps = 300
|
| 302 |
+
print('Evaluating for {} steps'.format(eval_steps))
|
| 303 |
+
running_sync_loss, running_l1_loss, running_disc_real_loss, running_disc_fake_loss, running_perceptual_loss = [], [], [], [], []
|
| 304 |
+
while 1:
|
| 305 |
+
for step, (x, indiv_mels, mel, gt) in enumerate((test_data_loader)):
|
| 306 |
+
model.eval()
|
| 307 |
+
disc.eval()
|
| 308 |
+
|
| 309 |
+
x = x.to(device)
|
| 310 |
+
mel = mel.to(device)
|
| 311 |
+
indiv_mels = indiv_mels.to(device)
|
| 312 |
+
gt = gt.to(device)
|
| 313 |
+
|
| 314 |
+
pred = disc(gt)
|
| 315 |
+
disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device))
|
| 316 |
+
|
| 317 |
+
g = model(indiv_mels, x)
|
| 318 |
+
pred = disc(g)
|
| 319 |
+
disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device))
|
| 320 |
+
|
| 321 |
+
running_disc_real_loss.append(disc_real_loss.item())
|
| 322 |
+
running_disc_fake_loss.append(disc_fake_loss.item())
|
| 323 |
+
|
| 324 |
+
sync_loss = get_sync_loss(mel, g)
|
| 325 |
+
|
| 326 |
+
if hparams.disc_wt > 0.:
|
| 327 |
+
perceptual_loss = disc.perceptual_forward(g)
|
| 328 |
+
else:
|
| 329 |
+
perceptual_loss = 0.
|
| 330 |
+
|
| 331 |
+
l1loss = recon_loss(g, gt)
|
| 332 |
+
|
| 333 |
+
loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \
|
| 334 |
+
(1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss
|
| 335 |
+
|
| 336 |
+
running_l1_loss.append(l1loss.item())
|
| 337 |
+
running_sync_loss.append(sync_loss.item())
|
| 338 |
+
|
| 339 |
+
if hparams.disc_wt > 0.:
|
| 340 |
+
running_perceptual_loss.append(perceptual_loss.item())
|
| 341 |
+
else:
|
| 342 |
+
running_perceptual_loss.append(0.)
|
| 343 |
+
|
| 344 |
+
if step > eval_steps: break
|
| 345 |
+
|
| 346 |
+
print('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(sum(running_l1_loss) / len(running_l1_loss),
|
| 347 |
+
sum(running_sync_loss) / len(running_sync_loss),
|
| 348 |
+
sum(running_perceptual_loss) / len(running_perceptual_loss),
|
| 349 |
+
sum(running_disc_fake_loss) / len(running_disc_fake_loss),
|
| 350 |
+
sum(running_disc_real_loss) / len(running_disc_real_loss)))
|
| 351 |
+
return sum(running_sync_loss) / len(running_sync_loss)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch, prefix=''):
|
| 355 |
+
checkpoint_path = join(
|
| 356 |
+
checkpoint_dir, "{}checkpoint_step{:09d}.pth".format(prefix, global_step))
|
| 357 |
+
optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
|
| 358 |
+
torch.save({
|
| 359 |
+
"state_dict": model.state_dict(),
|
| 360 |
+
"optimizer": optimizer_state,
|
| 361 |
+
"global_step": step,
|
| 362 |
+
"global_epoch": epoch,
|
| 363 |
+
}, checkpoint_path)
|
| 364 |
+
print("Saved checkpoint:", checkpoint_path)
|
| 365 |
+
|
| 366 |
+
def _load(checkpoint_path):
|
| 367 |
+
if use_cuda:
|
| 368 |
+
checkpoint = torch.load(checkpoint_path)
|
| 369 |
+
else:
|
| 370 |
+
checkpoint = torch.load(checkpoint_path,
|
| 371 |
+
map_location=lambda storage, loc: storage)
|
| 372 |
+
return checkpoint
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
|
| 376 |
+
global global_step
|
| 377 |
+
global global_epoch
|
| 378 |
+
|
| 379 |
+
print("Load checkpoint from: {}".format(path))
|
| 380 |
+
checkpoint = _load(path)
|
| 381 |
+
s = checkpoint["state_dict"]
|
| 382 |
+
new_s = {}
|
| 383 |
+
for k, v in s.items():
|
| 384 |
+
new_s[k.replace('module.', '')] = v
|
| 385 |
+
model.load_state_dict(new_s)
|
| 386 |
+
if not reset_optimizer:
|
| 387 |
+
optimizer_state = checkpoint["optimizer"]
|
| 388 |
+
if optimizer_state is not None:
|
| 389 |
+
print("Load optimizer state from {}".format(path))
|
| 390 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
| 391 |
+
if overwrite_global_states:
|
| 392 |
+
global_step = checkpoint["global_step"]
|
| 393 |
+
global_epoch = checkpoint["global_epoch"]
|
| 394 |
+
|
| 395 |
+
return model
|
| 396 |
+
|
| 397 |
+
if __name__ == "__main__":
|
| 398 |
+
checkpoint_dir = args.checkpoint_dir
|
| 399 |
+
|
| 400 |
+
# Dataset and Dataloader setup
|
| 401 |
+
train_dataset = Dataset('train')
|
| 402 |
+
test_dataset = Dataset('val')
|
| 403 |
+
|
| 404 |
+
train_data_loader = data_utils.DataLoader(
|
| 405 |
+
train_dataset, batch_size=hparams.batch_size, shuffle=True,
|
| 406 |
+
num_workers=hparams.num_workers)
|
| 407 |
+
|
| 408 |
+
test_data_loader = data_utils.DataLoader(
|
| 409 |
+
test_dataset, batch_size=hparams.batch_size,
|
| 410 |
+
num_workers=4)
|
| 411 |
+
|
| 412 |
+
device = torch.device("cuda" if use_cuda else "cpu")
|
| 413 |
+
|
| 414 |
+
# Model
|
| 415 |
+
model = Wav2Lip().to(device)
|
| 416 |
+
disc = Wav2Lip_disc_qual().to(device)
|
| 417 |
+
|
| 418 |
+
print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
| 419 |
+
print('total DISC trainable params {}'.format(sum(p.numel() for p in disc.parameters() if p.requires_grad)))
|
| 420 |
+
|
| 421 |
+
optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
|
| 422 |
+
lr=hparams.initial_learning_rate, betas=(0.5, 0.999))
|
| 423 |
+
disc_optimizer = optim.Adam([p for p in disc.parameters() if p.requires_grad],
|
| 424 |
+
lr=hparams.disc_initial_learning_rate, betas=(0.5, 0.999))
|
| 425 |
+
|
| 426 |
+
if args.checkpoint_path is not None:
|
| 427 |
+
load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)
|
| 428 |
+
|
| 429 |
+
if args.disc_checkpoint_path is not None:
|
| 430 |
+
load_checkpoint(args.disc_checkpoint_path, disc, disc_optimizer,
|
| 431 |
+
reset_optimizer=False, overwrite_global_states=False)
|
| 432 |
+
|
| 433 |
+
load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True,
|
| 434 |
+
overwrite_global_states=False)
|
| 435 |
+
|
| 436 |
+
if not os.path.exists(checkpoint_dir):
|
| 437 |
+
os.mkdir(checkpoint_dir)
|
| 438 |
+
|
| 439 |
+
# Train!
|
| 440 |
+
train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer,
|
| 441 |
+
checkpoint_dir=checkpoint_dir,
|
| 442 |
+
checkpoint_interval=hparams.checkpoint_interval,
|
| 443 |
+
nepochs=hparams.nepochs)
|
inference.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os import listdir, path
|
| 2 |
+
import numpy as np
|
| 3 |
+
import scipy, cv2, os, sys, argparse, audio
|
| 4 |
+
import json, subprocess, random, string
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from glob import glob
|
| 7 |
+
import torch, face_detection
|
| 8 |
+
from models import Wav2Lip
|
| 9 |
+
import platform
|
| 10 |
+
|
| 11 |
+
parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
|
| 12 |
+
|
| 13 |
+
parser.add_argument('--checkpoint_path', type=str,
|
| 14 |
+
help='Name of saved checkpoint to load weights from', required=True)
|
| 15 |
+
|
| 16 |
+
parser.add_argument('--face', type=str,
|
| 17 |
+
help='Filepath of video/image that contains faces to use', required=True)
|
| 18 |
+
parser.add_argument('--audio', type=str,
|
| 19 |
+
help='Filepath of video/audio file to use as raw audio source', required=True)
|
| 20 |
+
parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.',
|
| 21 |
+
default='results/result_voice.mp4')
|
| 22 |
+
|
| 23 |
+
parser.add_argument('--static', type=bool,
|
| 24 |
+
help='If True, then use only first video frame for inference', default=False)
|
| 25 |
+
parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)',
|
| 26 |
+
default=25., required=False)
|
| 27 |
+
|
| 28 |
+
parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
|
| 29 |
+
help='Padding (top, bottom, left, right). Please adjust to include chin at least')
|
| 30 |
+
|
| 31 |
+
parser.add_argument('--face_det_batch_size', type=int,
|
| 32 |
+
help='Batch size for face detection', default=16)
|
| 33 |
+
parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=128)
|
| 34 |
+
|
| 35 |
+
parser.add_argument('--resize_factor', default=1, type=int,
|
| 36 |
+
help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')
|
| 37 |
+
|
| 38 |
+
parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1],
|
| 39 |
+
help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
|
| 40 |
+
'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')
|
| 41 |
+
|
| 42 |
+
parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1],
|
| 43 |
+
help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
|
| 44 |
+
'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')
|
| 45 |
+
|
| 46 |
+
parser.add_argument('--rotate', default=False, action='store_true',
|
| 47 |
+
help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.'
|
| 48 |
+
'Use if you get a flipped result, despite feeding a normal looking video')
|
| 49 |
+
|
| 50 |
+
parser.add_argument('--nosmooth', default=False, action='store_true',
|
| 51 |
+
help='Prevent smoothing face detections over a short temporal window')
|
| 52 |
+
|
| 53 |
+
args = parser.parse_args()
|
| 54 |
+
args.img_size = 96
|
| 55 |
+
|
| 56 |
+
if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
|
| 57 |
+
args.static = True
|
| 58 |
+
|
| 59 |
+
def get_smoothened_boxes(boxes, T):
|
| 60 |
+
for i in range(len(boxes)):
|
| 61 |
+
if i + T > len(boxes):
|
| 62 |
+
window = boxes[len(boxes) - T:]
|
| 63 |
+
else:
|
| 64 |
+
window = boxes[i : i + T]
|
| 65 |
+
boxes[i] = np.mean(window, axis=0)
|
| 66 |
+
return boxes
|
| 67 |
+
|
| 68 |
+
def face_detect(images):
|
| 69 |
+
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
|
| 70 |
+
flip_input=False, device=device)
|
| 71 |
+
|
| 72 |
+
batch_size = args.face_det_batch_size
|
| 73 |
+
|
| 74 |
+
while 1:
|
| 75 |
+
predictions = []
|
| 76 |
+
try:
|
| 77 |
+
for i in tqdm(range(0, len(images), batch_size)):
|
| 78 |
+
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
|
| 79 |
+
except RuntimeError:
|
| 80 |
+
if batch_size == 1:
|
| 81 |
+
raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
|
| 82 |
+
batch_size //= 2
|
| 83 |
+
print('Recovering from OOM error; New batch size: {}'.format(batch_size))
|
| 84 |
+
continue
|
| 85 |
+
break
|
| 86 |
+
|
| 87 |
+
results = []
|
| 88 |
+
pady1, pady2, padx1, padx2 = args.pads
|
| 89 |
+
for rect, image in zip(predictions, images):
|
| 90 |
+
if rect is None:
|
| 91 |
+
cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
|
| 92 |
+
raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
|
| 93 |
+
|
| 94 |
+
y1 = max(0, rect[1] - pady1)
|
| 95 |
+
y2 = min(image.shape[0], rect[3] + pady2)
|
| 96 |
+
x1 = max(0, rect[0] - padx1)
|
| 97 |
+
x2 = min(image.shape[1], rect[2] + padx2)
|
| 98 |
+
|
| 99 |
+
results.append([x1, y1, x2, y2])
|
| 100 |
+
|
| 101 |
+
boxes = np.array(results)
|
| 102 |
+
if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
|
| 103 |
+
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
|
| 104 |
+
|
| 105 |
+
del detector
|
| 106 |
+
return results
|
| 107 |
+
|
| 108 |
+
def datagen(frames, mels):
|
| 109 |
+
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
|
| 110 |
+
|
| 111 |
+
if args.box[0] == -1:
|
| 112 |
+
if not args.static:
|
| 113 |
+
face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
|
| 114 |
+
else:
|
| 115 |
+
face_det_results = face_detect([frames[0]])
|
| 116 |
+
else:
|
| 117 |
+
print('Using the specified bounding box instead of face detection...')
|
| 118 |
+
y1, y2, x1, x2 = args.box
|
| 119 |
+
face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
|
| 120 |
+
|
| 121 |
+
for i, m in enumerate(mels):
|
| 122 |
+
idx = 0 if args.static else i%len(frames)
|
| 123 |
+
frame_to_save = frames[idx].copy()
|
| 124 |
+
face, coords = face_det_results[idx].copy()
|
| 125 |
+
|
| 126 |
+
face = cv2.resize(face, (args.img_size, args.img_size))
|
| 127 |
+
|
| 128 |
+
img_batch.append(face)
|
| 129 |
+
mel_batch.append(m)
|
| 130 |
+
frame_batch.append(frame_to_save)
|
| 131 |
+
coords_batch.append(coords)
|
| 132 |
+
|
| 133 |
+
if len(img_batch) >= args.wav2lip_batch_size:
|
| 134 |
+
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
|
| 135 |
+
|
| 136 |
+
img_masked = img_batch.copy()
|
| 137 |
+
img_masked[:, args.img_size//2:] = 0
|
| 138 |
+
|
| 139 |
+
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
| 140 |
+
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
| 141 |
+
|
| 142 |
+
yield img_batch, mel_batch, frame_batch, coords_batch
|
| 143 |
+
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
|
| 144 |
+
|
| 145 |
+
if len(img_batch) > 0:
|
| 146 |
+
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
|
| 147 |
+
|
| 148 |
+
img_masked = img_batch.copy()
|
| 149 |
+
img_masked[:, args.img_size//2:] = 0
|
| 150 |
+
|
| 151 |
+
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
| 152 |
+
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
| 153 |
+
|
| 154 |
+
yield img_batch, mel_batch, frame_batch, coords_batch
|
| 155 |
+
|
| 156 |
+
mel_step_size = 16
|
| 157 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 158 |
+
print('Using {} for inference.'.format(device))
|
| 159 |
+
|
| 160 |
+
def _load(checkpoint_path):
|
| 161 |
+
if device == 'cuda':
|
| 162 |
+
checkpoint = torch.load(checkpoint_path)
|
| 163 |
+
else:
|
| 164 |
+
checkpoint = torch.load(checkpoint_path,
|
| 165 |
+
map_location=lambda storage, loc: storage)
|
| 166 |
+
return checkpoint
|
| 167 |
+
|
| 168 |
+
def load_model(path):
|
| 169 |
+
model = Wav2Lip()
|
| 170 |
+
print("Load checkpoint from: {}".format(path))
|
| 171 |
+
checkpoint = _load(path)
|
| 172 |
+
s = checkpoint["state_dict"]
|
| 173 |
+
new_s = {}
|
| 174 |
+
for k, v in s.items():
|
| 175 |
+
new_s[k.replace('module.', '')] = v
|
| 176 |
+
model.load_state_dict(new_s)
|
| 177 |
+
|
| 178 |
+
model = model.to(device)
|
| 179 |
+
return model.eval()
|
| 180 |
+
|
| 181 |
+
def main():
|
| 182 |
+
if not os.path.isfile(args.face):
|
| 183 |
+
raise ValueError('--face argument must be a valid path to video/image file')
|
| 184 |
+
|
| 185 |
+
elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
|
| 186 |
+
full_frames = [cv2.imread(args.face)]
|
| 187 |
+
fps = args.fps
|
| 188 |
+
|
| 189 |
+
else:
|
| 190 |
+
video_stream = cv2.VideoCapture(args.face)
|
| 191 |
+
fps = video_stream.get(cv2.CAP_PROP_FPS)
|
| 192 |
+
|
| 193 |
+
print('Reading video frames...')
|
| 194 |
+
|
| 195 |
+
full_frames = []
|
| 196 |
+
while 1:
|
| 197 |
+
still_reading, frame = video_stream.read()
|
| 198 |
+
if not still_reading:
|
| 199 |
+
video_stream.release()
|
| 200 |
+
break
|
| 201 |
+
if args.resize_factor > 1:
|
| 202 |
+
frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor))
|
| 203 |
+
|
| 204 |
+
if args.rotate:
|
| 205 |
+
frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
|
| 206 |
+
|
| 207 |
+
y1, y2, x1, x2 = args.crop
|
| 208 |
+
if x2 == -1: x2 = frame.shape[1]
|
| 209 |
+
if y2 == -1: y2 = frame.shape[0]
|
| 210 |
+
|
| 211 |
+
frame = frame[y1:y2, x1:x2]
|
| 212 |
+
|
| 213 |
+
full_frames.append(frame)
|
| 214 |
+
|
| 215 |
+
print ("Number of frames available for inference: "+str(len(full_frames)))
|
| 216 |
+
|
| 217 |
+
if not args.audio.endswith('.wav'):
|
| 218 |
+
print('Extracting raw audio...')
|
| 219 |
+
command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav')
|
| 220 |
+
|
| 221 |
+
subprocess.call(command, shell=True)
|
| 222 |
+
args.audio = 'temp/temp.wav'
|
| 223 |
+
|
| 224 |
+
wav = audio.load_wav(args.audio, 16000)
|
| 225 |
+
mel = audio.melspectrogram(wav)
|
| 226 |
+
print(mel.shape)
|
| 227 |
+
|
| 228 |
+
if np.isnan(mel.reshape(-1)).sum() > 0:
|
| 229 |
+
raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
|
| 230 |
+
|
| 231 |
+
mel_chunks = []
|
| 232 |
+
mel_idx_multiplier = 80./fps
|
| 233 |
+
i = 0
|
| 234 |
+
while 1:
|
| 235 |
+
start_idx = int(i * mel_idx_multiplier)
|
| 236 |
+
if start_idx + mel_step_size > len(mel[0]):
|
| 237 |
+
mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
|
| 238 |
+
break
|
| 239 |
+
mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
|
| 240 |
+
i += 1
|
| 241 |
+
|
| 242 |
+
print("Length of mel chunks: {}".format(len(mel_chunks)))
|
| 243 |
+
|
| 244 |
+
full_frames = full_frames[:len(mel_chunks)]
|
| 245 |
+
|
| 246 |
+
batch_size = args.wav2lip_batch_size
|
| 247 |
+
gen = datagen(full_frames.copy(), mel_chunks)
|
| 248 |
+
|
| 249 |
+
for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
|
| 250 |
+
total=int(np.ceil(float(len(mel_chunks))/batch_size)))):
|
| 251 |
+
if i == 0:
|
| 252 |
+
model = load_model(args.checkpoint_path)
|
| 253 |
+
print ("Model loaded")
|
| 254 |
+
|
| 255 |
+
frame_h, frame_w = full_frames[0].shape[:-1]
|
| 256 |
+
out = cv2.VideoWriter('temp/result.avi',
|
| 257 |
+
cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
|
| 258 |
+
|
| 259 |
+
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
|
| 260 |
+
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
|
| 261 |
+
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
pred = model(mel_batch, img_batch)
|
| 264 |
+
|
| 265 |
+
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
|
| 266 |
+
|
| 267 |
+
for p, f, c in zip(pred, frames, coords):
|
| 268 |
+
y1, y2, x1, x2 = c
|
| 269 |
+
p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
|
| 270 |
+
|
| 271 |
+
f[y1:y2, x1:x2] = p
|
| 272 |
+
out.write(f)
|
| 273 |
+
|
| 274 |
+
out.release()
|
| 275 |
+
|
| 276 |
+
command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/result.avi', args.outfile)
|
| 277 |
+
subprocess.call(command, shell=platform.system() != 'Windows')
|
| 278 |
+
|
| 279 |
+
if __name__ == '__main__':
|
| 280 |
+
main()
|
models/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .wav2lip import Wav2Lip, Wav2Lip_disc_qual
|
| 2 |
+
from .syncnet import SyncNet_color
|
models/conv.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
class Conv2d(nn.Module):
|
| 6 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
| 7 |
+
super().__init__(*args, **kwargs)
|
| 8 |
+
self.conv_block = nn.Sequential(
|
| 9 |
+
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
| 10 |
+
nn.BatchNorm2d(cout)
|
| 11 |
+
)
|
| 12 |
+
self.act = nn.ReLU()
|
| 13 |
+
self.residual = residual
|
| 14 |
+
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
out = self.conv_block(x)
|
| 17 |
+
if self.residual:
|
| 18 |
+
out += x
|
| 19 |
+
return self.act(out)
|
| 20 |
+
|
| 21 |
+
class nonorm_Conv2d(nn.Module):
|
| 22 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
| 23 |
+
super().__init__(*args, **kwargs)
|
| 24 |
+
self.conv_block = nn.Sequential(
|
| 25 |
+
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
| 26 |
+
)
|
| 27 |
+
self.act = nn.LeakyReLU(0.01, inplace=True)
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
out = self.conv_block(x)
|
| 31 |
+
return self.act(out)
|
| 32 |
+
|
| 33 |
+
class Conv2dTranspose(nn.Module):
|
| 34 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
|
| 35 |
+
super().__init__(*args, **kwargs)
|
| 36 |
+
self.conv_block = nn.Sequential(
|
| 37 |
+
nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
|
| 38 |
+
nn.BatchNorm2d(cout)
|
| 39 |
+
)
|
| 40 |
+
self.act = nn.ReLU()
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
out = self.conv_block(x)
|
| 44 |
+
return self.act(out)
|
models/syncnet.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
from .conv import Conv2d
|
| 6 |
+
|
| 7 |
+
class SyncNet_color(nn.Module):
|
| 8 |
+
def __init__(self):
|
| 9 |
+
super(SyncNet_color, self).__init__()
|
| 10 |
+
|
| 11 |
+
self.face_encoder = nn.Sequential(
|
| 12 |
+
Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3),
|
| 13 |
+
|
| 14 |
+
Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1),
|
| 15 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
| 16 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
| 17 |
+
|
| 18 |
+
Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
| 19 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
| 20 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
| 21 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
| 22 |
+
|
| 23 |
+
Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
|
| 24 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
| 25 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
| 26 |
+
|
| 27 |
+
Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
|
| 28 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
|
| 29 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
|
| 30 |
+
|
| 31 |
+
Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
| 32 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
|
| 33 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
| 34 |
+
|
| 35 |
+
self.audio_encoder = nn.Sequential(
|
| 36 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
| 37 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
| 38 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
| 39 |
+
|
| 40 |
+
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
| 41 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
| 42 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
| 43 |
+
|
| 44 |
+
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
| 45 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
| 46 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
| 47 |
+
|
| 48 |
+
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
| 49 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
| 50 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
| 51 |
+
|
| 52 |
+
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
| 53 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
| 54 |
+
|
| 55 |
+
def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T)
|
| 56 |
+
face_embedding = self.face_encoder(face_sequences)
|
| 57 |
+
audio_embedding = self.audio_encoder(audio_sequences)
|
| 58 |
+
|
| 59 |
+
audio_embedding = audio_embedding.view(audio_embedding.size(0), -1)
|
| 60 |
+
face_embedding = face_embedding.view(face_embedding.size(0), -1)
|
| 61 |
+
|
| 62 |
+
audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
|
| 63 |
+
face_embedding = F.normalize(face_embedding, p=2, dim=1)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
return audio_embedding, face_embedding
|
models/wav2lip.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d
|
| 7 |
+
|
| 8 |
+
class Wav2Lip(nn.Module):
|
| 9 |
+
def __init__(self):
|
| 10 |
+
super(Wav2Lip, self).__init__()
|
| 11 |
+
|
| 12 |
+
self.face_encoder_blocks = nn.ModuleList([
|
| 13 |
+
nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96
|
| 14 |
+
|
| 15 |
+
nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48
|
| 16 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
| 17 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)),
|
| 18 |
+
|
| 19 |
+
nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24
|
| 20 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
| 21 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
| 22 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)),
|
| 23 |
+
|
| 24 |
+
nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12
|
| 25 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
| 26 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)),
|
| 27 |
+
|
| 28 |
+
nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6
|
| 29 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
| 30 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)),
|
| 31 |
+
|
| 32 |
+
nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3
|
| 33 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
|
| 34 |
+
|
| 35 |
+
nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
|
| 36 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
|
| 37 |
+
|
| 38 |
+
self.audio_encoder = nn.Sequential(
|
| 39 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
| 40 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
| 41 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
| 42 |
+
|
| 43 |
+
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
| 44 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
| 45 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
| 46 |
+
|
| 47 |
+
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
| 48 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
| 49 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
| 50 |
+
|
| 51 |
+
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
| 52 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
| 53 |
+
|
| 54 |
+
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
| 55 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
| 56 |
+
|
| 57 |
+
self.face_decoder_blocks = nn.ModuleList([
|
| 58 |
+
nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),),
|
| 59 |
+
|
| 60 |
+
nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3
|
| 61 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
|
| 62 |
+
|
| 63 |
+
nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
|
| 64 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
|
| 65 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6
|
| 66 |
+
|
| 67 |
+
nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
|
| 68 |
+
Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
|
| 69 |
+
Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12
|
| 70 |
+
|
| 71 |
+
nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
|
| 72 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
| 73 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24
|
| 74 |
+
|
| 75 |
+
nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
|
| 76 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
| 77 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48
|
| 78 |
+
|
| 79 |
+
nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
|
| 80 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
| 81 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96
|
| 82 |
+
|
| 83 |
+
self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
|
| 84 |
+
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
|
| 85 |
+
nn.Sigmoid())
|
| 86 |
+
|
| 87 |
+
def forward(self, audio_sequences, face_sequences):
|
| 88 |
+
# audio_sequences = (B, T, 1, 80, 16)
|
| 89 |
+
B = audio_sequences.size(0)
|
| 90 |
+
|
| 91 |
+
input_dim_size = len(face_sequences.size())
|
| 92 |
+
if input_dim_size > 4:
|
| 93 |
+
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
|
| 94 |
+
face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
|
| 95 |
+
|
| 96 |
+
audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
|
| 97 |
+
|
| 98 |
+
feats = []
|
| 99 |
+
x = face_sequences
|
| 100 |
+
for f in self.face_encoder_blocks:
|
| 101 |
+
x = f(x)
|
| 102 |
+
feats.append(x)
|
| 103 |
+
|
| 104 |
+
x = audio_embedding
|
| 105 |
+
for f in self.face_decoder_blocks:
|
| 106 |
+
x = f(x)
|
| 107 |
+
try:
|
| 108 |
+
x = torch.cat((x, feats[-1]), dim=1)
|
| 109 |
+
except Exception as e:
|
| 110 |
+
print(x.size())
|
| 111 |
+
print(feats[-1].size())
|
| 112 |
+
raise e
|
| 113 |
+
|
| 114 |
+
feats.pop()
|
| 115 |
+
|
| 116 |
+
x = self.output_block(x)
|
| 117 |
+
|
| 118 |
+
if input_dim_size > 4:
|
| 119 |
+
x = torch.split(x, B, dim=0) # [(B, C, H, W)]
|
| 120 |
+
outputs = torch.stack(x, dim=2) # (B, C, T, H, W)
|
| 121 |
+
|
| 122 |
+
else:
|
| 123 |
+
outputs = x
|
| 124 |
+
|
| 125 |
+
return outputs
|
| 126 |
+
|
| 127 |
+
class Wav2Lip_disc_qual(nn.Module):
|
| 128 |
+
def __init__(self):
|
| 129 |
+
super(Wav2Lip_disc_qual, self).__init__()
|
| 130 |
+
|
| 131 |
+
self.face_encoder_blocks = nn.ModuleList([
|
| 132 |
+
nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)), # 48,96
|
| 133 |
+
|
| 134 |
+
nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2), # 48,48
|
| 135 |
+
nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)),
|
| 136 |
+
|
| 137 |
+
nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 24,24
|
| 138 |
+
nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)),
|
| 139 |
+
|
| 140 |
+
nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 12,12
|
| 141 |
+
nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)),
|
| 142 |
+
|
| 143 |
+
nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6,6
|
| 144 |
+
nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)),
|
| 145 |
+
|
| 146 |
+
nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 3,3
|
| 147 |
+
nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1),),
|
| 148 |
+
|
| 149 |
+
nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
|
| 150 |
+
nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
|
| 151 |
+
|
| 152 |
+
self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid())
|
| 153 |
+
self.label_noise = .0
|
| 154 |
+
|
| 155 |
+
def get_lower_half(self, face_sequences):
|
| 156 |
+
return face_sequences[:, :, face_sequences.size(2)//2:]
|
| 157 |
+
|
| 158 |
+
def to_2d(self, face_sequences):
|
| 159 |
+
B = face_sequences.size(0)
|
| 160 |
+
face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
|
| 161 |
+
return face_sequences
|
| 162 |
+
|
| 163 |
+
def perceptual_forward(self, false_face_sequences):
|
| 164 |
+
false_face_sequences = self.to_2d(false_face_sequences)
|
| 165 |
+
false_face_sequences = self.get_lower_half(false_face_sequences)
|
| 166 |
+
|
| 167 |
+
false_feats = false_face_sequences
|
| 168 |
+
for f in self.face_encoder_blocks:
|
| 169 |
+
false_feats = f(false_feats)
|
| 170 |
+
|
| 171 |
+
false_pred_loss = F.binary_cross_entropy(self.binary_pred(false_feats).view(len(false_feats), -1),
|
| 172 |
+
torch.ones((len(false_feats), 1)).cuda())
|
| 173 |
+
|
| 174 |
+
return false_pred_loss
|
| 175 |
+
|
| 176 |
+
def forward(self, face_sequences):
|
| 177 |
+
face_sequences = self.to_2d(face_sequences)
|
| 178 |
+
face_sequences = self.get_lower_half(face_sequences)
|
| 179 |
+
|
| 180 |
+
x = face_sequences
|
| 181 |
+
for f in self.face_encoder_blocks:
|
| 182 |
+
x = f(x)
|
| 183 |
+
|
| 184 |
+
return self.binary_pred(x).view(len(x), -1)
|
preprocess.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
if sys.version_info[0] < 3 and sys.version_info[1] < 2:
|
| 4 |
+
raise Exception("Must be using >= Python 3.2")
|
| 5 |
+
|
| 6 |
+
from os import listdir, path
|
| 7 |
+
|
| 8 |
+
if not path.isfile('face_detection/detection/sfd/s3fd.pth'):
|
| 9 |
+
raise FileNotFoundError('Save the s3fd model to face_detection/detection/sfd/s3fd.pth \
|
| 10 |
+
before running this script!')
|
| 11 |
+
|
| 12 |
+
import multiprocessing as mp
|
| 13 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 14 |
+
import numpy as np
|
| 15 |
+
import argparse, os, cv2, traceback, subprocess
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
from glob import glob
|
| 18 |
+
import audio
|
| 19 |
+
from hparams import hparams as hp
|
| 20 |
+
|
| 21 |
+
import face_detection
|
| 22 |
+
|
| 23 |
+
parser = argparse.ArgumentParser()
|
| 24 |
+
|
| 25 |
+
parser.add_argument('--ngpu', help='Number of GPUs across which to run in parallel', default=1, type=int)
|
| 26 |
+
parser.add_argument('--batch_size', help='Single GPU Face detection batch size', default=32, type=int)
|
| 27 |
+
parser.add_argument("--data_root", help="Root folder of the LRS2 dataset", required=True)
|
| 28 |
+
parser.add_argument("--preprocessed_root", help="Root folder of the preprocessed dataset", required=True)
|
| 29 |
+
|
| 30 |
+
args = parser.parse_args()
|
| 31 |
+
|
| 32 |
+
fa = [face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False,
|
| 33 |
+
device='cuda:{}'.format(id)) for id in range(args.ngpu)]
|
| 34 |
+
|
| 35 |
+
template = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'
|
| 36 |
+
# template2 = 'ffmpeg -hide_banner -loglevel panic -threads 1 -y -i {} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {}'
|
| 37 |
+
|
| 38 |
+
def process_video_file(vfile, args, gpu_id):
|
| 39 |
+
video_stream = cv2.VideoCapture(vfile)
|
| 40 |
+
|
| 41 |
+
frames = []
|
| 42 |
+
while 1:
|
| 43 |
+
still_reading, frame = video_stream.read()
|
| 44 |
+
if not still_reading:
|
| 45 |
+
video_stream.release()
|
| 46 |
+
break
|
| 47 |
+
frames.append(frame)
|
| 48 |
+
|
| 49 |
+
vidname = os.path.basename(vfile).split('.')[0]
|
| 50 |
+
dirname = vfile.split('/')[-2]
|
| 51 |
+
|
| 52 |
+
fulldir = path.join(args.preprocessed_root, dirname, vidname)
|
| 53 |
+
os.makedirs(fulldir, exist_ok=True)
|
| 54 |
+
|
| 55 |
+
batches = [frames[i:i + args.batch_size] for i in range(0, len(frames), args.batch_size)]
|
| 56 |
+
|
| 57 |
+
i = -1
|
| 58 |
+
for fb in batches:
|
| 59 |
+
preds = fa[gpu_id].get_detections_for_batch(np.asarray(fb))
|
| 60 |
+
|
| 61 |
+
for j, f in enumerate(preds):
|
| 62 |
+
i += 1
|
| 63 |
+
if f is None:
|
| 64 |
+
continue
|
| 65 |
+
|
| 66 |
+
x1, y1, x2, y2 = f
|
| 67 |
+
cv2.imwrite(path.join(fulldir, '{}.jpg'.format(i)), fb[j][y1:y2, x1:x2])
|
| 68 |
+
|
| 69 |
+
def process_audio_file(vfile, args):
|
| 70 |
+
vidname = os.path.basename(vfile).split('.')[0]
|
| 71 |
+
dirname = vfile.split('/')[-2]
|
| 72 |
+
|
| 73 |
+
fulldir = path.join(args.preprocessed_root, dirname, vidname)
|
| 74 |
+
os.makedirs(fulldir, exist_ok=True)
|
| 75 |
+
|
| 76 |
+
wavpath = path.join(fulldir, 'audio.wav')
|
| 77 |
+
|
| 78 |
+
command = template.format(vfile, wavpath)
|
| 79 |
+
subprocess.call(command, shell=True)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def mp_handler(job):
|
| 83 |
+
vfile, args, gpu_id = job
|
| 84 |
+
try:
|
| 85 |
+
process_video_file(vfile, args, gpu_id)
|
| 86 |
+
except KeyboardInterrupt:
|
| 87 |
+
exit(0)
|
| 88 |
+
except:
|
| 89 |
+
traceback.print_exc()
|
| 90 |
+
|
| 91 |
+
def main(args):
|
| 92 |
+
print('Started processing for {} with {} GPUs'.format(args.data_root, args.ngpu))
|
| 93 |
+
|
| 94 |
+
filelist = glob(path.join(args.data_root, '*/*.mp4'))
|
| 95 |
+
|
| 96 |
+
jobs = [(vfile, args, i%args.ngpu) for i, vfile in enumerate(filelist)]
|
| 97 |
+
p = ThreadPoolExecutor(args.ngpu)
|
| 98 |
+
futures = [p.submit(mp_handler, j) for j in jobs]
|
| 99 |
+
_ = [r.result() for r in tqdm(as_completed(futures), total=len(futures))]
|
| 100 |
+
|
| 101 |
+
print('Dumping audios...')
|
| 102 |
+
|
| 103 |
+
for vfile in tqdm(filelist):
|
| 104 |
+
try:
|
| 105 |
+
process_audio_file(vfile, args)
|
| 106 |
+
except KeyboardInterrupt:
|
| 107 |
+
exit(0)
|
| 108 |
+
except:
|
| 109 |
+
traceback.print_exc()
|
| 110 |
+
continue
|
| 111 |
+
|
| 112 |
+
if __name__ == '__main__':
|
| 113 |
+
main(args)
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
librosa
|
| 2 |
+
numpy
|
| 3 |
+
opencv-contrib-python
|
| 4 |
+
opencv-python
|
| 5 |
+
torch
|
| 6 |
+
torchvision
|
| 7 |
+
tqdm
|
| 8 |
+
numba
|
requirementsCPU.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
librosa
|
| 2 |
+
numpy
|
| 3 |
+
opencv-contrib-python
|
| 4 |
+
opencv-python
|
| 5 |
+
-f https://download.pytorch.org/whl/torch_stable.html
|
| 6 |
+
torch
|
| 7 |
+
torchvision
|
| 8 |
+
tqdm
|
| 9 |
+
numba
|
wav2lip_train.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os.path import dirname, join, basename, isfile
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
|
| 4 |
+
from models import SyncNet_color as SyncNet
|
| 5 |
+
from models import Wav2Lip as Wav2Lip
|
| 6 |
+
import audio
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch import optim
|
| 11 |
+
import torch.backends.cudnn as cudnn
|
| 12 |
+
from torch.utils import data as data_utils
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
from glob import glob
|
| 16 |
+
|
| 17 |
+
import os, random, cv2, argparse
|
| 18 |
+
from hparams import hparams, get_image_list
|
| 19 |
+
|
| 20 |
+
parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model without the visual quality discriminator')
|
| 21 |
+
|
| 22 |
+
parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True, type=str)
|
| 23 |
+
|
| 24 |
+
parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
|
| 25 |
+
parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True, type=str)
|
| 26 |
+
|
| 27 |
+
parser.add_argument('--checkpoint_path', help='Resume from this checkpoint', default=None, type=str)
|
| 28 |
+
|
| 29 |
+
args = parser.parse_args()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
global_step = 0
|
| 33 |
+
global_epoch = 0
|
| 34 |
+
use_cuda = torch.cuda.is_available()
|
| 35 |
+
print('use_cuda: {}'.format(use_cuda))
|
| 36 |
+
|
| 37 |
+
syncnet_T = 5
|
| 38 |
+
syncnet_mel_step_size = 16
|
| 39 |
+
|
| 40 |
+
class Dataset(object):
|
| 41 |
+
def __init__(self, split):
|
| 42 |
+
self.all_videos = get_image_list(args.data_root, split)
|
| 43 |
+
|
| 44 |
+
def get_frame_id(self, frame):
|
| 45 |
+
return int(basename(frame).split('.')[0])
|
| 46 |
+
|
| 47 |
+
def get_window(self, start_frame):
|
| 48 |
+
start_id = self.get_frame_id(start_frame)
|
| 49 |
+
vidname = dirname(start_frame)
|
| 50 |
+
|
| 51 |
+
window_fnames = []
|
| 52 |
+
for frame_id in range(start_id, start_id + syncnet_T):
|
| 53 |
+
frame = join(vidname, '{}.jpg'.format(frame_id))
|
| 54 |
+
if not isfile(frame):
|
| 55 |
+
return None
|
| 56 |
+
window_fnames.append(frame)
|
| 57 |
+
return window_fnames
|
| 58 |
+
|
| 59 |
+
def read_window(self, window_fnames):
|
| 60 |
+
if window_fnames is None: return None
|
| 61 |
+
window = []
|
| 62 |
+
for fname in window_fnames:
|
| 63 |
+
img = cv2.imread(fname)
|
| 64 |
+
if img is None:
|
| 65 |
+
return None
|
| 66 |
+
try:
|
| 67 |
+
img = cv2.resize(img, (hparams.img_size, hparams.img_size))
|
| 68 |
+
except Exception as e:
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
window.append(img)
|
| 72 |
+
|
| 73 |
+
return window
|
| 74 |
+
|
| 75 |
+
def crop_audio_window(self, spec, start_frame):
|
| 76 |
+
if type(start_frame) == int:
|
| 77 |
+
start_frame_num = start_frame
|
| 78 |
+
else:
|
| 79 |
+
start_frame_num = self.get_frame_id(start_frame) # 0-indexing ---> 1-indexing
|
| 80 |
+
start_idx = int(80. * (start_frame_num / float(hparams.fps)))
|
| 81 |
+
|
| 82 |
+
end_idx = start_idx + syncnet_mel_step_size
|
| 83 |
+
|
| 84 |
+
return spec[start_idx : end_idx, :]
|
| 85 |
+
|
| 86 |
+
def get_segmented_mels(self, spec, start_frame):
|
| 87 |
+
mels = []
|
| 88 |
+
assert syncnet_T == 5
|
| 89 |
+
start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing
|
| 90 |
+
if start_frame_num - 2 < 0: return None
|
| 91 |
+
for i in range(start_frame_num, start_frame_num + syncnet_T):
|
| 92 |
+
m = self.crop_audio_window(spec, i - 2)
|
| 93 |
+
if m.shape[0] != syncnet_mel_step_size:
|
| 94 |
+
return None
|
| 95 |
+
mels.append(m.T)
|
| 96 |
+
|
| 97 |
+
mels = np.asarray(mels)
|
| 98 |
+
|
| 99 |
+
return mels
|
| 100 |
+
|
| 101 |
+
def prepare_window(self, window):
|
| 102 |
+
# 3 x T x H x W
|
| 103 |
+
x = np.asarray(window) / 255.
|
| 104 |
+
x = np.transpose(x, (3, 0, 1, 2))
|
| 105 |
+
|
| 106 |
+
return x
|
| 107 |
+
|
| 108 |
+
def __len__(self):
|
| 109 |
+
return len(self.all_videos)
|
| 110 |
+
|
| 111 |
+
def __getitem__(self, idx):
|
| 112 |
+
while 1:
|
| 113 |
+
idx = random.randint(0, len(self.all_videos) - 1)
|
| 114 |
+
vidname = self.all_videos[idx]
|
| 115 |
+
img_names = list(glob(join(vidname, '*.jpg')))
|
| 116 |
+
if len(img_names) <= 3 * syncnet_T:
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
img_name = random.choice(img_names)
|
| 120 |
+
wrong_img_name = random.choice(img_names)
|
| 121 |
+
while wrong_img_name == img_name:
|
| 122 |
+
wrong_img_name = random.choice(img_names)
|
| 123 |
+
|
| 124 |
+
window_fnames = self.get_window(img_name)
|
| 125 |
+
wrong_window_fnames = self.get_window(wrong_img_name)
|
| 126 |
+
if window_fnames is None or wrong_window_fnames is None:
|
| 127 |
+
continue
|
| 128 |
+
|
| 129 |
+
window = self.read_window(window_fnames)
|
| 130 |
+
if window is None:
|
| 131 |
+
continue
|
| 132 |
+
|
| 133 |
+
wrong_window = self.read_window(wrong_window_fnames)
|
| 134 |
+
if wrong_window is None:
|
| 135 |
+
continue
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
wavpath = join(vidname, "audio.wav")
|
| 139 |
+
wav = audio.load_wav(wavpath, hparams.sample_rate)
|
| 140 |
+
|
| 141 |
+
orig_mel = audio.melspectrogram(wav).T
|
| 142 |
+
except Exception as e:
|
| 143 |
+
continue
|
| 144 |
+
|
| 145 |
+
mel = self.crop_audio_window(orig_mel.copy(), img_name)
|
| 146 |
+
|
| 147 |
+
if (mel.shape[0] != syncnet_mel_step_size):
|
| 148 |
+
continue
|
| 149 |
+
|
| 150 |
+
indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name)
|
| 151 |
+
if indiv_mels is None: continue
|
| 152 |
+
|
| 153 |
+
window = self.prepare_window(window)
|
| 154 |
+
y = window.copy()
|
| 155 |
+
window[:, :, window.shape[2]//2:] = 0.
|
| 156 |
+
|
| 157 |
+
wrong_window = self.prepare_window(wrong_window)
|
| 158 |
+
x = np.concatenate([window, wrong_window], axis=0)
|
| 159 |
+
|
| 160 |
+
x = torch.FloatTensor(x)
|
| 161 |
+
mel = torch.FloatTensor(mel.T).unsqueeze(0)
|
| 162 |
+
indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1)
|
| 163 |
+
y = torch.FloatTensor(y)
|
| 164 |
+
return x, indiv_mels, mel, y
|
| 165 |
+
|
| 166 |
+
def save_sample_images(x, g, gt, global_step, checkpoint_dir):
|
| 167 |
+
x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
|
| 168 |
+
g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
|
| 169 |
+
gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
|
| 170 |
+
|
| 171 |
+
refs, inps = x[..., 3:], x[..., :3]
|
| 172 |
+
folder = join(checkpoint_dir, "samples_step{:09d}".format(global_step))
|
| 173 |
+
if not os.path.exists(folder): os.mkdir(folder)
|
| 174 |
+
collage = np.concatenate((refs, inps, g, gt), axis=-2)
|
| 175 |
+
for batch_idx, c in enumerate(collage):
|
| 176 |
+
for t in range(len(c)):
|
| 177 |
+
cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t])
|
| 178 |
+
|
| 179 |
+
logloss = nn.BCELoss()
|
| 180 |
+
def cosine_loss(a, v, y):
|
| 181 |
+
d = nn.functional.cosine_similarity(a, v)
|
| 182 |
+
loss = logloss(d.unsqueeze(1), y)
|
| 183 |
+
|
| 184 |
+
return loss
|
| 185 |
+
|
| 186 |
+
device = torch.device("cuda" if use_cuda else "cpu")
|
| 187 |
+
syncnet = SyncNet().to(device)
|
| 188 |
+
for p in syncnet.parameters():
|
| 189 |
+
p.requires_grad = False
|
| 190 |
+
|
| 191 |
+
recon_loss = nn.L1Loss()
|
| 192 |
+
def get_sync_loss(mel, g):
|
| 193 |
+
g = g[:, :, :, g.size(3)//2:]
|
| 194 |
+
g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1)
|
| 195 |
+
# B, 3 * T, H//2, W
|
| 196 |
+
a, v = syncnet(mel, g)
|
| 197 |
+
y = torch.ones(g.size(0), 1).float().to(device)
|
| 198 |
+
return cosine_loss(a, v, y)
|
| 199 |
+
|
| 200 |
+
def train(device, model, train_data_loader, test_data_loader, optimizer,
|
| 201 |
+
checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
|
| 202 |
+
|
| 203 |
+
global global_step, global_epoch
|
| 204 |
+
resumed_step = global_step
|
| 205 |
+
|
| 206 |
+
while global_epoch < nepochs:
|
| 207 |
+
print('Starting Epoch: {}'.format(global_epoch))
|
| 208 |
+
running_sync_loss, running_l1_loss = 0., 0.
|
| 209 |
+
prog_bar = tqdm(enumerate(train_data_loader))
|
| 210 |
+
for step, (x, indiv_mels, mel, gt) in prog_bar:
|
| 211 |
+
model.train()
|
| 212 |
+
optimizer.zero_grad()
|
| 213 |
+
|
| 214 |
+
# Move data to CUDA device
|
| 215 |
+
x = x.to(device)
|
| 216 |
+
mel = mel.to(device)
|
| 217 |
+
indiv_mels = indiv_mels.to(device)
|
| 218 |
+
gt = gt.to(device)
|
| 219 |
+
|
| 220 |
+
g = model(indiv_mels, x)
|
| 221 |
+
|
| 222 |
+
if hparams.syncnet_wt > 0.:
|
| 223 |
+
sync_loss = get_sync_loss(mel, g)
|
| 224 |
+
else:
|
| 225 |
+
sync_loss = 0.
|
| 226 |
+
|
| 227 |
+
l1loss = recon_loss(g, gt)
|
| 228 |
+
|
| 229 |
+
loss = hparams.syncnet_wt * sync_loss + (1 - hparams.syncnet_wt) * l1loss
|
| 230 |
+
loss.backward()
|
| 231 |
+
optimizer.step()
|
| 232 |
+
|
| 233 |
+
if global_step % checkpoint_interval == 0:
|
| 234 |
+
save_sample_images(x, g, gt, global_step, checkpoint_dir)
|
| 235 |
+
|
| 236 |
+
global_step += 1
|
| 237 |
+
cur_session_steps = global_step - resumed_step
|
| 238 |
+
|
| 239 |
+
running_l1_loss += l1loss.item()
|
| 240 |
+
if hparams.syncnet_wt > 0.:
|
| 241 |
+
running_sync_loss += sync_loss.item()
|
| 242 |
+
else:
|
| 243 |
+
running_sync_loss += 0.
|
| 244 |
+
|
| 245 |
+
if global_step == 1 or global_step % checkpoint_interval == 0:
|
| 246 |
+
save_checkpoint(
|
| 247 |
+
model, optimizer, global_step, checkpoint_dir, global_epoch)
|
| 248 |
+
|
| 249 |
+
if global_step == 1 or global_step % hparams.eval_interval == 0:
|
| 250 |
+
with torch.no_grad():
|
| 251 |
+
average_sync_loss = eval_model(test_data_loader, global_step, device, model, checkpoint_dir)
|
| 252 |
+
|
| 253 |
+
if average_sync_loss < .75:
|
| 254 |
+
hparams.set_hparam('syncnet_wt', 0.01) # without image GAN a lesser weight is sufficient
|
| 255 |
+
|
| 256 |
+
prog_bar.set_description('L1: {}, Sync Loss: {}'.format(running_l1_loss / (step + 1),
|
| 257 |
+
running_sync_loss / (step + 1)))
|
| 258 |
+
|
| 259 |
+
global_epoch += 1
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def eval_model(test_data_loader, global_step, device, model, checkpoint_dir):
|
| 263 |
+
eval_steps = 700
|
| 264 |
+
print('Evaluating for {} steps'.format(eval_steps))
|
| 265 |
+
sync_losses, recon_losses = [], []
|
| 266 |
+
step = 0
|
| 267 |
+
while 1:
|
| 268 |
+
for x, indiv_mels, mel, gt in test_data_loader:
|
| 269 |
+
step += 1
|
| 270 |
+
model.eval()
|
| 271 |
+
|
| 272 |
+
# Move data to CUDA device
|
| 273 |
+
x = x.to(device)
|
| 274 |
+
gt = gt.to(device)
|
| 275 |
+
indiv_mels = indiv_mels.to(device)
|
| 276 |
+
mel = mel.to(device)
|
| 277 |
+
|
| 278 |
+
g = model(indiv_mels, x)
|
| 279 |
+
|
| 280 |
+
sync_loss = get_sync_loss(mel, g)
|
| 281 |
+
l1loss = recon_loss(g, gt)
|
| 282 |
+
|
| 283 |
+
sync_losses.append(sync_loss.item())
|
| 284 |
+
recon_losses.append(l1loss.item())
|
| 285 |
+
|
| 286 |
+
if step > eval_steps:
|
| 287 |
+
averaged_sync_loss = sum(sync_losses) / len(sync_losses)
|
| 288 |
+
averaged_recon_loss = sum(recon_losses) / len(recon_losses)
|
| 289 |
+
|
| 290 |
+
print('L1: {}, Sync loss: {}'.format(averaged_recon_loss, averaged_sync_loss))
|
| 291 |
+
|
| 292 |
+
return averaged_sync_loss
|
| 293 |
+
|
| 294 |
+
def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch):
|
| 295 |
+
|
| 296 |
+
checkpoint_path = join(
|
| 297 |
+
checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step))
|
| 298 |
+
optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
|
| 299 |
+
torch.save({
|
| 300 |
+
"state_dict": model.state_dict(),
|
| 301 |
+
"optimizer": optimizer_state,
|
| 302 |
+
"global_step": step,
|
| 303 |
+
"global_epoch": epoch,
|
| 304 |
+
}, checkpoint_path)
|
| 305 |
+
print("Saved checkpoint:", checkpoint_path)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def _load(checkpoint_path):
|
| 309 |
+
if use_cuda:
|
| 310 |
+
checkpoint = torch.load(checkpoint_path)
|
| 311 |
+
else:
|
| 312 |
+
checkpoint = torch.load(checkpoint_path,
|
| 313 |
+
map_location=lambda storage, loc: storage)
|
| 314 |
+
return checkpoint
|
| 315 |
+
|
| 316 |
+
def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
|
| 317 |
+
global global_step
|
| 318 |
+
global global_epoch
|
| 319 |
+
|
| 320 |
+
print("Load checkpoint from: {}".format(path))
|
| 321 |
+
checkpoint = _load(path)
|
| 322 |
+
s = checkpoint["state_dict"]
|
| 323 |
+
new_s = {}
|
| 324 |
+
for k, v in s.items():
|
| 325 |
+
new_s[k.replace('module.', '')] = v
|
| 326 |
+
model.load_state_dict(new_s)
|
| 327 |
+
if not reset_optimizer:
|
| 328 |
+
optimizer_state = checkpoint["optimizer"]
|
| 329 |
+
if optimizer_state is not None:
|
| 330 |
+
print("Load optimizer state from {}".format(path))
|
| 331 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
| 332 |
+
if overwrite_global_states:
|
| 333 |
+
global_step = checkpoint["global_step"]
|
| 334 |
+
global_epoch = checkpoint["global_epoch"]
|
| 335 |
+
|
| 336 |
+
return model
|
| 337 |
+
|
| 338 |
+
if __name__ == "__main__":
|
| 339 |
+
checkpoint_dir = args.checkpoint_dir
|
| 340 |
+
|
| 341 |
+
# Dataset and Dataloader setup
|
| 342 |
+
train_dataset = Dataset('train')
|
| 343 |
+
test_dataset = Dataset('val')
|
| 344 |
+
|
| 345 |
+
train_data_loader = data_utils.DataLoader(
|
| 346 |
+
train_dataset, batch_size=hparams.batch_size, shuffle=True,
|
| 347 |
+
num_workers=hparams.num_workers)
|
| 348 |
+
|
| 349 |
+
test_data_loader = data_utils.DataLoader(
|
| 350 |
+
test_dataset, batch_size=hparams.batch_size,
|
| 351 |
+
num_workers=4)
|
| 352 |
+
|
| 353 |
+
device = torch.device("cuda" if use_cuda else "cpu")
|
| 354 |
+
|
| 355 |
+
# Model
|
| 356 |
+
model = Wav2Lip().to(device)
|
| 357 |
+
print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
| 358 |
+
|
| 359 |
+
optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
|
| 360 |
+
lr=hparams.initial_learning_rate)
|
| 361 |
+
|
| 362 |
+
if args.checkpoint_path is not None:
|
| 363 |
+
load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)
|
| 364 |
+
|
| 365 |
+
load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True, overwrite_global_states=False)
|
| 366 |
+
|
| 367 |
+
if not os.path.exists(checkpoint_dir):
|
| 368 |
+
os.mkdir(checkpoint_dir)
|
| 369 |
+
|
| 370 |
+
# Train!
|
| 371 |
+
train(device, model, train_data_loader, test_data_loader, optimizer,
|
| 372 |
+
checkpoint_dir=checkpoint_dir,
|
| 373 |
+
checkpoint_interval=hparams.checkpoint_interval,
|
| 374 |
+
nepochs=hparams.nepochs)
|