darshankr commited on
Commit
02ea507
·
verified ·
1 Parent(s): 4eb5700

Upload 41 files

Browse files
Files changed (41) hide show
  1. audio.py +136 -0
  2. checkpoints/lipsync_expert.pth +3 -0
  3. checkpoints/s3fd-619a316812.pth +3 -0
  4. checkpoints/visual_quality_disc.pth +3 -0
  5. checkpoints/wav2lip.pth +3 -0
  6. checkpoints/wav2lip_gan.pth +3 -0
  7. color_syncnet_train.py +279 -0
  8. evaluation/gen_videos_from_filelist.py +238 -0
  9. evaluation/real_videos_inference.py +305 -0
  10. evaluation/scores_LSE/SyncNetInstance_calc_scores.py +210 -0
  11. evaluation/scores_LSE/calculate_scores_LRS.py +53 -0
  12. evaluation/scores_LSE/calculate_scores_real_videos.py +45 -0
  13. evaluation/scores_LSE/calculate_scores_real_videos.sh +8 -0
  14. evaluation/test_filelists/ReSyncED/random_pairs.txt +160 -0
  15. evaluation/test_filelists/ReSyncED/tts_pairs.txt +18 -0
  16. evaluation/test_filelists/lrs2.txt +0 -0
  17. evaluation/test_filelists/lrs3.txt +0 -0
  18. evaluation/test_filelists/lrw.txt +0 -0
  19. face_detection/__init__.py +7 -0
  20. face_detection/api.py +79 -0
  21. face_detection/detection/__init__.py +1 -0
  22. face_detection/detection/core.py +130 -0
  23. face_detection/detection/sfd/__init__.py +1 -0
  24. face_detection/detection/sfd/bbox.py +129 -0
  25. face_detection/detection/sfd/detect.py +112 -0
  26. face_detection/detection/sfd/net_s3fd.py +129 -0
  27. face_detection/detection/sfd/s3fd.pth +3 -0
  28. face_detection/detection/sfd/sfd_detector.py +59 -0
  29. face_detection/models.py +261 -0
  30. face_detection/utils.py +313 -0
  31. hparams.py +101 -0
  32. hq_wav2lip_train.py +443 -0
  33. inference.py +280 -0
  34. models/__init__.py +2 -0
  35. models/conv.py +44 -0
  36. models/syncnet.py +66 -0
  37. models/wav2lip.py +184 -0
  38. preprocess.py +113 -0
  39. requirements.txt +8 -0
  40. requirementsCPU.txt +9 -0
  41. wav2lip_train.py +374 -0
audio.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import numpy as np
4
+ # import tensorflow as tf
5
+ from scipy import signal
6
+ from scipy.io import wavfile
7
+ from hparams import hparams as hp
8
+
9
+ def load_wav(path, sr):
10
+ return librosa.core.load(path, sr=sr)[0]
11
+
12
+ def save_wav(wav, path, sr):
13
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
14
+ #proposed by @dsmiller
15
+ wavfile.write(path, sr, wav.astype(np.int16))
16
+
17
+ def save_wavenet_wav(wav, path, sr):
18
+ librosa.output.write_wav(path, wav, sr=sr)
19
+
20
+ def preemphasis(wav, k, preemphasize=True):
21
+ if preemphasize:
22
+ return signal.lfilter([1, -k], [1], wav)
23
+ return wav
24
+
25
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
26
+ if inv_preemphasize:
27
+ return signal.lfilter([1], [1, -k], wav)
28
+ return wav
29
+
30
+ def get_hop_size():
31
+ hop_size = hp.hop_size
32
+ if hop_size is None:
33
+ assert hp.frame_shift_ms is not None
34
+ hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
35
+ return hop_size
36
+
37
+ def linearspectrogram(wav):
38
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
39
+ S = _amp_to_db(np.abs(D)) - hp.ref_level_db
40
+
41
+ if hp.signal_normalization:
42
+ return _normalize(S)
43
+ return S
44
+
45
+ def melspectrogram(wav):
46
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
47
+ S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
48
+
49
+ if hp.signal_normalization:
50
+ return _normalize(S)
51
+ return S
52
+
53
+ def _lws_processor():
54
+ import lws
55
+ return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
56
+
57
+ def _stft(y):
58
+ if hp.use_lws:
59
+ return _lws_processor(hp).stft(y).T
60
+ else:
61
+ return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
62
+
63
+ ##########################################################
64
+ #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
65
+ def num_frames(length, fsize, fshift):
66
+ """Compute number of time frames of spectrogram
67
+ """
68
+ pad = (fsize - fshift)
69
+ if length % fshift == 0:
70
+ M = (length + pad * 2 - fsize) // fshift + 1
71
+ else:
72
+ M = (length + pad * 2 - fsize) // fshift + 2
73
+ return M
74
+
75
+
76
+ def pad_lr(x, fsize, fshift):
77
+ """Compute left and right padding
78
+ """
79
+ M = num_frames(len(x), fsize, fshift)
80
+ pad = (fsize - fshift)
81
+ T = len(x) + 2 * pad
82
+ r = (M - 1) * fshift + fsize - T
83
+ return pad, pad + r
84
+ ##########################################################
85
+ #Librosa correct padding
86
+ def librosa_pad_lr(x, fsize, fshift):
87
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
88
+
89
+ # Conversions
90
+ _mel_basis = None
91
+
92
+ def _linear_to_mel(spectogram):
93
+ global _mel_basis
94
+ if _mel_basis is None:
95
+ _mel_basis = _build_mel_basis()
96
+ return np.dot(_mel_basis, spectogram)
97
+
98
+ def _build_mel_basis():
99
+ assert hp.fmax <= hp.sample_rate // 2
100
+ return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels,
101
+ fmin=hp.fmin, fmax=hp.fmax)
102
+
103
+ def _amp_to_db(x):
104
+ min_level = np.exp(hp.min_level_db / 20 * np.log(10))
105
+ return 20 * np.log10(np.maximum(min_level, x))
106
+
107
+ def _db_to_amp(x):
108
+ return np.power(10.0, (x) * 0.05)
109
+
110
+ def _normalize(S):
111
+ if hp.allow_clipping_in_normalization:
112
+ if hp.symmetric_mels:
113
+ return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
114
+ -hp.max_abs_value, hp.max_abs_value)
115
+ else:
116
+ return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
117
+
118
+ assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
119
+ if hp.symmetric_mels:
120
+ return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
121
+ else:
122
+ return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
123
+
124
+ def _denormalize(D):
125
+ if hp.allow_clipping_in_normalization:
126
+ if hp.symmetric_mels:
127
+ return (((np.clip(D, -hp.max_abs_value,
128
+ hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
129
+ + hp.min_level_db)
130
+ else:
131
+ return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
132
+
133
+ if hp.symmetric_mels:
134
+ return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
135
+ else:
136
+ return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
checkpoints/lipsync_expert.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa1f06d61ae86c47074ff9bc1bb7d0c40ab2d840724dc9258e255a8fab4b3559
3
+ size 134
checkpoints/s3fd-619a316812.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7636d0c9d2a8f4759aef537cbcc25c5fa2eb2d5d80b1fada4dcc800e967cf381
3
+ size 133
checkpoints/visual_quality_disc.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bdb40d624f6a1e9a07beec0f6bb5a19a91a9fac46ce6bcfd282fd9ccf1c3d3fc
3
+ size 134
checkpoints/wav2lip.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8e58726ef72ac961e2fea864e93e10fd64076222e5bd98394736684aa63dd2d
3
+ size 131
checkpoints/wav2lip_gan.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:483f94a71bfd57ff73a2464a661b9af5766ce54c2ad1f06def1a2e1d8b8cd78a
3
+ size 134
color_syncnet_train.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import dirname, join, basename, isfile
2
+ from tqdm import tqdm
3
+
4
+ from models import SyncNet_color as SyncNet
5
+ import audio
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch import optim
10
+ import torch.backends.cudnn as cudnn
11
+ from torch.utils import data as data_utils
12
+ import numpy as np
13
+
14
+ from glob import glob
15
+
16
+ import os, random, cv2, argparse
17
+ from hparams import hparams, get_image_list
18
+
19
+ parser = argparse.ArgumentParser(description='Code to train the expert lip-sync discriminator')
20
+
21
+ parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True)
22
+
23
+ parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
24
+ parser.add_argument('--checkpoint_path', help='Resumed from this checkpoint', default=None, type=str)
25
+
26
+ args = parser.parse_args()
27
+
28
+
29
+ global_step = 0
30
+ global_epoch = 0
31
+ use_cuda = torch.cuda.is_available()
32
+ print('use_cuda: {}'.format(use_cuda))
33
+
34
+ syncnet_T = 5
35
+ syncnet_mel_step_size = 16
36
+
37
+ class Dataset(object):
38
+ def __init__(self, split):
39
+ self.all_videos = get_image_list(args.data_root, split)
40
+
41
+ def get_frame_id(self, frame):
42
+ return int(basename(frame).split('.')[0])
43
+
44
+ def get_window(self, start_frame):
45
+ start_id = self.get_frame_id(start_frame)
46
+ vidname = dirname(start_frame)
47
+
48
+ window_fnames = []
49
+ for frame_id in range(start_id, start_id + syncnet_T):
50
+ frame = join(vidname, '{}.jpg'.format(frame_id))
51
+ if not isfile(frame):
52
+ return None
53
+ window_fnames.append(frame)
54
+ return window_fnames
55
+
56
+ def crop_audio_window(self, spec, start_frame):
57
+ # num_frames = (T x hop_size * fps) / sample_rate
58
+ start_frame_num = self.get_frame_id(start_frame)
59
+ start_idx = int(80. * (start_frame_num / float(hparams.fps)))
60
+
61
+ end_idx = start_idx + syncnet_mel_step_size
62
+
63
+ return spec[start_idx : end_idx, :]
64
+
65
+
66
+ def __len__(self):
67
+ return len(self.all_videos)
68
+
69
+ def __getitem__(self, idx):
70
+ while 1:
71
+ idx = random.randint(0, len(self.all_videos) - 1)
72
+ vidname = self.all_videos[idx]
73
+
74
+ img_names = list(glob(join(vidname, '*.jpg')))
75
+ if len(img_names) <= 3 * syncnet_T:
76
+ continue
77
+ img_name = random.choice(img_names)
78
+ wrong_img_name = random.choice(img_names)
79
+ while wrong_img_name == img_name:
80
+ wrong_img_name = random.choice(img_names)
81
+
82
+ if random.choice([True, False]):
83
+ y = torch.ones(1).float()
84
+ chosen = img_name
85
+ else:
86
+ y = torch.zeros(1).float()
87
+ chosen = wrong_img_name
88
+
89
+ window_fnames = self.get_window(chosen)
90
+ if window_fnames is None:
91
+ continue
92
+
93
+ window = []
94
+ all_read = True
95
+ for fname in window_fnames:
96
+ img = cv2.imread(fname)
97
+ if img is None:
98
+ all_read = False
99
+ break
100
+ try:
101
+ img = cv2.resize(img, (hparams.img_size, hparams.img_size))
102
+ except Exception as e:
103
+ all_read = False
104
+ break
105
+
106
+ window.append(img)
107
+
108
+ if not all_read: continue
109
+
110
+ try:
111
+ wavpath = join(vidname, "audio.wav")
112
+ wav = audio.load_wav(wavpath, hparams.sample_rate)
113
+
114
+ orig_mel = audio.melspectrogram(wav).T
115
+ except Exception as e:
116
+ continue
117
+
118
+ mel = self.crop_audio_window(orig_mel.copy(), img_name)
119
+
120
+ if (mel.shape[0] != syncnet_mel_step_size):
121
+ continue
122
+
123
+ # H x W x 3 * T
124
+ x = np.concatenate(window, axis=2) / 255.
125
+ x = x.transpose(2, 0, 1)
126
+ x = x[:, x.shape[1]//2:]
127
+
128
+ x = torch.FloatTensor(x)
129
+ mel = torch.FloatTensor(mel.T).unsqueeze(0)
130
+
131
+ return x, mel, y
132
+
133
+ logloss = nn.BCELoss()
134
+ def cosine_loss(a, v, y):
135
+ d = nn.functional.cosine_similarity(a, v)
136
+ loss = logloss(d.unsqueeze(1), y)
137
+
138
+ return loss
139
+
140
+ def train(device, model, train_data_loader, test_data_loader, optimizer,
141
+ checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
142
+
143
+ global global_step, global_epoch
144
+ resumed_step = global_step
145
+
146
+ while global_epoch < nepochs:
147
+ running_loss = 0.
148
+ prog_bar = tqdm(enumerate(train_data_loader))
149
+ for step, (x, mel, y) in prog_bar:
150
+ model.train()
151
+ optimizer.zero_grad()
152
+
153
+ # Transform data to CUDA device
154
+ x = x.to(device)
155
+
156
+ mel = mel.to(device)
157
+
158
+ a, v = model(mel, x)
159
+ y = y.to(device)
160
+
161
+ loss = cosine_loss(a, v, y)
162
+ loss.backward()
163
+ optimizer.step()
164
+
165
+ global_step += 1
166
+ cur_session_steps = global_step - resumed_step
167
+ running_loss += loss.item()
168
+
169
+ if global_step == 1 or global_step % checkpoint_interval == 0:
170
+ save_checkpoint(
171
+ model, optimizer, global_step, checkpoint_dir, global_epoch)
172
+
173
+ if global_step % hparams.syncnet_eval_interval == 0:
174
+ with torch.no_grad():
175
+ eval_model(test_data_loader, global_step, device, model, checkpoint_dir)
176
+
177
+ prog_bar.set_description('Loss: {}'.format(running_loss / (step + 1)))
178
+
179
+ global_epoch += 1
180
+
181
+ def eval_model(test_data_loader, global_step, device, model, checkpoint_dir):
182
+ eval_steps = 1400
183
+ print('Evaluating for {} steps'.format(eval_steps))
184
+ losses = []
185
+ while 1:
186
+ for step, (x, mel, y) in enumerate(test_data_loader):
187
+
188
+ model.eval()
189
+
190
+ # Transform data to CUDA device
191
+ x = x.to(device)
192
+
193
+ mel = mel.to(device)
194
+
195
+ a, v = model(mel, x)
196
+ y = y.to(device)
197
+
198
+ loss = cosine_loss(a, v, y)
199
+ losses.append(loss.item())
200
+
201
+ if step > eval_steps: break
202
+
203
+ averaged_loss = sum(losses) / len(losses)
204
+ print(averaged_loss)
205
+
206
+ return
207
+
208
+ def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch):
209
+
210
+ checkpoint_path = join(
211
+ checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step))
212
+ optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
213
+ torch.save({
214
+ "state_dict": model.state_dict(),
215
+ "optimizer": optimizer_state,
216
+ "global_step": step,
217
+ "global_epoch": epoch,
218
+ }, checkpoint_path)
219
+ print("Saved checkpoint:", checkpoint_path)
220
+
221
+ def _load(checkpoint_path):
222
+ if use_cuda:
223
+ checkpoint = torch.load(checkpoint_path)
224
+ else:
225
+ checkpoint = torch.load(checkpoint_path,
226
+ map_location=lambda storage, loc: storage)
227
+ return checkpoint
228
+
229
+ def load_checkpoint(path, model, optimizer, reset_optimizer=False):
230
+ global global_step
231
+ global global_epoch
232
+
233
+ print("Load checkpoint from: {}".format(path))
234
+ checkpoint = _load(path)
235
+ model.load_state_dict(checkpoint["state_dict"])
236
+ if not reset_optimizer:
237
+ optimizer_state = checkpoint["optimizer"]
238
+ if optimizer_state is not None:
239
+ print("Load optimizer state from {}".format(path))
240
+ optimizer.load_state_dict(checkpoint["optimizer"])
241
+ global_step = checkpoint["global_step"]
242
+ global_epoch = checkpoint["global_epoch"]
243
+
244
+ return model
245
+
246
+ if __name__ == "__main__":
247
+ checkpoint_dir = args.checkpoint_dir
248
+ checkpoint_path = args.checkpoint_path
249
+
250
+ if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir)
251
+
252
+ # Dataset and Dataloader setup
253
+ train_dataset = Dataset('train')
254
+ test_dataset = Dataset('val')
255
+
256
+ train_data_loader = data_utils.DataLoader(
257
+ train_dataset, batch_size=hparams.syncnet_batch_size, shuffle=True,
258
+ num_workers=hparams.num_workers)
259
+
260
+ test_data_loader = data_utils.DataLoader(
261
+ test_dataset, batch_size=hparams.syncnet_batch_size,
262
+ num_workers=8)
263
+
264
+ device = torch.device("cuda" if use_cuda else "cpu")
265
+
266
+ # Model
267
+ model = SyncNet().to(device)
268
+ print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
269
+
270
+ optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
271
+ lr=hparams.syncnet_lr)
272
+
273
+ if checkpoint_path is not None:
274
+ load_checkpoint(checkpoint_path, model, optimizer, reset_optimizer=False)
275
+
276
+ train(device, model, train_data_loader, test_data_loader, optimizer,
277
+ checkpoint_dir=checkpoint_dir,
278
+ checkpoint_interval=hparams.syncnet_checkpoint_interval,
279
+ nepochs=hparams.nepochs)
evaluation/gen_videos_from_filelist.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import listdir, path
2
+ import numpy as np
3
+ import scipy, cv2, os, sys, argparse
4
+ import dlib, json, subprocess
5
+ from tqdm import tqdm
6
+ from glob import glob
7
+ import torch
8
+
9
+ sys.path.append('../')
10
+ import audio
11
+ import face_detection
12
+ from models import Wav2Lip
13
+
14
+ parser = argparse.ArgumentParser(description='Code to generate results for test filelists')
15
+
16
+ parser.add_argument('--filelist', type=str,
17
+ help='Filepath of filelist file to read', required=True)
18
+ parser.add_argument('--results_dir', type=str, help='Folder to save all results into',
19
+ required=True)
20
+ parser.add_argument('--data_root', type=str, required=True)
21
+ parser.add_argument('--checkpoint_path', type=str,
22
+ help='Name of saved checkpoint to load weights from', required=True)
23
+
24
+ parser.add_argument('--pads', nargs='+', type=int, default=[0, 0, 0, 0],
25
+ help='Padding (top, bottom, left, right)')
26
+ parser.add_argument('--face_det_batch_size', type=int,
27
+ help='Single GPU batch size for face detection', default=64)
28
+ parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip', default=128)
29
+
30
+ # parser.add_argument('--resize_factor', default=1, type=int)
31
+
32
+ args = parser.parse_args()
33
+ args.img_size = 96
34
+
35
+ def get_smoothened_boxes(boxes, T):
36
+ for i in range(len(boxes)):
37
+ if i + T > len(boxes):
38
+ window = boxes[len(boxes) - T:]
39
+ else:
40
+ window = boxes[i : i + T]
41
+ boxes[i] = np.mean(window, axis=0)
42
+ return boxes
43
+
44
+ def face_detect(images):
45
+ batch_size = args.face_det_batch_size
46
+
47
+ while 1:
48
+ predictions = []
49
+ try:
50
+ for i in range(0, len(images), batch_size):
51
+ predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
52
+ except RuntimeError:
53
+ if batch_size == 1:
54
+ raise RuntimeError('Image too big to run face detection on GPU')
55
+ batch_size //= 2
56
+ args.face_det_batch_size = batch_size
57
+ print('Recovering from OOM error; New batch size: {}'.format(batch_size))
58
+ continue
59
+ break
60
+
61
+ results = []
62
+ pady1, pady2, padx1, padx2 = args.pads
63
+ for rect, image in zip(predictions, images):
64
+ if rect is None:
65
+ raise ValueError('Face not detected!')
66
+
67
+ y1 = max(0, rect[1] - pady1)
68
+ y2 = min(image.shape[0], rect[3] + pady2)
69
+ x1 = max(0, rect[0] - padx1)
70
+ x2 = min(image.shape[1], rect[2] + padx2)
71
+
72
+ results.append([x1, y1, x2, y2])
73
+
74
+ boxes = get_smoothened_boxes(np.array(results), T=5)
75
+ results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
76
+
77
+ return results
78
+
79
+ def datagen(frames, face_det_results, mels):
80
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
81
+
82
+ for i, m in enumerate(mels):
83
+ if i >= len(frames): raise ValueError('Equal or less lengths only')
84
+
85
+ frame_to_save = frames[i].copy()
86
+ face, coords, valid_frame = face_det_results[i].copy()
87
+ if not valid_frame:
88
+ continue
89
+
90
+ face = cv2.resize(face, (args.img_size, args.img_size))
91
+
92
+ img_batch.append(face)
93
+ mel_batch.append(m)
94
+ frame_batch.append(frame_to_save)
95
+ coords_batch.append(coords)
96
+
97
+ if len(img_batch) >= args.wav2lip_batch_size:
98
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
99
+
100
+ img_masked = img_batch.copy()
101
+ img_masked[:, args.img_size//2:] = 0
102
+
103
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
104
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
105
+
106
+ yield img_batch, mel_batch, frame_batch, coords_batch
107
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
108
+
109
+ if len(img_batch) > 0:
110
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
111
+
112
+ img_masked = img_batch.copy()
113
+ img_masked[:, args.img_size//2:] = 0
114
+
115
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
116
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
117
+
118
+ yield img_batch, mel_batch, frame_batch, coords_batch
119
+
120
+ fps = 25
121
+ mel_step_size = 16
122
+ mel_idx_multiplier = 80./fps
123
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
124
+ print('Using {} for inference.'.format(device))
125
+
126
+ detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
127
+ flip_input=False, device=device)
128
+
129
+ def _load(checkpoint_path):
130
+ if device == 'cuda':
131
+ checkpoint = torch.load(checkpoint_path)
132
+ else:
133
+ checkpoint = torch.load(checkpoint_path,
134
+ map_location=lambda storage, loc: storage)
135
+ return checkpoint
136
+
137
+ def load_model(path):
138
+ model = Wav2Lip()
139
+ print("Load checkpoint from: {}".format(path))
140
+ checkpoint = _load(path)
141
+ s = checkpoint["state_dict"]
142
+ new_s = {}
143
+ for k, v in s.items():
144
+ new_s[k.replace('module.', '')] = v
145
+ model.load_state_dict(new_s)
146
+
147
+ model = model.to(device)
148
+ return model.eval()
149
+
150
+ model = load_model(args.checkpoint_path)
151
+
152
+ def main():
153
+ assert args.data_root is not None
154
+ data_root = args.data_root
155
+
156
+ if not os.path.isdir(args.results_dir): os.makedirs(args.results_dir)
157
+
158
+ with open(args.filelist, 'r') as filelist:
159
+ lines = filelist.readlines()
160
+
161
+ for idx, line in enumerate(tqdm(lines)):
162
+ audio_src, video = line.strip().split()
163
+
164
+ audio_src = os.path.join(data_root, audio_src) + '.mp4'
165
+ video = os.path.join(data_root, video) + '.mp4'
166
+
167
+ command = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'.format(audio_src, '../temp/temp.wav')
168
+ subprocess.call(command, shell=True)
169
+ temp_audio = '../temp/temp.wav'
170
+
171
+ wav = audio.load_wav(temp_audio, 16000)
172
+ mel = audio.melspectrogram(wav)
173
+ if np.isnan(mel.reshape(-1)).sum() > 0:
174
+ continue
175
+
176
+ mel_chunks = []
177
+ i = 0
178
+ while 1:
179
+ start_idx = int(i * mel_idx_multiplier)
180
+ if start_idx + mel_step_size > len(mel[0]):
181
+ break
182
+ mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
183
+ i += 1
184
+
185
+ video_stream = cv2.VideoCapture(video)
186
+
187
+ full_frames = []
188
+ while 1:
189
+ still_reading, frame = video_stream.read()
190
+ if not still_reading or len(full_frames) > len(mel_chunks):
191
+ video_stream.release()
192
+ break
193
+ full_frames.append(frame)
194
+
195
+ if len(full_frames) < len(mel_chunks):
196
+ continue
197
+
198
+ full_frames = full_frames[:len(mel_chunks)]
199
+
200
+ try:
201
+ face_det_results = face_detect(full_frames.copy())
202
+ except ValueError as e:
203
+ continue
204
+
205
+ batch_size = args.wav2lip_batch_size
206
+ gen = datagen(full_frames.copy(), face_det_results, mel_chunks)
207
+
208
+ for i, (img_batch, mel_batch, frames, coords) in enumerate(gen):
209
+ if i == 0:
210
+ frame_h, frame_w = full_frames[0].shape[:-1]
211
+ out = cv2.VideoWriter('../temp/result.avi',
212
+ cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
213
+
214
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
215
+ mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
216
+
217
+ with torch.no_grad():
218
+ pred = model(mel_batch, img_batch)
219
+
220
+
221
+ pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
222
+
223
+ for pl, f, c in zip(pred, frames, coords):
224
+ y1, y2, x1, x2 = c
225
+ pl = cv2.resize(pl.astype(np.uint8), (x2 - x1, y2 - y1))
226
+ f[y1:y2, x1:x2] = pl
227
+ out.write(f)
228
+
229
+ out.release()
230
+
231
+ vid = os.path.join(args.results_dir, '{}.mp4'.format(idx))
232
+
233
+ command = 'ffmpeg -loglevel panic -y -i {} -i {} -strict -2 -q:v 1 {}'.format(temp_audio,
234
+ '../temp/result.avi', vid)
235
+ subprocess.call(command, shell=True)
236
+
237
+ if __name__ == '__main__':
238
+ main()
evaluation/real_videos_inference.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import listdir, path
2
+ import numpy as np
3
+ import scipy, cv2, os, sys, argparse
4
+ import dlib, json, subprocess
5
+ from tqdm import tqdm
6
+ from glob import glob
7
+ import torch
8
+
9
+ sys.path.append('../')
10
+ import audio
11
+ import face_detection
12
+ from models import Wav2Lip
13
+
14
+ parser = argparse.ArgumentParser(description='Code to generate results on ReSyncED evaluation set')
15
+
16
+ parser.add_argument('--mode', type=str,
17
+ help='random | dubbed | tts', required=True)
18
+
19
+ parser.add_argument('--filelist', type=str,
20
+ help='Filepath of filelist file to read', default=None)
21
+
22
+ parser.add_argument('--results_dir', type=str, help='Folder to save all results into',
23
+ required=True)
24
+ parser.add_argument('--data_root', type=str, required=True)
25
+ parser.add_argument('--checkpoint_path', type=str,
26
+ help='Name of saved checkpoint to load weights from', required=True)
27
+ parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
28
+ help='Padding (top, bottom, left, right)')
29
+
30
+ parser.add_argument('--face_det_batch_size', type=int,
31
+ help='Single GPU batch size for face detection', default=16)
32
+
33
+ parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip', default=128)
34
+ parser.add_argument('--face_res', help='Approximate resolution of the face at which to test', default=180)
35
+ parser.add_argument('--min_frame_res', help='Do not downsample further below this frame resolution', default=480)
36
+ parser.add_argument('--max_frame_res', help='Downsample to at least this frame resolution', default=720)
37
+ # parser.add_argument('--resize_factor', default=1, type=int)
38
+
39
+ args = parser.parse_args()
40
+ args.img_size = 96
41
+
42
+ def get_smoothened_boxes(boxes, T):
43
+ for i in range(len(boxes)):
44
+ if i + T > len(boxes):
45
+ window = boxes[len(boxes) - T:]
46
+ else:
47
+ window = boxes[i : i + T]
48
+ boxes[i] = np.mean(window, axis=0)
49
+ return boxes
50
+
51
+ def rescale_frames(images):
52
+ rect = detector.get_detections_for_batch(np.array([images[0]]))[0]
53
+ if rect is None:
54
+ raise ValueError('Face not detected!')
55
+ h, w = images[0].shape[:-1]
56
+
57
+ x1, y1, x2, y2 = rect
58
+
59
+ face_size = max(np.abs(y1 - y2), np.abs(x1 - x2))
60
+
61
+ diff = np.abs(face_size - args.face_res)
62
+ for factor in range(2, 16):
63
+ downsampled_res = face_size // factor
64
+ if min(h//factor, w//factor) < args.min_frame_res: break
65
+ if np.abs(downsampled_res - args.face_res) >= diff: break
66
+
67
+ factor -= 1
68
+ if factor == 1: return images
69
+
70
+ return [cv2.resize(im, (im.shape[1]//(factor), im.shape[0]//(factor))) for im in images]
71
+
72
+
73
+ def face_detect(images):
74
+ batch_size = args.face_det_batch_size
75
+ images = rescale_frames(images)
76
+
77
+ while 1:
78
+ predictions = []
79
+ try:
80
+ for i in range(0, len(images), batch_size):
81
+ predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
82
+ except RuntimeError:
83
+ if batch_size == 1:
84
+ raise RuntimeError('Image too big to run face detection on GPU')
85
+ batch_size //= 2
86
+ print('Recovering from OOM error; New batch size: {}'.format(batch_size))
87
+ continue
88
+ break
89
+
90
+ results = []
91
+ pady1, pady2, padx1, padx2 = args.pads
92
+ for rect, image in zip(predictions, images):
93
+ if rect is None:
94
+ raise ValueError('Face not detected!')
95
+
96
+ y1 = max(0, rect[1] - pady1)
97
+ y2 = min(image.shape[0], rect[3] + pady2)
98
+ x1 = max(0, rect[0] - padx1)
99
+ x2 = min(image.shape[1], rect[2] + padx2)
100
+
101
+ results.append([x1, y1, x2, y2])
102
+
103
+ boxes = get_smoothened_boxes(np.array(results), T=5)
104
+ results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
105
+
106
+ return results, images
107
+
108
+ def datagen(frames, face_det_results, mels):
109
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
110
+
111
+ for i, m in enumerate(mels):
112
+ if i >= len(frames): raise ValueError('Equal or less lengths only')
113
+
114
+ frame_to_save = frames[i].copy()
115
+ face, coords, valid_frame = face_det_results[i].copy()
116
+ if not valid_frame:
117
+ continue
118
+
119
+ face = cv2.resize(face, (args.img_size, args.img_size))
120
+
121
+ img_batch.append(face)
122
+ mel_batch.append(m)
123
+ frame_batch.append(frame_to_save)
124
+ coords_batch.append(coords)
125
+
126
+ if len(img_batch) >= args.wav2lip_batch_size:
127
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
128
+
129
+ img_masked = img_batch.copy()
130
+ img_masked[:, args.img_size//2:] = 0
131
+
132
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
133
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
134
+
135
+ yield img_batch, mel_batch, frame_batch, coords_batch
136
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
137
+
138
+ if len(img_batch) > 0:
139
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
140
+
141
+ img_masked = img_batch.copy()
142
+ img_masked[:, args.img_size//2:] = 0
143
+
144
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
145
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
146
+
147
+ yield img_batch, mel_batch, frame_batch, coords_batch
148
+
149
+ def increase_frames(frames, l):
150
+ ## evenly duplicating frames to increase length of video
151
+ while len(frames) < l:
152
+ dup_every = float(l) / len(frames)
153
+
154
+ final_frames = []
155
+ next_duplicate = 0.
156
+
157
+ for i, f in enumerate(frames):
158
+ final_frames.append(f)
159
+
160
+ if int(np.ceil(next_duplicate)) == i:
161
+ final_frames.append(f)
162
+
163
+ next_duplicate += dup_every
164
+
165
+ frames = final_frames
166
+
167
+ return frames[:l]
168
+
169
+ mel_step_size = 16
170
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
171
+ print('Using {} for inference.'.format(device))
172
+
173
+ detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
174
+ flip_input=False, device=device)
175
+
176
+ def _load(checkpoint_path):
177
+ if device == 'cuda':
178
+ checkpoint = torch.load(checkpoint_path)
179
+ else:
180
+ checkpoint = torch.load(checkpoint_path,
181
+ map_location=lambda storage, loc: storage)
182
+ return checkpoint
183
+
184
+ def load_model(path):
185
+ model = Wav2Lip()
186
+ print("Load checkpoint from: {}".format(path))
187
+ checkpoint = _load(path)
188
+ s = checkpoint["state_dict"]
189
+ new_s = {}
190
+ for k, v in s.items():
191
+ new_s[k.replace('module.', '')] = v
192
+ model.load_state_dict(new_s)
193
+
194
+ model = model.to(device)
195
+ return model.eval()
196
+
197
+ model = load_model(args.checkpoint_path)
198
+
199
+ def main():
200
+ if not os.path.isdir(args.results_dir): os.makedirs(args.results_dir)
201
+
202
+ if args.mode == 'dubbed':
203
+ files = listdir(args.data_root)
204
+ lines = ['{} {}'.format(f, f) for f in files]
205
+
206
+ else:
207
+ assert args.filelist is not None
208
+ with open(args.filelist, 'r') as filelist:
209
+ lines = filelist.readlines()
210
+
211
+ for idx, line in enumerate(tqdm(lines)):
212
+ video, audio_src = line.strip().split()
213
+
214
+ audio_src = os.path.join(args.data_root, audio_src)
215
+ video = os.path.join(args.data_root, video)
216
+
217
+ command = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'.format(audio_src, '../temp/temp.wav')
218
+ subprocess.call(command, shell=True)
219
+ temp_audio = '../temp/temp.wav'
220
+
221
+ wav = audio.load_wav(temp_audio, 16000)
222
+ mel = audio.melspectrogram(wav)
223
+
224
+ if np.isnan(mel.reshape(-1)).sum() > 0:
225
+ raise ValueError('Mel contains nan!')
226
+
227
+ video_stream = cv2.VideoCapture(video)
228
+
229
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
230
+ mel_idx_multiplier = 80./fps
231
+
232
+ full_frames = []
233
+ while 1:
234
+ still_reading, frame = video_stream.read()
235
+ if not still_reading:
236
+ video_stream.release()
237
+ break
238
+
239
+ if min(frame.shape[:-1]) > args.max_frame_res:
240
+ h, w = frame.shape[:-1]
241
+ scale_factor = min(h, w) / float(args.max_frame_res)
242
+ h = int(h/scale_factor)
243
+ w = int(w/scale_factor)
244
+
245
+ frame = cv2.resize(frame, (w, h))
246
+ full_frames.append(frame)
247
+
248
+ mel_chunks = []
249
+ i = 0
250
+ while 1:
251
+ start_idx = int(i * mel_idx_multiplier)
252
+ if start_idx + mel_step_size > len(mel[0]):
253
+ break
254
+ mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
255
+ i += 1
256
+
257
+ if len(full_frames) < len(mel_chunks):
258
+ if args.mode == 'tts':
259
+ full_frames = increase_frames(full_frames, len(mel_chunks))
260
+ else:
261
+ raise ValueError('#Frames, audio length mismatch')
262
+
263
+ else:
264
+ full_frames = full_frames[:len(mel_chunks)]
265
+
266
+ try:
267
+ face_det_results, full_frames = face_detect(full_frames.copy())
268
+ except ValueError as e:
269
+ continue
270
+
271
+ batch_size = args.wav2lip_batch_size
272
+ gen = datagen(full_frames.copy(), face_det_results, mel_chunks)
273
+
274
+ for i, (img_batch, mel_batch, frames, coords) in enumerate(gen):
275
+ if i == 0:
276
+ frame_h, frame_w = full_frames[0].shape[:-1]
277
+
278
+ out = cv2.VideoWriter('../temp/result.avi',
279
+ cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
280
+
281
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
282
+ mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
283
+
284
+ with torch.no_grad():
285
+ pred = model(mel_batch, img_batch)
286
+
287
+
288
+ pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
289
+
290
+ for pl, f, c in zip(pred, frames, coords):
291
+ y1, y2, x1, x2 = c
292
+ pl = cv2.resize(pl.astype(np.uint8), (x2 - x1, y2 - y1))
293
+ f[y1:y2, x1:x2] = pl
294
+ out.write(f)
295
+
296
+ out.release()
297
+
298
+ vid = os.path.join(args.results_dir, '{}.mp4'.format(idx))
299
+ command = 'ffmpeg -loglevel panic -y -i {} -i {} -strict -2 -q:v 1 {}'.format('../temp/temp.wav',
300
+ '../temp/result.avi', vid)
301
+ subprocess.call(command, shell=True)
302
+
303
+
304
+ if __name__ == '__main__':
305
+ main()
evaluation/scores_LSE/SyncNetInstance_calc_scores.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ #-*- coding: utf-8 -*-
3
+ # Video 25 FPS, Audio 16000HZ
4
+
5
+ import torch
6
+ import numpy
7
+ import time, pdb, argparse, subprocess, os, math, glob
8
+ import cv2
9
+ import python_speech_features
10
+
11
+ from scipy import signal
12
+ from scipy.io import wavfile
13
+ from SyncNetModel import *
14
+ from shutil import rmtree
15
+
16
+
17
+ # ==================== Get OFFSET ====================
18
+
19
+ def calc_pdist(feat1, feat2, vshift=10):
20
+
21
+ win_size = vshift*2+1
22
+
23
+ feat2p = torch.nn.functional.pad(feat2,(0,0,vshift,vshift))
24
+
25
+ dists = []
26
+
27
+ for i in range(0,len(feat1)):
28
+
29
+ dists.append(torch.nn.functional.pairwise_distance(feat1[[i],:].repeat(win_size, 1), feat2p[i:i+win_size,:]))
30
+
31
+ return dists
32
+
33
+ # ==================== MAIN DEF ====================
34
+
35
+ class SyncNetInstance(torch.nn.Module):
36
+
37
+ def __init__(self, dropout = 0, num_layers_in_fc_layers = 1024):
38
+ super(SyncNetInstance, self).__init__();
39
+
40
+ self.__S__ = S(num_layers_in_fc_layers = num_layers_in_fc_layers).cuda();
41
+
42
+ def evaluate(self, opt, videofile):
43
+
44
+ self.__S__.eval();
45
+
46
+ # ========== ==========
47
+ # Convert files
48
+ # ========== ==========
49
+
50
+ if os.path.exists(os.path.join(opt.tmp_dir,opt.reference)):
51
+ rmtree(os.path.join(opt.tmp_dir,opt.reference))
52
+
53
+ os.makedirs(os.path.join(opt.tmp_dir,opt.reference))
54
+
55
+ command = ("ffmpeg -loglevel error -y -i %s -threads 1 -f image2 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'%06d.jpg')))
56
+ output = subprocess.call(command, shell=True, stdout=None)
57
+
58
+ command = ("ffmpeg -loglevel error -y -i %s -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'audio.wav')))
59
+ output = subprocess.call(command, shell=True, stdout=None)
60
+
61
+ # ========== ==========
62
+ # Load video
63
+ # ========== ==========
64
+
65
+ images = []
66
+
67
+ flist = glob.glob(os.path.join(opt.tmp_dir,opt.reference,'*.jpg'))
68
+ flist.sort()
69
+
70
+ for fname in flist:
71
+ img_input = cv2.imread(fname)
72
+ img_input = cv2.resize(img_input, (224,224)) #HARD CODED, CHANGE BEFORE RELEASE
73
+ images.append(img_input)
74
+
75
+ im = numpy.stack(images,axis=3)
76
+ im = numpy.expand_dims(im,axis=0)
77
+ im = numpy.transpose(im,(0,3,4,1,2))
78
+
79
+ imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
80
+
81
+ # ========== ==========
82
+ # Load audio
83
+ # ========== ==========
84
+
85
+ sample_rate, audio = wavfile.read(os.path.join(opt.tmp_dir,opt.reference,'audio.wav'))
86
+ mfcc = zip(*python_speech_features.mfcc(audio,sample_rate))
87
+ mfcc = numpy.stack([numpy.array(i) for i in mfcc])
88
+
89
+ cc = numpy.expand_dims(numpy.expand_dims(mfcc,axis=0),axis=0)
90
+ cct = torch.autograd.Variable(torch.from_numpy(cc.astype(float)).float())
91
+
92
+ # ========== ==========
93
+ # Check audio and video input length
94
+ # ========== ==========
95
+
96
+ #if (float(len(audio))/16000) != (float(len(images))/25) :
97
+ # print("WARNING: Audio (%.4fs) and video (%.4fs) lengths are different."%(float(len(audio))/16000,float(len(images))/25))
98
+
99
+ min_length = min(len(images),math.floor(len(audio)/640))
100
+
101
+ # ========== ==========
102
+ # Generate video and audio feats
103
+ # ========== ==========
104
+
105
+ lastframe = min_length-5
106
+ im_feat = []
107
+ cc_feat = []
108
+
109
+ tS = time.time()
110
+ for i in range(0,lastframe,opt.batch_size):
111
+
112
+ im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ]
113
+ im_in = torch.cat(im_batch,0)
114
+ im_out = self.__S__.forward_lip(im_in.cuda());
115
+ im_feat.append(im_out.data.cpu())
116
+
117
+ cc_batch = [ cct[:,:,:,vframe*4:vframe*4+20] for vframe in range(i,min(lastframe,i+opt.batch_size)) ]
118
+ cc_in = torch.cat(cc_batch,0)
119
+ cc_out = self.__S__.forward_aud(cc_in.cuda())
120
+ cc_feat.append(cc_out.data.cpu())
121
+
122
+ im_feat = torch.cat(im_feat,0)
123
+ cc_feat = torch.cat(cc_feat,0)
124
+
125
+ # ========== ==========
126
+ # Compute offset
127
+ # ========== ==========
128
+
129
+ #print('Compute time %.3f sec.' % (time.time()-tS))
130
+
131
+ dists = calc_pdist(im_feat,cc_feat,vshift=opt.vshift)
132
+ mdist = torch.mean(torch.stack(dists,1),1)
133
+
134
+ minval, minidx = torch.min(mdist,0)
135
+
136
+ offset = opt.vshift-minidx
137
+ conf = torch.median(mdist) - minval
138
+
139
+ fdist = numpy.stack([dist[minidx].numpy() for dist in dists])
140
+ # fdist = numpy.pad(fdist, (3,3), 'constant', constant_values=15)
141
+ fconf = torch.median(mdist).numpy() - fdist
142
+ fconfm = signal.medfilt(fconf,kernel_size=9)
143
+
144
+ numpy.set_printoptions(formatter={'float': '{: 0.3f}'.format})
145
+ #print('Framewise conf: ')
146
+ #print(fconfm)
147
+ #print('AV offset: \t%d \nMin dist: \t%.3f\nConfidence: \t%.3f' % (offset,minval,conf))
148
+
149
+ dists_npy = numpy.array([ dist.numpy() for dist in dists ])
150
+ return offset.numpy(), conf.numpy(), minval.numpy()
151
+
152
+ def extract_feature(self, opt, videofile):
153
+
154
+ self.__S__.eval();
155
+
156
+ # ========== ==========
157
+ # Load video
158
+ # ========== ==========
159
+ cap = cv2.VideoCapture(videofile)
160
+
161
+ frame_num = 1;
162
+ images = []
163
+ while frame_num:
164
+ frame_num += 1
165
+ ret, image = cap.read()
166
+ if ret == 0:
167
+ break
168
+
169
+ images.append(image)
170
+
171
+ im = numpy.stack(images,axis=3)
172
+ im = numpy.expand_dims(im,axis=0)
173
+ im = numpy.transpose(im,(0,3,4,1,2))
174
+
175
+ imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
176
+
177
+ # ========== ==========
178
+ # Generate video feats
179
+ # ========== ==========
180
+
181
+ lastframe = len(images)-4
182
+ im_feat = []
183
+
184
+ tS = time.time()
185
+ for i in range(0,lastframe,opt.batch_size):
186
+
187
+ im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ]
188
+ im_in = torch.cat(im_batch,0)
189
+ im_out = self.__S__.forward_lipfeat(im_in.cuda());
190
+ im_feat.append(im_out.data.cpu())
191
+
192
+ im_feat = torch.cat(im_feat,0)
193
+
194
+ # ========== ==========
195
+ # Compute offset
196
+ # ========== ==========
197
+
198
+ print('Compute time %.3f sec.' % (time.time()-tS))
199
+
200
+ return im_feat
201
+
202
+
203
+ def loadParameters(self, path):
204
+ loaded_state = torch.load(path, map_location=lambda storage, loc: storage);
205
+
206
+ self_state = self.__S__.state_dict();
207
+
208
+ for name, param in loaded_state.items():
209
+
210
+ self_state[name].copy_(param);
evaluation/scores_LSE/calculate_scores_LRS.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ #-*- coding: utf-8 -*-
3
+
4
+ import time, pdb, argparse, subprocess
5
+ import glob
6
+ import os
7
+ from tqdm import tqdm
8
+
9
+ from SyncNetInstance_calc_scores import *
10
+
11
+ # ==================== LOAD PARAMS ====================
12
+
13
+
14
+ parser = argparse.ArgumentParser(description = "SyncNet");
15
+
16
+ parser.add_argument('--initial_model', type=str, default="data/syncnet_v2.model", help='');
17
+ parser.add_argument('--batch_size', type=int, default='20', help='');
18
+ parser.add_argument('--vshift', type=int, default='15', help='');
19
+ parser.add_argument('--data_root', type=str, required=True, help='');
20
+ parser.add_argument('--tmp_dir', type=str, default="data/work/pytmp", help='');
21
+ parser.add_argument('--reference', type=str, default="demo", help='');
22
+
23
+ opt = parser.parse_args();
24
+
25
+
26
+ # ==================== RUN EVALUATION ====================
27
+
28
+ s = SyncNetInstance();
29
+
30
+ s.loadParameters(opt.initial_model);
31
+ #print("Model %s loaded."%opt.initial_model);
32
+ path = os.path.join(opt.data_root, "*.mp4")
33
+
34
+ all_videos = glob.glob(path)
35
+
36
+ prog_bar = tqdm(range(len(all_videos)))
37
+ avg_confidence = 0.
38
+ avg_min_distance = 0.
39
+
40
+
41
+ for videofile_idx in prog_bar:
42
+ videofile = all_videos[videofile_idx]
43
+ offset, confidence, min_distance = s.evaluate(opt, videofile=videofile)
44
+ avg_confidence += confidence
45
+ avg_min_distance += min_distance
46
+ prog_bar.set_description('Avg Confidence: {}, Avg Minimum Dist: {}'.format(round(avg_confidence / (videofile_idx + 1), 3), round(avg_min_distance / (videofile_idx + 1), 3)))
47
+ prog_bar.refresh()
48
+
49
+ print ('Average Confidence: {}'.format(avg_confidence/len(all_videos)))
50
+ print ('Average Minimum Distance: {}'.format(avg_min_distance/len(all_videos)))
51
+
52
+
53
+
evaluation/scores_LSE/calculate_scores_real_videos.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ #-*- coding: utf-8 -*-
3
+
4
+ import time, pdb, argparse, subprocess, pickle, os, gzip, glob
5
+
6
+ from SyncNetInstance_calc_scores import *
7
+
8
+ # ==================== PARSE ARGUMENT ====================
9
+
10
+ parser = argparse.ArgumentParser(description = "SyncNet");
11
+ parser.add_argument('--initial_model', type=str, default="data/syncnet_v2.model", help='');
12
+ parser.add_argument('--batch_size', type=int, default='20', help='');
13
+ parser.add_argument('--vshift', type=int, default='15', help='');
14
+ parser.add_argument('--data_dir', type=str, default='data/work', help='');
15
+ parser.add_argument('--videofile', type=str, default='', help='');
16
+ parser.add_argument('--reference', type=str, default='', help='');
17
+ opt = parser.parse_args();
18
+
19
+ setattr(opt,'avi_dir',os.path.join(opt.data_dir,'pyavi'))
20
+ setattr(opt,'tmp_dir',os.path.join(opt.data_dir,'pytmp'))
21
+ setattr(opt,'work_dir',os.path.join(opt.data_dir,'pywork'))
22
+ setattr(opt,'crop_dir',os.path.join(opt.data_dir,'pycrop'))
23
+
24
+
25
+ # ==================== LOAD MODEL AND FILE LIST ====================
26
+
27
+ s = SyncNetInstance();
28
+
29
+ s.loadParameters(opt.initial_model);
30
+ #print("Model %s loaded."%opt.initial_model);
31
+
32
+ flist = glob.glob(os.path.join(opt.crop_dir,opt.reference,'0*.avi'))
33
+ flist.sort()
34
+
35
+ # ==================== GET OFFSETS ====================
36
+
37
+ dists = []
38
+ for idx, fname in enumerate(flist):
39
+ offset, conf, dist = s.evaluate(opt,videofile=fname)
40
+ print (str(dist)+" "+str(conf))
41
+
42
+ # ==================== PRINT RESULTS TO FILE ====================
43
+
44
+ #with open(os.path.join(opt.work_dir,opt.reference,'activesd.pckl'), 'wb') as fil:
45
+ # pickle.dump(dists, fil)
evaluation/scores_LSE/calculate_scores_real_videos.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ rm all_scores.txt
2
+ yourfilenames=`ls $1`
3
+
4
+ for eachfile in $yourfilenames
5
+ do
6
+ python run_pipeline.py --videofile $1/$eachfile --reference wav2lip --data_dir tmp_dir
7
+ python calculate_scores_real_videos.py --videofile $1/$eachfile --reference wav2lip --data_dir tmp_dir >> all_scores.txt
8
+ done
evaluation/test_filelists/ReSyncED/random_pairs.txt ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sachin.mp4 emma_cropped.mp4
2
+ sachin.mp4 mourinho.mp4
3
+ sachin.mp4 elon.mp4
4
+ sachin.mp4 messi2.mp4
5
+ sachin.mp4 cr1.mp4
6
+ sachin.mp4 sachin.mp4
7
+ sachin.mp4 sg.mp4
8
+ sachin.mp4 fergi.mp4
9
+ sachin.mp4 spanish_lec1.mp4
10
+ sachin.mp4 bush_small.mp4
11
+ sachin.mp4 macca_cut.mp4
12
+ sachin.mp4 ca_cropped.mp4
13
+ sachin.mp4 lecun.mp4
14
+ sachin.mp4 spanish_lec0.mp4
15
+ srk.mp4 emma_cropped.mp4
16
+ srk.mp4 mourinho.mp4
17
+ srk.mp4 elon.mp4
18
+ srk.mp4 messi2.mp4
19
+ srk.mp4 cr1.mp4
20
+ srk.mp4 srk.mp4
21
+ srk.mp4 sachin.mp4
22
+ srk.mp4 sg.mp4
23
+ srk.mp4 fergi.mp4
24
+ srk.mp4 spanish_lec1.mp4
25
+ srk.mp4 bush_small.mp4
26
+ srk.mp4 macca_cut.mp4
27
+ srk.mp4 ca_cropped.mp4
28
+ srk.mp4 guardiola.mp4
29
+ srk.mp4 lecun.mp4
30
+ srk.mp4 spanish_lec0.mp4
31
+ cr1.mp4 emma_cropped.mp4
32
+ cr1.mp4 elon.mp4
33
+ cr1.mp4 messi2.mp4
34
+ cr1.mp4 cr1.mp4
35
+ cr1.mp4 spanish_lec1.mp4
36
+ cr1.mp4 bush_small.mp4
37
+ cr1.mp4 macca_cut.mp4
38
+ cr1.mp4 ca_cropped.mp4
39
+ cr1.mp4 lecun.mp4
40
+ cr1.mp4 spanish_lec0.mp4
41
+ macca_cut.mp4 emma_cropped.mp4
42
+ macca_cut.mp4 elon.mp4
43
+ macca_cut.mp4 messi2.mp4
44
+ macca_cut.mp4 spanish_lec1.mp4
45
+ macca_cut.mp4 macca_cut.mp4
46
+ macca_cut.mp4 ca_cropped.mp4
47
+ macca_cut.mp4 spanish_lec0.mp4
48
+ lecun.mp4 emma_cropped.mp4
49
+ lecun.mp4 elon.mp4
50
+ lecun.mp4 messi2.mp4
51
+ lecun.mp4 spanish_lec1.mp4
52
+ lecun.mp4 macca_cut.mp4
53
+ lecun.mp4 ca_cropped.mp4
54
+ lecun.mp4 lecun.mp4
55
+ lecun.mp4 spanish_lec0.mp4
56
+ messi2.mp4 emma_cropped.mp4
57
+ messi2.mp4 elon.mp4
58
+ messi2.mp4 messi2.mp4
59
+ messi2.mp4 spanish_lec1.mp4
60
+ messi2.mp4 macca_cut.mp4
61
+ messi2.mp4 ca_cropped.mp4
62
+ messi2.mp4 spanish_lec0.mp4
63
+ ca_cropped.mp4 emma_cropped.mp4
64
+ ca_cropped.mp4 elon.mp4
65
+ ca_cropped.mp4 spanish_lec1.mp4
66
+ ca_cropped.mp4 ca_cropped.mp4
67
+ ca_cropped.mp4 spanish_lec0.mp4
68
+ spanish_lec1.mp4 spanish_lec1.mp4
69
+ spanish_lec1.mp4 spanish_lec0.mp4
70
+ elon.mp4 elon.mp4
71
+ elon.mp4 spanish_lec1.mp4
72
+ elon.mp4 spanish_lec0.mp4
73
+ guardiola.mp4 emma_cropped.mp4
74
+ guardiola.mp4 mourinho.mp4
75
+ guardiola.mp4 elon.mp4
76
+ guardiola.mp4 messi2.mp4
77
+ guardiola.mp4 cr1.mp4
78
+ guardiola.mp4 sachin.mp4
79
+ guardiola.mp4 sg.mp4
80
+ guardiola.mp4 fergi.mp4
81
+ guardiola.mp4 spanish_lec1.mp4
82
+ guardiola.mp4 bush_small.mp4
83
+ guardiola.mp4 macca_cut.mp4
84
+ guardiola.mp4 ca_cropped.mp4
85
+ guardiola.mp4 guardiola.mp4
86
+ guardiola.mp4 lecun.mp4
87
+ guardiola.mp4 spanish_lec0.mp4
88
+ fergi.mp4 emma_cropped.mp4
89
+ fergi.mp4 mourinho.mp4
90
+ fergi.mp4 elon.mp4
91
+ fergi.mp4 messi2.mp4
92
+ fergi.mp4 cr1.mp4
93
+ fergi.mp4 sachin.mp4
94
+ fergi.mp4 sg.mp4
95
+ fergi.mp4 fergi.mp4
96
+ fergi.mp4 spanish_lec1.mp4
97
+ fergi.mp4 bush_small.mp4
98
+ fergi.mp4 macca_cut.mp4
99
+ fergi.mp4 ca_cropped.mp4
100
+ fergi.mp4 lecun.mp4
101
+ fergi.mp4 spanish_lec0.mp4
102
+ spanish.mp4 emma_cropped.mp4
103
+ spanish.mp4 spanish.mp4
104
+ spanish.mp4 mourinho.mp4
105
+ spanish.mp4 elon.mp4
106
+ spanish.mp4 messi2.mp4
107
+ spanish.mp4 cr1.mp4
108
+ spanish.mp4 srk.mp4
109
+ spanish.mp4 sachin.mp4
110
+ spanish.mp4 sg.mp4
111
+ spanish.mp4 fergi.mp4
112
+ spanish.mp4 spanish_lec1.mp4
113
+ spanish.mp4 bush_small.mp4
114
+ spanish.mp4 macca_cut.mp4
115
+ spanish.mp4 ca_cropped.mp4
116
+ spanish.mp4 guardiola.mp4
117
+ spanish.mp4 lecun.mp4
118
+ spanish.mp4 spanish_lec0.mp4
119
+ bush_small.mp4 emma_cropped.mp4
120
+ bush_small.mp4 elon.mp4
121
+ bush_small.mp4 messi2.mp4
122
+ bush_small.mp4 spanish_lec1.mp4
123
+ bush_small.mp4 bush_small.mp4
124
+ bush_small.mp4 macca_cut.mp4
125
+ bush_small.mp4 ca_cropped.mp4
126
+ bush_small.mp4 lecun.mp4
127
+ bush_small.mp4 spanish_lec0.mp4
128
+ emma_cropped.mp4 emma_cropped.mp4
129
+ emma_cropped.mp4 elon.mp4
130
+ emma_cropped.mp4 spanish_lec1.mp4
131
+ emma_cropped.mp4 spanish_lec0.mp4
132
+ sg.mp4 emma_cropped.mp4
133
+ sg.mp4 mourinho.mp4
134
+ sg.mp4 elon.mp4
135
+ sg.mp4 messi2.mp4
136
+ sg.mp4 cr1.mp4
137
+ sg.mp4 sachin.mp4
138
+ sg.mp4 sg.mp4
139
+ sg.mp4 fergi.mp4
140
+ sg.mp4 spanish_lec1.mp4
141
+ sg.mp4 bush_small.mp4
142
+ sg.mp4 macca_cut.mp4
143
+ sg.mp4 ca_cropped.mp4
144
+ sg.mp4 lecun.mp4
145
+ sg.mp4 spanish_lec0.mp4
146
+ spanish_lec0.mp4 spanish_lec0.mp4
147
+ mourinho.mp4 emma_cropped.mp4
148
+ mourinho.mp4 mourinho.mp4
149
+ mourinho.mp4 elon.mp4
150
+ mourinho.mp4 messi2.mp4
151
+ mourinho.mp4 cr1.mp4
152
+ mourinho.mp4 sachin.mp4
153
+ mourinho.mp4 sg.mp4
154
+ mourinho.mp4 fergi.mp4
155
+ mourinho.mp4 spanish_lec1.mp4
156
+ mourinho.mp4 bush_small.mp4
157
+ mourinho.mp4 macca_cut.mp4
158
+ mourinho.mp4 ca_cropped.mp4
159
+ mourinho.mp4 lecun.mp4
160
+ mourinho.mp4 spanish_lec0.mp4
evaluation/test_filelists/ReSyncED/tts_pairs.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_1.mp4 andreng_optimization.wav
2
+ agad_2.mp4 agad_2.wav
3
+ agad_1.mp4 agad_1.wav
4
+ agad_3.mp4 agad_3.wav
5
+ rms_prop_1.mp4 rms_prop_tts.wav
6
+ tf_1.mp4 tf_1.wav
7
+ tf_2.mp4 tf_2.wav
8
+ andrew_ng_ai_business.mp4 andrewng_business_tts.wav
9
+ covid_autopsy_1.mp4 autopsy_tts.wav
10
+ news_1.mp4 news_tts.wav
11
+ andrew_ng_fund_1.mp4 andrewng_ai_fund.wav
12
+ covid_treatments_1.mp4 covid_tts.wav
13
+ pytorch_v_tf.mp4 pytorch_vs_tf_eng.wav
14
+ pytorch_1.mp4 pytorch.wav
15
+ pkb_1.mp4 pkb_1.wav
16
+ ss_1.mp4 ss_1.wav
17
+ carlsen_1.mp4 carlsen_eng.wav
18
+ french.mp4 french.wav
evaluation/test_filelists/lrs2.txt ADDED
The diff for this file is too large to render. See raw diff
 
evaluation/test_filelists/lrs3.txt ADDED
The diff for this file is too large to render. See raw diff
 
evaluation/test_filelists/lrw.txt ADDED
The diff for this file is too large to render. See raw diff
 
face_detection/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ __author__ = """Adrian Bulat"""
4
+ __email__ = 'adrian.bulat@nottingham.ac.uk'
5
+ __version__ = '1.0.1'
6
+
7
+ from .api import FaceAlignment, LandmarksType, NetworkSize
face_detection/api.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import torch
4
+ from torch.utils.model_zoo import load_url
5
+ from enum import Enum
6
+ import numpy as np
7
+ import cv2
8
+ try:
9
+ import urllib.request as request_file
10
+ except BaseException:
11
+ import urllib as request_file
12
+
13
+ from .models import FAN, ResNetDepth
14
+ from .utils import *
15
+
16
+
17
+ class LandmarksType(Enum):
18
+ """Enum class defining the type of landmarks to detect.
19
+
20
+ ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
21
+ ``_2halfD`` - this points represent the projection of the 3D points into 3D
22
+ ``_3D`` - detect the points ``(x,y,z)``` in a 3D space
23
+
24
+ """
25
+ _2D = 1
26
+ _2halfD = 2
27
+ _3D = 3
28
+
29
+
30
+ class NetworkSize(Enum):
31
+ # TINY = 1
32
+ # SMALL = 2
33
+ # MEDIUM = 3
34
+ LARGE = 4
35
+
36
+ def __new__(cls, value):
37
+ member = object.__new__(cls)
38
+ member._value_ = value
39
+ return member
40
+
41
+ def __int__(self):
42
+ return self.value
43
+
44
+ ROOT = os.path.dirname(os.path.abspath(__file__))
45
+
46
+ class FaceAlignment:
47
+ def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
48
+ device='cuda', flip_input=False, face_detector='sfd', verbose=False):
49
+ self.device = device
50
+ self.flip_input = flip_input
51
+ self.landmarks_type = landmarks_type
52
+ self.verbose = verbose
53
+
54
+ network_size = int(network_size)
55
+
56
+ if 'cuda' in device:
57
+ torch.backends.cudnn.benchmark = True
58
+
59
+ # Get the face detector
60
+ face_detector_module = __import__('face_detection.detection.' + face_detector,
61
+ globals(), locals(), [face_detector], 0)
62
+ self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
63
+
64
+ def get_detections_for_batch(self, images):
65
+ images = images[..., ::-1]
66
+ detected_faces = self.face_detector.detect_from_batch(images.copy())
67
+ results = []
68
+
69
+ for i, d in enumerate(detected_faces):
70
+ if len(d) == 0:
71
+ results.append(None)
72
+ continue
73
+ d = d[0]
74
+ d = np.clip(d, 0, None)
75
+
76
+ x1, y1, x2, y2 = map(int, d[:-1])
77
+ results.append((x1, y1, x2, y2))
78
+
79
+ return results
face_detection/detection/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .core import FaceDetector
face_detection/detection/core.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import glob
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+ import torch
6
+ import cv2
7
+
8
+
9
+ class FaceDetector(object):
10
+ """An abstract class representing a face detector.
11
+
12
+ Any other face detection implementation must subclass it. All subclasses
13
+ must implement ``detect_from_image``, that return a list of detected
14
+ bounding boxes. Optionally, for speed considerations detect from path is
15
+ recommended.
16
+ """
17
+
18
+ def __init__(self, device, verbose):
19
+ self.device = device
20
+ self.verbose = verbose
21
+
22
+ if verbose:
23
+ if 'cpu' in device:
24
+ logger = logging.getLogger(__name__)
25
+ logger.warning("Detection running on CPU, this may be potentially slow.")
26
+
27
+ if 'cpu' not in device and 'cuda' not in device:
28
+ if verbose:
29
+ logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
30
+ raise ValueError
31
+
32
+ def detect_from_image(self, tensor_or_path):
33
+ """Detects faces in a given image.
34
+
35
+ This function detects the faces present in a provided BGR(usually)
36
+ image. The input can be either the image itself or the path to it.
37
+
38
+ Arguments:
39
+ tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
40
+ to an image or the image itself.
41
+
42
+ Example::
43
+
44
+ >>> path_to_image = 'data/image_01.jpg'
45
+ ... detected_faces = detect_from_image(path_to_image)
46
+ [A list of bounding boxes (x1, y1, x2, y2)]
47
+ >>> image = cv2.imread(path_to_image)
48
+ ... detected_faces = detect_from_image(image)
49
+ [A list of bounding boxes (x1, y1, x2, y2)]
50
+
51
+ """
52
+ raise NotImplementedError
53
+
54
+ def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
55
+ """Detects faces from all the images present in a given directory.
56
+
57
+ Arguments:
58
+ path {string} -- a string containing a path that points to the folder containing the images
59
+
60
+ Keyword Arguments:
61
+ extensions {list} -- list of string containing the extensions to be
62
+ consider in the following format: ``.extension_name`` (default:
63
+ {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
64
+ folder recursively (default: {False}) show_progress_bar {bool} --
65
+ display a progressbar (default: {True})
66
+
67
+ Example:
68
+ >>> directory = 'data'
69
+ ... detected_faces = detect_from_directory(directory)
70
+ {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
71
+
72
+ """
73
+ if self.verbose:
74
+ logger = logging.getLogger(__name__)
75
+
76
+ if len(extensions) == 0:
77
+ if self.verbose:
78
+ logger.error("Expected at list one extension, but none was received.")
79
+ raise ValueError
80
+
81
+ if self.verbose:
82
+ logger.info("Constructing the list of images.")
83
+ additional_pattern = '/**/*' if recursive else '/*'
84
+ files = []
85
+ for extension in extensions:
86
+ files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
87
+
88
+ if self.verbose:
89
+ logger.info("Finished searching for images. %s images found", len(files))
90
+ logger.info("Preparing to run the detection.")
91
+
92
+ predictions = {}
93
+ for image_path in tqdm(files, disable=not show_progress_bar):
94
+ if self.verbose:
95
+ logger.info("Running the face detector on image: %s", image_path)
96
+ predictions[image_path] = self.detect_from_image(image_path)
97
+
98
+ if self.verbose:
99
+ logger.info("The detector was successfully run on all %s images", len(files))
100
+
101
+ return predictions
102
+
103
+ @property
104
+ def reference_scale(self):
105
+ raise NotImplementedError
106
+
107
+ @property
108
+ def reference_x_shift(self):
109
+ raise NotImplementedError
110
+
111
+ @property
112
+ def reference_y_shift(self):
113
+ raise NotImplementedError
114
+
115
+ @staticmethod
116
+ def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
117
+ """Convert path (represented as a string) or torch.tensor to a numpy.ndarray
118
+
119
+ Arguments:
120
+ tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
121
+ """
122
+ if isinstance(tensor_or_path, str):
123
+ return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
124
+ elif torch.is_tensor(tensor_or_path):
125
+ # Call cpu in case its coming from cuda
126
+ return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
127
+ elif isinstance(tensor_or_path, np.ndarray):
128
+ return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
129
+ else:
130
+ raise TypeError
face_detection/detection/sfd/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sfd_detector import SFDDetector as FaceDetector
face_detection/detection/sfd/bbox.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import sys
4
+ import cv2
5
+ import random
6
+ import datetime
7
+ import time
8
+ import math
9
+ import argparse
10
+ import numpy as np
11
+ import torch
12
+
13
+ try:
14
+ from iou import IOU
15
+ except BaseException:
16
+ # IOU cython speedup 10x
17
+ def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
18
+ sa = abs((ax2 - ax1) * (ay2 - ay1))
19
+ sb = abs((bx2 - bx1) * (by2 - by1))
20
+ x1, y1 = max(ax1, bx1), max(ay1, by1)
21
+ x2, y2 = min(ax2, bx2), min(ay2, by2)
22
+ w = x2 - x1
23
+ h = y2 - y1
24
+ if w < 0 or h < 0:
25
+ return 0.0
26
+ else:
27
+ return 1.0 * w * h / (sa + sb - w * h)
28
+
29
+
30
+ def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
31
+ xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
32
+ dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
33
+ dw, dh = math.log(ww / aww), math.log(hh / ahh)
34
+ return dx, dy, dw, dh
35
+
36
+
37
+ def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
38
+ xc, yc = dx * aww + axc, dy * ahh + ayc
39
+ ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
40
+ x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
41
+ return x1, y1, x2, y2
42
+
43
+
44
+ def nms(dets, thresh):
45
+ if 0 == len(dets):
46
+ return []
47
+ x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
48
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
49
+ order = scores.argsort()[::-1]
50
+
51
+ keep = []
52
+ while order.size > 0:
53
+ i = order[0]
54
+ keep.append(i)
55
+ xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
56
+ xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
57
+
58
+ w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
59
+ ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
60
+
61
+ inds = np.where(ovr <= thresh)[0]
62
+ order = order[inds + 1]
63
+
64
+ return keep
65
+
66
+
67
+ def encode(matched, priors, variances):
68
+ """Encode the variances from the priorbox layers into the ground truth boxes
69
+ we have matched (based on jaccard overlap) with the prior boxes.
70
+ Args:
71
+ matched: (tensor) Coords of ground truth for each prior in point-form
72
+ Shape: [num_priors, 4].
73
+ priors: (tensor) Prior boxes in center-offset form
74
+ Shape: [num_priors,4].
75
+ variances: (list[float]) Variances of priorboxes
76
+ Return:
77
+ encoded boxes (tensor), Shape: [num_priors, 4]
78
+ """
79
+
80
+ # dist b/t match center and prior's center
81
+ g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
82
+ # encode variance
83
+ g_cxcy /= (variances[0] * priors[:, 2:])
84
+ # match wh / prior wh
85
+ g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
86
+ g_wh = torch.log(g_wh) / variances[1]
87
+ # return target for smooth_l1_loss
88
+ return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
89
+
90
+
91
+ def decode(loc, priors, variances):
92
+ """Decode locations from predictions using priors to undo
93
+ the encoding we did for offset regression at train time.
94
+ Args:
95
+ loc (tensor): location predictions for loc layers,
96
+ Shape: [num_priors,4]
97
+ priors (tensor): Prior boxes in center-offset form.
98
+ Shape: [num_priors,4].
99
+ variances: (list[float]) Variances of priorboxes
100
+ Return:
101
+ decoded bounding box predictions
102
+ """
103
+
104
+ boxes = torch.cat((
105
+ priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
106
+ priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
107
+ boxes[:, :2] -= boxes[:, 2:] / 2
108
+ boxes[:, 2:] += boxes[:, :2]
109
+ return boxes
110
+
111
+ def batch_decode(loc, priors, variances):
112
+ """Decode locations from predictions using priors to undo
113
+ the encoding we did for offset regression at train time.
114
+ Args:
115
+ loc (tensor): location predictions for loc layers,
116
+ Shape: [num_priors,4]
117
+ priors (tensor): Prior boxes in center-offset form.
118
+ Shape: [num_priors,4].
119
+ variances: (list[float]) Variances of priorboxes
120
+ Return:
121
+ decoded bounding box predictions
122
+ """
123
+
124
+ boxes = torch.cat((
125
+ priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
126
+ priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
127
+ boxes[:, :, :2] -= boxes[:, :, 2:] / 2
128
+ boxes[:, :, 2:] += boxes[:, :, :2]
129
+ return boxes
face_detection/detection/sfd/detect.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ import os
5
+ import sys
6
+ import cv2
7
+ import random
8
+ import datetime
9
+ import math
10
+ import argparse
11
+ import numpy as np
12
+
13
+ import scipy.io as sio
14
+ import zipfile
15
+ from .net_s3fd import s3fd
16
+ from .bbox import *
17
+
18
+
19
+ def detect(net, img, device):
20
+ img = img - np.array([104, 117, 123])
21
+ img = img.transpose(2, 0, 1)
22
+ img = img.reshape((1,) + img.shape)
23
+
24
+ if 'cuda' in device:
25
+ torch.backends.cudnn.benchmark = True
26
+
27
+ img = torch.from_numpy(img).float().to(device)
28
+ BB, CC, HH, WW = img.size()
29
+ with torch.no_grad():
30
+ olist = net(img)
31
+
32
+ bboxlist = []
33
+ for i in range(len(olist) // 2):
34
+ olist[i * 2] = F.softmax(olist[i * 2], dim=1)
35
+ olist = [oelem.data.cpu() for oelem in olist]
36
+ for i in range(len(olist) // 2):
37
+ ocls, oreg = olist[i * 2], olist[i * 2 + 1]
38
+ FB, FC, FH, FW = ocls.size() # feature map size
39
+ stride = 2**(i + 2) # 4,8,16,32,64,128
40
+ anchor = stride * 4
41
+ poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
42
+ for Iindex, hindex, windex in poss:
43
+ axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
44
+ score = ocls[0, 1, hindex, windex]
45
+ loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
46
+ priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
47
+ variances = [0.1, 0.2]
48
+ box = decode(loc, priors, variances)
49
+ x1, y1, x2, y2 = box[0] * 1.0
50
+ # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
51
+ bboxlist.append([x1, y1, x2, y2, score])
52
+ bboxlist = np.array(bboxlist)
53
+ if 0 == len(bboxlist):
54
+ bboxlist = np.zeros((1, 5))
55
+
56
+ return bboxlist
57
+
58
+ def batch_detect(net, imgs, device):
59
+ imgs = imgs - np.array([104, 117, 123])
60
+ imgs = imgs.transpose(0, 3, 1, 2)
61
+
62
+ if 'cuda' in device:
63
+ torch.backends.cudnn.benchmark = True
64
+
65
+ imgs = torch.from_numpy(imgs).float().to(device)
66
+ BB, CC, HH, WW = imgs.size()
67
+ with torch.no_grad():
68
+ olist = net(imgs)
69
+
70
+ bboxlist = []
71
+ for i in range(len(olist) // 2):
72
+ olist[i * 2] = F.softmax(olist[i * 2], dim=1)
73
+ olist = [oelem.data.cpu() for oelem in olist]
74
+ for i in range(len(olist) // 2):
75
+ ocls, oreg = olist[i * 2], olist[i * 2 + 1]
76
+ FB, FC, FH, FW = ocls.size() # feature map size
77
+ stride = 2**(i + 2) # 4,8,16,32,64,128
78
+ anchor = stride * 4
79
+ poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
80
+ for Iindex, hindex, windex in poss:
81
+ axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
82
+ score = ocls[:, 1, hindex, windex]
83
+ loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
84
+ priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
85
+ variances = [0.1, 0.2]
86
+ box = batch_decode(loc, priors, variances)
87
+ box = box[:, 0] * 1.0
88
+ # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
89
+ bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
90
+ bboxlist = np.array(bboxlist)
91
+ if 0 == len(bboxlist):
92
+ bboxlist = np.zeros((1, BB, 5))
93
+
94
+ return bboxlist
95
+
96
+ def flip_detect(net, img, device):
97
+ img = cv2.flip(img, 1)
98
+ b = detect(net, img, device)
99
+
100
+ bboxlist = np.zeros(b.shape)
101
+ bboxlist[:, 0] = img.shape[1] - b[:, 2]
102
+ bboxlist[:, 1] = b[:, 1]
103
+ bboxlist[:, 2] = img.shape[1] - b[:, 0]
104
+ bboxlist[:, 3] = b[:, 3]
105
+ bboxlist[:, 4] = b[:, 4]
106
+ return bboxlist
107
+
108
+
109
+ def pts_to_bb(pts):
110
+ min_x, min_y = np.min(pts, axis=0)
111
+ max_x, max_y = np.max(pts, axis=0)
112
+ return np.array([min_x, min_y, max_x, max_y])
face_detection/detection/sfd/net_s3fd.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class L2Norm(nn.Module):
7
+ def __init__(self, n_channels, scale=1.0):
8
+ super(L2Norm, self).__init__()
9
+ self.n_channels = n_channels
10
+ self.scale = scale
11
+ self.eps = 1e-10
12
+ self.weight = nn.Parameter(torch.Tensor(self.n_channels))
13
+ self.weight.data *= 0.0
14
+ self.weight.data += self.scale
15
+
16
+ def forward(self, x):
17
+ norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
18
+ x = x / norm * self.weight.view(1, -1, 1, 1)
19
+ return x
20
+
21
+
22
+ class s3fd(nn.Module):
23
+ def __init__(self):
24
+ super(s3fd, self).__init__()
25
+ self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
26
+ self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
27
+
28
+ self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
29
+ self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
30
+
31
+ self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
32
+ self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
33
+ self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
34
+
35
+ self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
36
+ self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
37
+ self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
38
+
39
+ self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
40
+ self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
41
+ self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
42
+
43
+ self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
44
+ self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
45
+
46
+ self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
47
+ self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
48
+
49
+ self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
50
+ self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
51
+
52
+ self.conv3_3_norm = L2Norm(256, scale=10)
53
+ self.conv4_3_norm = L2Norm(512, scale=8)
54
+ self.conv5_3_norm = L2Norm(512, scale=5)
55
+
56
+ self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
57
+ self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
58
+ self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
59
+ self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
60
+ self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
61
+ self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
62
+
63
+ self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
64
+ self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
65
+ self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
66
+ self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
67
+ self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
68
+ self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
69
+
70
+ def forward(self, x):
71
+ h = F.relu(self.conv1_1(x))
72
+ h = F.relu(self.conv1_2(h))
73
+ h = F.max_pool2d(h, 2, 2)
74
+
75
+ h = F.relu(self.conv2_1(h))
76
+ h = F.relu(self.conv2_2(h))
77
+ h = F.max_pool2d(h, 2, 2)
78
+
79
+ h = F.relu(self.conv3_1(h))
80
+ h = F.relu(self.conv3_2(h))
81
+ h = F.relu(self.conv3_3(h))
82
+ f3_3 = h
83
+ h = F.max_pool2d(h, 2, 2)
84
+
85
+ h = F.relu(self.conv4_1(h))
86
+ h = F.relu(self.conv4_2(h))
87
+ h = F.relu(self.conv4_3(h))
88
+ f4_3 = h
89
+ h = F.max_pool2d(h, 2, 2)
90
+
91
+ h = F.relu(self.conv5_1(h))
92
+ h = F.relu(self.conv5_2(h))
93
+ h = F.relu(self.conv5_3(h))
94
+ f5_3 = h
95
+ h = F.max_pool2d(h, 2, 2)
96
+
97
+ h = F.relu(self.fc6(h))
98
+ h = F.relu(self.fc7(h))
99
+ ffc7 = h
100
+ h = F.relu(self.conv6_1(h))
101
+ h = F.relu(self.conv6_2(h))
102
+ f6_2 = h
103
+ h = F.relu(self.conv7_1(h))
104
+ h = F.relu(self.conv7_2(h))
105
+ f7_2 = h
106
+
107
+ f3_3 = self.conv3_3_norm(f3_3)
108
+ f4_3 = self.conv4_3_norm(f4_3)
109
+ f5_3 = self.conv5_3_norm(f5_3)
110
+
111
+ cls1 = self.conv3_3_norm_mbox_conf(f3_3)
112
+ reg1 = self.conv3_3_norm_mbox_loc(f3_3)
113
+ cls2 = self.conv4_3_norm_mbox_conf(f4_3)
114
+ reg2 = self.conv4_3_norm_mbox_loc(f4_3)
115
+ cls3 = self.conv5_3_norm_mbox_conf(f5_3)
116
+ reg3 = self.conv5_3_norm_mbox_loc(f5_3)
117
+ cls4 = self.fc7_mbox_conf(ffc7)
118
+ reg4 = self.fc7_mbox_loc(ffc7)
119
+ cls5 = self.conv6_2_mbox_conf(f6_2)
120
+ reg5 = self.conv6_2_mbox_loc(f6_2)
121
+ cls6 = self.conv7_2_mbox_conf(f7_2)
122
+ reg6 = self.conv7_2_mbox_loc(f7_2)
123
+
124
+ # max-out background label
125
+ chunk = torch.chunk(cls1, 4, 1)
126
+ bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
127
+ cls1 = torch.cat([bmax, chunk[3]], dim=1)
128
+
129
+ return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
face_detection/detection/sfd/s3fd.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7636d0c9d2a8f4759aef537cbcc25c5fa2eb2d5d80b1fada4dcc800e967cf381
3
+ size 133
face_detection/detection/sfd/sfd_detector.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ from torch.utils.model_zoo import load_url
4
+
5
+ from ..core import FaceDetector
6
+
7
+ from .net_s3fd import s3fd
8
+ from .bbox import *
9
+ from .detect import *
10
+
11
+ models_urls = {
12
+ 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
13
+ }
14
+
15
+
16
+ class SFDDetector(FaceDetector):
17
+ def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
18
+ super(SFDDetector, self).__init__(device, verbose)
19
+
20
+ # Initialise the face detector
21
+ if not os.path.isfile(path_to_detector):
22
+ model_weights = load_url(models_urls['s3fd'])
23
+ else:
24
+ model_weights = torch.load(path_to_detector)
25
+
26
+ self.face_detector = s3fd()
27
+ self.face_detector.load_state_dict(model_weights)
28
+ self.face_detector.to(device)
29
+ self.face_detector.eval()
30
+
31
+ def detect_from_image(self, tensor_or_path):
32
+ image = self.tensor_or_path_to_ndarray(tensor_or_path)
33
+
34
+ bboxlist = detect(self.face_detector, image, device=self.device)
35
+ keep = nms(bboxlist, 0.3)
36
+ bboxlist = bboxlist[keep, :]
37
+ bboxlist = [x for x in bboxlist if x[-1] > 0.5]
38
+
39
+ return bboxlist
40
+
41
+ def detect_from_batch(self, images):
42
+ bboxlists = batch_detect(self.face_detector, images, device=self.device)
43
+ keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
44
+ bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
45
+ bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
46
+
47
+ return bboxlists
48
+
49
+ @property
50
+ def reference_scale(self):
51
+ return 195
52
+
53
+ @property
54
+ def reference_x_shift(self):
55
+ return 0
56
+
57
+ @property
58
+ def reference_y_shift(self):
59
+ return 0
face_detection/models.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+
7
+ def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
8
+ "3x3 convolution with padding"
9
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3,
10
+ stride=strd, padding=padding, bias=bias)
11
+
12
+
13
+ class ConvBlock(nn.Module):
14
+ def __init__(self, in_planes, out_planes):
15
+ super(ConvBlock, self).__init__()
16
+ self.bn1 = nn.BatchNorm2d(in_planes)
17
+ self.conv1 = conv3x3(in_planes, int(out_planes / 2))
18
+ self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
19
+ self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
20
+ self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
21
+ self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
22
+
23
+ if in_planes != out_planes:
24
+ self.downsample = nn.Sequential(
25
+ nn.BatchNorm2d(in_planes),
26
+ nn.ReLU(True),
27
+ nn.Conv2d(in_planes, out_planes,
28
+ kernel_size=1, stride=1, bias=False),
29
+ )
30
+ else:
31
+ self.downsample = None
32
+
33
+ def forward(self, x):
34
+ residual = x
35
+
36
+ out1 = self.bn1(x)
37
+ out1 = F.relu(out1, True)
38
+ out1 = self.conv1(out1)
39
+
40
+ out2 = self.bn2(out1)
41
+ out2 = F.relu(out2, True)
42
+ out2 = self.conv2(out2)
43
+
44
+ out3 = self.bn3(out2)
45
+ out3 = F.relu(out3, True)
46
+ out3 = self.conv3(out3)
47
+
48
+ out3 = torch.cat((out1, out2, out3), 1)
49
+
50
+ if self.downsample is not None:
51
+ residual = self.downsample(residual)
52
+
53
+ out3 += residual
54
+
55
+ return out3
56
+
57
+
58
+ class Bottleneck(nn.Module):
59
+
60
+ expansion = 4
61
+
62
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
63
+ super(Bottleneck, self).__init__()
64
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
65
+ self.bn1 = nn.BatchNorm2d(planes)
66
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
67
+ padding=1, bias=False)
68
+ self.bn2 = nn.BatchNorm2d(planes)
69
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
70
+ self.bn3 = nn.BatchNorm2d(planes * 4)
71
+ self.relu = nn.ReLU(inplace=True)
72
+ self.downsample = downsample
73
+ self.stride = stride
74
+
75
+ def forward(self, x):
76
+ residual = x
77
+
78
+ out = self.conv1(x)
79
+ out = self.bn1(out)
80
+ out = self.relu(out)
81
+
82
+ out = self.conv2(out)
83
+ out = self.bn2(out)
84
+ out = self.relu(out)
85
+
86
+ out = self.conv3(out)
87
+ out = self.bn3(out)
88
+
89
+ if self.downsample is not None:
90
+ residual = self.downsample(x)
91
+
92
+ out += residual
93
+ out = self.relu(out)
94
+
95
+ return out
96
+
97
+
98
+ class HourGlass(nn.Module):
99
+ def __init__(self, num_modules, depth, num_features):
100
+ super(HourGlass, self).__init__()
101
+ self.num_modules = num_modules
102
+ self.depth = depth
103
+ self.features = num_features
104
+
105
+ self._generate_network(self.depth)
106
+
107
+ def _generate_network(self, level):
108
+ self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
109
+
110
+ self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
111
+
112
+ if level > 1:
113
+ self._generate_network(level - 1)
114
+ else:
115
+ self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
116
+
117
+ self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
118
+
119
+ def _forward(self, level, inp):
120
+ # Upper branch
121
+ up1 = inp
122
+ up1 = self._modules['b1_' + str(level)](up1)
123
+
124
+ # Lower branch
125
+ low1 = F.avg_pool2d(inp, 2, stride=2)
126
+ low1 = self._modules['b2_' + str(level)](low1)
127
+
128
+ if level > 1:
129
+ low2 = self._forward(level - 1, low1)
130
+ else:
131
+ low2 = low1
132
+ low2 = self._modules['b2_plus_' + str(level)](low2)
133
+
134
+ low3 = low2
135
+ low3 = self._modules['b3_' + str(level)](low3)
136
+
137
+ up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
138
+
139
+ return up1 + up2
140
+
141
+ def forward(self, x):
142
+ return self._forward(self.depth, x)
143
+
144
+
145
+ class FAN(nn.Module):
146
+
147
+ def __init__(self, num_modules=1):
148
+ super(FAN, self).__init__()
149
+ self.num_modules = num_modules
150
+
151
+ # Base part
152
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
153
+ self.bn1 = nn.BatchNorm2d(64)
154
+ self.conv2 = ConvBlock(64, 128)
155
+ self.conv3 = ConvBlock(128, 128)
156
+ self.conv4 = ConvBlock(128, 256)
157
+
158
+ # Stacking part
159
+ for hg_module in range(self.num_modules):
160
+ self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
161
+ self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
162
+ self.add_module('conv_last' + str(hg_module),
163
+ nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
164
+ self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
165
+ self.add_module('l' + str(hg_module), nn.Conv2d(256,
166
+ 68, kernel_size=1, stride=1, padding=0))
167
+
168
+ if hg_module < self.num_modules - 1:
169
+ self.add_module(
170
+ 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
171
+ self.add_module('al' + str(hg_module), nn.Conv2d(68,
172
+ 256, kernel_size=1, stride=1, padding=0))
173
+
174
+ def forward(self, x):
175
+ x = F.relu(self.bn1(self.conv1(x)), True)
176
+ x = F.avg_pool2d(self.conv2(x), 2, stride=2)
177
+ x = self.conv3(x)
178
+ x = self.conv4(x)
179
+
180
+ previous = x
181
+
182
+ outputs = []
183
+ for i in range(self.num_modules):
184
+ hg = self._modules['m' + str(i)](previous)
185
+
186
+ ll = hg
187
+ ll = self._modules['top_m_' + str(i)](ll)
188
+
189
+ ll = F.relu(self._modules['bn_end' + str(i)]
190
+ (self._modules['conv_last' + str(i)](ll)), True)
191
+
192
+ # Predict heatmaps
193
+ tmp_out = self._modules['l' + str(i)](ll)
194
+ outputs.append(tmp_out)
195
+
196
+ if i < self.num_modules - 1:
197
+ ll = self._modules['bl' + str(i)](ll)
198
+ tmp_out_ = self._modules['al' + str(i)](tmp_out)
199
+ previous = previous + ll + tmp_out_
200
+
201
+ return outputs
202
+
203
+
204
+ class ResNetDepth(nn.Module):
205
+
206
+ def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
207
+ self.inplanes = 64
208
+ super(ResNetDepth, self).__init__()
209
+ self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
210
+ bias=False)
211
+ self.bn1 = nn.BatchNorm2d(64)
212
+ self.relu = nn.ReLU(inplace=True)
213
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
214
+ self.layer1 = self._make_layer(block, 64, layers[0])
215
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
216
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
217
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
218
+ self.avgpool = nn.AvgPool2d(7)
219
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
220
+
221
+ for m in self.modules():
222
+ if isinstance(m, nn.Conv2d):
223
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
224
+ m.weight.data.normal_(0, math.sqrt(2. / n))
225
+ elif isinstance(m, nn.BatchNorm2d):
226
+ m.weight.data.fill_(1)
227
+ m.bias.data.zero_()
228
+
229
+ def _make_layer(self, block, planes, blocks, stride=1):
230
+ downsample = None
231
+ if stride != 1 or self.inplanes != planes * block.expansion:
232
+ downsample = nn.Sequential(
233
+ nn.Conv2d(self.inplanes, planes * block.expansion,
234
+ kernel_size=1, stride=stride, bias=False),
235
+ nn.BatchNorm2d(planes * block.expansion),
236
+ )
237
+
238
+ layers = []
239
+ layers.append(block(self.inplanes, planes, stride, downsample))
240
+ self.inplanes = planes * block.expansion
241
+ for i in range(1, blocks):
242
+ layers.append(block(self.inplanes, planes))
243
+
244
+ return nn.Sequential(*layers)
245
+
246
+ def forward(self, x):
247
+ x = self.conv1(x)
248
+ x = self.bn1(x)
249
+ x = self.relu(x)
250
+ x = self.maxpool(x)
251
+
252
+ x = self.layer1(x)
253
+ x = self.layer2(x)
254
+ x = self.layer3(x)
255
+ x = self.layer4(x)
256
+
257
+ x = self.avgpool(x)
258
+ x = x.view(x.size(0), -1)
259
+ x = self.fc(x)
260
+
261
+ return x
face_detection/utils.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import sys
4
+ import time
5
+ import torch
6
+ import math
7
+ import numpy as np
8
+ import cv2
9
+
10
+
11
+ def _gaussian(
12
+ size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
13
+ height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
14
+ mean_vert=0.5):
15
+ # handle some defaults
16
+ if width is None:
17
+ width = size
18
+ if height is None:
19
+ height = size
20
+ if sigma_horz is None:
21
+ sigma_horz = sigma
22
+ if sigma_vert is None:
23
+ sigma_vert = sigma
24
+ center_x = mean_horz * width + 0.5
25
+ center_y = mean_vert * height + 0.5
26
+ gauss = np.empty((height, width), dtype=np.float32)
27
+ # generate kernel
28
+ for i in range(height):
29
+ for j in range(width):
30
+ gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
31
+ sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
32
+ if normalize:
33
+ gauss = gauss / np.sum(gauss)
34
+ return gauss
35
+
36
+
37
+ def draw_gaussian(image, point, sigma):
38
+ # Check if the gaussian is inside
39
+ ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
40
+ br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
41
+ if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
42
+ return image
43
+ size = 6 * sigma + 1
44
+ g = _gaussian(size)
45
+ g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
46
+ g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
47
+ img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
48
+ img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
49
+ assert (g_x[0] > 0 and g_y[1] > 0)
50
+ image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
51
+ ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
52
+ image[image > 1] = 1
53
+ return image
54
+
55
+
56
+ def transform(point, center, scale, resolution, invert=False):
57
+ """Generate and affine transformation matrix.
58
+
59
+ Given a set of points, a center, a scale and a targer resolution, the
60
+ function generates and affine transformation matrix. If invert is ``True``
61
+ it will produce the inverse transformation.
62
+
63
+ Arguments:
64
+ point {torch.tensor} -- the input 2D point
65
+ center {torch.tensor or numpy.array} -- the center around which to perform the transformations
66
+ scale {float} -- the scale of the face/object
67
+ resolution {float} -- the output resolution
68
+
69
+ Keyword Arguments:
70
+ invert {bool} -- define wherever the function should produce the direct or the
71
+ inverse transformation matrix (default: {False})
72
+ """
73
+ _pt = torch.ones(3)
74
+ _pt[0] = point[0]
75
+ _pt[1] = point[1]
76
+
77
+ h = 200.0 * scale
78
+ t = torch.eye(3)
79
+ t[0, 0] = resolution / h
80
+ t[1, 1] = resolution / h
81
+ t[0, 2] = resolution * (-center[0] / h + 0.5)
82
+ t[1, 2] = resolution * (-center[1] / h + 0.5)
83
+
84
+ if invert:
85
+ t = torch.inverse(t)
86
+
87
+ new_point = (torch.matmul(t, _pt))[0:2]
88
+
89
+ return new_point.int()
90
+
91
+
92
+ def crop(image, center, scale, resolution=256.0):
93
+ """Center crops an image or set of heatmaps
94
+
95
+ Arguments:
96
+ image {numpy.array} -- an rgb image
97
+ center {numpy.array} -- the center of the object, usually the same as of the bounding box
98
+ scale {float} -- scale of the face
99
+
100
+ Keyword Arguments:
101
+ resolution {float} -- the size of the output cropped image (default: {256.0})
102
+
103
+ Returns:
104
+ [type] -- [description]
105
+ """ # Crop around the center point
106
+ """ Crops the image around the center. Input is expected to be an np.ndarray """
107
+ ul = transform([1, 1], center, scale, resolution, True)
108
+ br = transform([resolution, resolution], center, scale, resolution, True)
109
+ # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
110
+ if image.ndim > 2:
111
+ newDim = np.array([br[1] - ul[1], br[0] - ul[0],
112
+ image.shape[2]], dtype=np.int32)
113
+ newImg = np.zeros(newDim, dtype=np.uint8)
114
+ else:
115
+ newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
116
+ newImg = np.zeros(newDim, dtype=np.uint8)
117
+ ht = image.shape[0]
118
+ wd = image.shape[1]
119
+ newX = np.array(
120
+ [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
121
+ newY = np.array(
122
+ [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
123
+ oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
124
+ oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
125
+ newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
126
+ ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
127
+ newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
128
+ interpolation=cv2.INTER_LINEAR)
129
+ return newImg
130
+
131
+
132
+ def get_preds_fromhm(hm, center=None, scale=None):
133
+ """Obtain (x,y) coordinates given a set of N heatmaps. If the center
134
+ and the scale is provided the function will return the points also in
135
+ the original coordinate frame.
136
+
137
+ Arguments:
138
+ hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
139
+
140
+ Keyword Arguments:
141
+ center {torch.tensor} -- the center of the bounding box (default: {None})
142
+ scale {float} -- face scale (default: {None})
143
+ """
144
+ max, idx = torch.max(
145
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
146
+ idx += 1
147
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
148
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
149
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
150
+
151
+ for i in range(preds.size(0)):
152
+ for j in range(preds.size(1)):
153
+ hm_ = hm[i, j, :]
154
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
155
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
156
+ diff = torch.FloatTensor(
157
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
158
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
159
+ preds[i, j].add_(diff.sign_().mul_(.25))
160
+
161
+ preds.add_(-.5)
162
+
163
+ preds_orig = torch.zeros(preds.size())
164
+ if center is not None and scale is not None:
165
+ for i in range(hm.size(0)):
166
+ for j in range(hm.size(1)):
167
+ preds_orig[i, j] = transform(
168
+ preds[i, j], center, scale, hm.size(2), True)
169
+
170
+ return preds, preds_orig
171
+
172
+ def get_preds_fromhm_batch(hm, centers=None, scales=None):
173
+ """Obtain (x,y) coordinates given a set of N heatmaps. If the centers
174
+ and the scales is provided the function will return the points also in
175
+ the original coordinate frame.
176
+
177
+ Arguments:
178
+ hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
179
+
180
+ Keyword Arguments:
181
+ centers {torch.tensor} -- the centers of the bounding box (default: {None})
182
+ scales {float} -- face scales (default: {None})
183
+ """
184
+ max, idx = torch.max(
185
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
186
+ idx += 1
187
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
188
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
189
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
190
+
191
+ for i in range(preds.size(0)):
192
+ for j in range(preds.size(1)):
193
+ hm_ = hm[i, j, :]
194
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
195
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
196
+ diff = torch.FloatTensor(
197
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
198
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
199
+ preds[i, j].add_(diff.sign_().mul_(.25))
200
+
201
+ preds.add_(-.5)
202
+
203
+ preds_orig = torch.zeros(preds.size())
204
+ if centers is not None and scales is not None:
205
+ for i in range(hm.size(0)):
206
+ for j in range(hm.size(1)):
207
+ preds_orig[i, j] = transform(
208
+ preds[i, j], centers[i], scales[i], hm.size(2), True)
209
+
210
+ return preds, preds_orig
211
+
212
+ def shuffle_lr(parts, pairs=None):
213
+ """Shuffle the points left-right according to the axis of symmetry
214
+ of the object.
215
+
216
+ Arguments:
217
+ parts {torch.tensor} -- a 3D or 4D object containing the
218
+ heatmaps.
219
+
220
+ Keyword Arguments:
221
+ pairs {list of integers} -- [order of the flipped points] (default: {None})
222
+ """
223
+ if pairs is None:
224
+ pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
225
+ 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
226
+ 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
227
+ 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
228
+ 62, 61, 60, 67, 66, 65]
229
+ if parts.ndimension() == 3:
230
+ parts = parts[pairs, ...]
231
+ else:
232
+ parts = parts[:, pairs, ...]
233
+
234
+ return parts
235
+
236
+
237
+ def flip(tensor, is_label=False):
238
+ """Flip an image or a set of heatmaps left-right
239
+
240
+ Arguments:
241
+ tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
242
+
243
+ Keyword Arguments:
244
+ is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
245
+ """
246
+ if not torch.is_tensor(tensor):
247
+ tensor = torch.from_numpy(tensor)
248
+
249
+ if is_label:
250
+ tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
251
+ else:
252
+ tensor = tensor.flip(tensor.ndimension() - 1)
253
+
254
+ return tensor
255
+
256
+ # From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
257
+
258
+
259
+ def appdata_dir(appname=None, roaming=False):
260
+ """ appdata_dir(appname=None, roaming=False)
261
+
262
+ Get the path to the application directory, where applications are allowed
263
+ to write user specific files (e.g. configurations). For non-user specific
264
+ data, consider using common_appdata_dir().
265
+ If appname is given, a subdir is appended (and created if necessary).
266
+ If roaming is True, will prefer a roaming directory (Windows Vista/7).
267
+ """
268
+
269
+ # Define default user directory
270
+ userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
271
+ if userDir is None:
272
+ userDir = os.path.expanduser('~')
273
+ if not os.path.isdir(userDir): # pragma: no cover
274
+ userDir = '/var/tmp' # issue #54
275
+
276
+ # Get system app data dir
277
+ path = None
278
+ if sys.platform.startswith('win'):
279
+ path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
280
+ path = (path2 or path1) if roaming else (path1 or path2)
281
+ elif sys.platform.startswith('darwin'):
282
+ path = os.path.join(userDir, 'Library', 'Application Support')
283
+ # On Linux and as fallback
284
+ if not (path and os.path.isdir(path)):
285
+ path = userDir
286
+
287
+ # Maybe we should store things local to the executable (in case of a
288
+ # portable distro or a frozen application that wants to be portable)
289
+ prefix = sys.prefix
290
+ if getattr(sys, 'frozen', None):
291
+ prefix = os.path.abspath(os.path.dirname(sys.executable))
292
+ for reldir in ('settings', '../settings'):
293
+ localpath = os.path.abspath(os.path.join(prefix, reldir))
294
+ if os.path.isdir(localpath): # pragma: no cover
295
+ try:
296
+ open(os.path.join(localpath, 'test.write'), 'wb').close()
297
+ os.remove(os.path.join(localpath, 'test.write'))
298
+ except IOError:
299
+ pass # We cannot write in this directory
300
+ else:
301
+ path = localpath
302
+ break
303
+
304
+ # Get path specific for this app
305
+ if appname:
306
+ if path == userDir:
307
+ appname = '.' + appname.lstrip('.') # Make it a hidden directory
308
+ path = os.path.join(path, appname)
309
+ if not os.path.isdir(path): # pragma: no cover
310
+ os.mkdir(path)
311
+
312
+ # Done
313
+ return path
hparams.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from glob import glob
2
+ import os
3
+
4
+ def get_image_list(data_root, split):
5
+ filelist = []
6
+
7
+ with open('filelists/{}.txt'.format(split)) as f:
8
+ for line in f:
9
+ line = line.strip()
10
+ if ' ' in line: line = line.split()[0]
11
+ filelist.append(os.path.join(data_root, line))
12
+
13
+ return filelist
14
+
15
+ class HParams:
16
+ def __init__(self, **kwargs):
17
+ self.data = {}
18
+
19
+ for key, value in kwargs.items():
20
+ self.data[key] = value
21
+
22
+ def __getattr__(self, key):
23
+ if key not in self.data:
24
+ raise AttributeError("'HParams' object has no attribute %s" % key)
25
+ return self.data[key]
26
+
27
+ def set_hparam(self, key, value):
28
+ self.data[key] = value
29
+
30
+
31
+ # Default hyperparameters
32
+ hparams = HParams(
33
+ num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
34
+ # network
35
+ rescale=True, # Whether to rescale audio prior to preprocessing
36
+ rescaling_max=0.9, # Rescaling value
37
+
38
+ # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
39
+ # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
40
+ # Does not work if n_ffit is not multiple of hop_size!!
41
+ use_lws=False,
42
+
43
+ n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
44
+ hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
45
+ win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
46
+ sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i <filename>)
47
+
48
+ frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
49
+
50
+ # Mel and Linear spectrograms normalization/scaling and clipping
51
+ signal_normalization=True,
52
+ # Whether to normalize mel spectrograms to some predefined range (following below parameters)
53
+ allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
54
+ symmetric_mels=True,
55
+ # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
56
+ # faster and cleaner convergence)
57
+ max_abs_value=4.,
58
+ # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
59
+ # be too big to avoid gradient explosion,
60
+ # not too small for fast convergence)
61
+ # Contribution by @begeekmyfriend
62
+ # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
63
+ # levels. Also allows for better G&L phase reconstruction)
64
+ preemphasize=True, # whether to apply filter
65
+ preemphasis=0.97, # filter coefficient.
66
+
67
+ # Limits
68
+ min_level_db=-100,
69
+ ref_level_db=20,
70
+ fmin=55,
71
+ # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
72
+ # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
73
+ fmax=7600, # To be increased/reduced depending on data.
74
+
75
+ ###################### Our training parameters #################################
76
+ img_size=96,
77
+ fps=25,
78
+
79
+ batch_size=16,
80
+ initial_learning_rate=1e-4,
81
+ nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
82
+ num_workers=16,
83
+ checkpoint_interval=3000,
84
+ eval_interval=3000,
85
+ save_optimizer_state=True,
86
+
87
+ syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
88
+ syncnet_batch_size=64,
89
+ syncnet_lr=1e-4,
90
+ syncnet_eval_interval=10000,
91
+ syncnet_checkpoint_interval=10000,
92
+
93
+ disc_wt=0.07,
94
+ disc_initial_learning_rate=1e-4,
95
+ )
96
+
97
+
98
+ def hparams_debug_string():
99
+ values = hparams.values()
100
+ hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"]
101
+ return "Hyperparameters:\n" + "\n".join(hp)
hq_wav2lip_train.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import dirname, join, basename, isfile
2
+ from tqdm import tqdm
3
+
4
+ from models import SyncNet_color as SyncNet
5
+ from models import Wav2Lip, Wav2Lip_disc_qual
6
+ import audio
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+ from torch import optim
12
+ import torch.backends.cudnn as cudnn
13
+ from torch.utils import data as data_utils
14
+ import numpy as np
15
+
16
+ from glob import glob
17
+
18
+ import os, random, cv2, argparse
19
+ from hparams import hparams, get_image_list
20
+
21
+ parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model WITH the visual quality discriminator')
22
+
23
+ parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True, type=str)
24
+
25
+ parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
26
+ parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True, type=str)
27
+
28
+ parser.add_argument('--checkpoint_path', help='Resume generator from this checkpoint', default=None, type=str)
29
+ parser.add_argument('--disc_checkpoint_path', help='Resume quality disc from this checkpoint', default=None, type=str)
30
+
31
+ args = parser.parse_args()
32
+
33
+
34
+ global_step = 0
35
+ global_epoch = 0
36
+ use_cuda = torch.cuda.is_available()
37
+ print('use_cuda: {}'.format(use_cuda))
38
+
39
+ syncnet_T = 5
40
+ syncnet_mel_step_size = 16
41
+
42
+ class Dataset(object):
43
+ def __init__(self, split):
44
+ self.all_videos = get_image_list(args.data_root, split)
45
+
46
+ def get_frame_id(self, frame):
47
+ return int(basename(frame).split('.')[0])
48
+
49
+ def get_window(self, start_frame):
50
+ start_id = self.get_frame_id(start_frame)
51
+ vidname = dirname(start_frame)
52
+
53
+ window_fnames = []
54
+ for frame_id in range(start_id, start_id + syncnet_T):
55
+ frame = join(vidname, '{}.jpg'.format(frame_id))
56
+ if not isfile(frame):
57
+ return None
58
+ window_fnames.append(frame)
59
+ return window_fnames
60
+
61
+ def read_window(self, window_fnames):
62
+ if window_fnames is None: return None
63
+ window = []
64
+ for fname in window_fnames:
65
+ img = cv2.imread(fname)
66
+ if img is None:
67
+ return None
68
+ try:
69
+ img = cv2.resize(img, (hparams.img_size, hparams.img_size))
70
+ except Exception as e:
71
+ return None
72
+
73
+ window.append(img)
74
+
75
+ return window
76
+
77
+ def crop_audio_window(self, spec, start_frame):
78
+ if type(start_frame) == int:
79
+ start_frame_num = start_frame
80
+ else:
81
+ start_frame_num = self.get_frame_id(start_frame)
82
+ start_idx = int(80. * (start_frame_num / float(hparams.fps)))
83
+
84
+ end_idx = start_idx + syncnet_mel_step_size
85
+
86
+ return spec[start_idx : end_idx, :]
87
+
88
+ def get_segmented_mels(self, spec, start_frame):
89
+ mels = []
90
+ assert syncnet_T == 5
91
+ start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing
92
+ if start_frame_num - 2 < 0: return None
93
+ for i in range(start_frame_num, start_frame_num + syncnet_T):
94
+ m = self.crop_audio_window(spec, i - 2)
95
+ if m.shape[0] != syncnet_mel_step_size:
96
+ return None
97
+ mels.append(m.T)
98
+
99
+ mels = np.asarray(mels)
100
+
101
+ return mels
102
+
103
+ def prepare_window(self, window):
104
+ # 3 x T x H x W
105
+ x = np.asarray(window) / 255.
106
+ x = np.transpose(x, (3, 0, 1, 2))
107
+
108
+ return x
109
+
110
+ def __len__(self):
111
+ return len(self.all_videos)
112
+
113
+ def __getitem__(self, idx):
114
+ while 1:
115
+ idx = random.randint(0, len(self.all_videos) - 1)
116
+ vidname = self.all_videos[idx]
117
+ img_names = list(glob(join(vidname, '*.jpg')))
118
+ if len(img_names) <= 3 * syncnet_T:
119
+ continue
120
+
121
+ img_name = random.choice(img_names)
122
+ wrong_img_name = random.choice(img_names)
123
+ while wrong_img_name == img_name:
124
+ wrong_img_name = random.choice(img_names)
125
+
126
+ window_fnames = self.get_window(img_name)
127
+ wrong_window_fnames = self.get_window(wrong_img_name)
128
+ if window_fnames is None or wrong_window_fnames is None:
129
+ continue
130
+
131
+ window = self.read_window(window_fnames)
132
+ if window is None:
133
+ continue
134
+
135
+ wrong_window = self.read_window(wrong_window_fnames)
136
+ if wrong_window is None:
137
+ continue
138
+
139
+ try:
140
+ wavpath = join(vidname, "audio.wav")
141
+ wav = audio.load_wav(wavpath, hparams.sample_rate)
142
+
143
+ orig_mel = audio.melspectrogram(wav).T
144
+ except Exception as e:
145
+ continue
146
+
147
+ mel = self.crop_audio_window(orig_mel.copy(), img_name)
148
+
149
+ if (mel.shape[0] != syncnet_mel_step_size):
150
+ continue
151
+
152
+ indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name)
153
+ if indiv_mels is None: continue
154
+
155
+ window = self.prepare_window(window)
156
+ y = window.copy()
157
+ window[:, :, window.shape[2]//2:] = 0.
158
+
159
+ wrong_window = self.prepare_window(wrong_window)
160
+ x = np.concatenate([window, wrong_window], axis=0)
161
+
162
+ x = torch.FloatTensor(x)
163
+ mel = torch.FloatTensor(mel.T).unsqueeze(0)
164
+ indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1)
165
+ y = torch.FloatTensor(y)
166
+ return x, indiv_mels, mel, y
167
+
168
+ def save_sample_images(x, g, gt, global_step, checkpoint_dir):
169
+ x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
170
+ g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
171
+ gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
172
+
173
+ refs, inps = x[..., 3:], x[..., :3]
174
+ folder = join(checkpoint_dir, "samples_step{:09d}".format(global_step))
175
+ if not os.path.exists(folder): os.mkdir(folder)
176
+ collage = np.concatenate((refs, inps, g, gt), axis=-2)
177
+ for batch_idx, c in enumerate(collage):
178
+ for t in range(len(c)):
179
+ cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t])
180
+
181
+ logloss = nn.BCELoss()
182
+ def cosine_loss(a, v, y):
183
+ d = nn.functional.cosine_similarity(a, v)
184
+ loss = logloss(d.unsqueeze(1), y)
185
+
186
+ return loss
187
+
188
+ device = torch.device("cuda" if use_cuda else "cpu")
189
+ syncnet = SyncNet().to(device)
190
+ for p in syncnet.parameters():
191
+ p.requires_grad = False
192
+
193
+ recon_loss = nn.L1Loss()
194
+ def get_sync_loss(mel, g):
195
+ g = g[:, :, :, g.size(3)//2:]
196
+ g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1)
197
+ # B, 3 * T, H//2, W
198
+ a, v = syncnet(mel, g)
199
+ y = torch.ones(g.size(0), 1).float().to(device)
200
+ return cosine_loss(a, v, y)
201
+
202
+ def train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer,
203
+ checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
204
+ global global_step, global_epoch
205
+ resumed_step = global_step
206
+
207
+ while global_epoch < nepochs:
208
+ print('Starting Epoch: {}'.format(global_epoch))
209
+ running_sync_loss, running_l1_loss, disc_loss, running_perceptual_loss = 0., 0., 0., 0.
210
+ running_disc_real_loss, running_disc_fake_loss = 0., 0.
211
+ prog_bar = tqdm(enumerate(train_data_loader))
212
+ for step, (x, indiv_mels, mel, gt) in prog_bar:
213
+ disc.train()
214
+ model.train()
215
+
216
+ x = x.to(device)
217
+ mel = mel.to(device)
218
+ indiv_mels = indiv_mels.to(device)
219
+ gt = gt.to(device)
220
+
221
+ ### Train generator now. Remove ALL grads.
222
+ optimizer.zero_grad()
223
+ disc_optimizer.zero_grad()
224
+
225
+ g = model(indiv_mels, x)
226
+
227
+ if hparams.syncnet_wt > 0.:
228
+ sync_loss = get_sync_loss(mel, g)
229
+ else:
230
+ sync_loss = 0.
231
+
232
+ if hparams.disc_wt > 0.:
233
+ perceptual_loss = disc.perceptual_forward(g)
234
+ else:
235
+ perceptual_loss = 0.
236
+
237
+ l1loss = recon_loss(g, gt)
238
+
239
+ loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \
240
+ (1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss
241
+
242
+ loss.backward()
243
+ optimizer.step()
244
+
245
+ ### Remove all gradients before Training disc
246
+ disc_optimizer.zero_grad()
247
+
248
+ pred = disc(gt)
249
+ disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device))
250
+ disc_real_loss.backward()
251
+
252
+ pred = disc(g.detach())
253
+ disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device))
254
+ disc_fake_loss.backward()
255
+
256
+ disc_optimizer.step()
257
+
258
+ running_disc_real_loss += disc_real_loss.item()
259
+ running_disc_fake_loss += disc_fake_loss.item()
260
+
261
+ if global_step % checkpoint_interval == 0:
262
+ save_sample_images(x, g, gt, global_step, checkpoint_dir)
263
+
264
+ # Logs
265
+ global_step += 1
266
+ cur_session_steps = global_step - resumed_step
267
+
268
+ running_l1_loss += l1loss.item()
269
+ if hparams.syncnet_wt > 0.:
270
+ running_sync_loss += sync_loss.item()
271
+ else:
272
+ running_sync_loss += 0.
273
+
274
+ if hparams.disc_wt > 0.:
275
+ running_perceptual_loss += perceptual_loss.item()
276
+ else:
277
+ running_perceptual_loss += 0.
278
+
279
+ if global_step == 1 or global_step % checkpoint_interval == 0:
280
+ save_checkpoint(
281
+ model, optimizer, global_step, checkpoint_dir, global_epoch)
282
+ save_checkpoint(disc, disc_optimizer, global_step, checkpoint_dir, global_epoch, prefix='disc_')
283
+
284
+
285
+ if global_step % hparams.eval_interval == 0:
286
+ with torch.no_grad():
287
+ average_sync_loss = eval_model(test_data_loader, global_step, device, model, disc)
288
+
289
+ if average_sync_loss < .75:
290
+ hparams.set_hparam('syncnet_wt', 0.03)
291
+
292
+ prog_bar.set_description('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(running_l1_loss / (step + 1),
293
+ running_sync_loss / (step + 1),
294
+ running_perceptual_loss / (step + 1),
295
+ running_disc_fake_loss / (step + 1),
296
+ running_disc_real_loss / (step + 1)))
297
+
298
+ global_epoch += 1
299
+
300
+ def eval_model(test_data_loader, global_step, device, model, disc):
301
+ eval_steps = 300
302
+ print('Evaluating for {} steps'.format(eval_steps))
303
+ running_sync_loss, running_l1_loss, running_disc_real_loss, running_disc_fake_loss, running_perceptual_loss = [], [], [], [], []
304
+ while 1:
305
+ for step, (x, indiv_mels, mel, gt) in enumerate((test_data_loader)):
306
+ model.eval()
307
+ disc.eval()
308
+
309
+ x = x.to(device)
310
+ mel = mel.to(device)
311
+ indiv_mels = indiv_mels.to(device)
312
+ gt = gt.to(device)
313
+
314
+ pred = disc(gt)
315
+ disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device))
316
+
317
+ g = model(indiv_mels, x)
318
+ pred = disc(g)
319
+ disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device))
320
+
321
+ running_disc_real_loss.append(disc_real_loss.item())
322
+ running_disc_fake_loss.append(disc_fake_loss.item())
323
+
324
+ sync_loss = get_sync_loss(mel, g)
325
+
326
+ if hparams.disc_wt > 0.:
327
+ perceptual_loss = disc.perceptual_forward(g)
328
+ else:
329
+ perceptual_loss = 0.
330
+
331
+ l1loss = recon_loss(g, gt)
332
+
333
+ loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \
334
+ (1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss
335
+
336
+ running_l1_loss.append(l1loss.item())
337
+ running_sync_loss.append(sync_loss.item())
338
+
339
+ if hparams.disc_wt > 0.:
340
+ running_perceptual_loss.append(perceptual_loss.item())
341
+ else:
342
+ running_perceptual_loss.append(0.)
343
+
344
+ if step > eval_steps: break
345
+
346
+ print('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(sum(running_l1_loss) / len(running_l1_loss),
347
+ sum(running_sync_loss) / len(running_sync_loss),
348
+ sum(running_perceptual_loss) / len(running_perceptual_loss),
349
+ sum(running_disc_fake_loss) / len(running_disc_fake_loss),
350
+ sum(running_disc_real_loss) / len(running_disc_real_loss)))
351
+ return sum(running_sync_loss) / len(running_sync_loss)
352
+
353
+
354
+ def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch, prefix=''):
355
+ checkpoint_path = join(
356
+ checkpoint_dir, "{}checkpoint_step{:09d}.pth".format(prefix, global_step))
357
+ optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
358
+ torch.save({
359
+ "state_dict": model.state_dict(),
360
+ "optimizer": optimizer_state,
361
+ "global_step": step,
362
+ "global_epoch": epoch,
363
+ }, checkpoint_path)
364
+ print("Saved checkpoint:", checkpoint_path)
365
+
366
+ def _load(checkpoint_path):
367
+ if use_cuda:
368
+ checkpoint = torch.load(checkpoint_path)
369
+ else:
370
+ checkpoint = torch.load(checkpoint_path,
371
+ map_location=lambda storage, loc: storage)
372
+ return checkpoint
373
+
374
+
375
+ def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
376
+ global global_step
377
+ global global_epoch
378
+
379
+ print("Load checkpoint from: {}".format(path))
380
+ checkpoint = _load(path)
381
+ s = checkpoint["state_dict"]
382
+ new_s = {}
383
+ for k, v in s.items():
384
+ new_s[k.replace('module.', '')] = v
385
+ model.load_state_dict(new_s)
386
+ if not reset_optimizer:
387
+ optimizer_state = checkpoint["optimizer"]
388
+ if optimizer_state is not None:
389
+ print("Load optimizer state from {}".format(path))
390
+ optimizer.load_state_dict(checkpoint["optimizer"])
391
+ if overwrite_global_states:
392
+ global_step = checkpoint["global_step"]
393
+ global_epoch = checkpoint["global_epoch"]
394
+
395
+ return model
396
+
397
+ if __name__ == "__main__":
398
+ checkpoint_dir = args.checkpoint_dir
399
+
400
+ # Dataset and Dataloader setup
401
+ train_dataset = Dataset('train')
402
+ test_dataset = Dataset('val')
403
+
404
+ train_data_loader = data_utils.DataLoader(
405
+ train_dataset, batch_size=hparams.batch_size, shuffle=True,
406
+ num_workers=hparams.num_workers)
407
+
408
+ test_data_loader = data_utils.DataLoader(
409
+ test_dataset, batch_size=hparams.batch_size,
410
+ num_workers=4)
411
+
412
+ device = torch.device("cuda" if use_cuda else "cpu")
413
+
414
+ # Model
415
+ model = Wav2Lip().to(device)
416
+ disc = Wav2Lip_disc_qual().to(device)
417
+
418
+ print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
419
+ print('total DISC trainable params {}'.format(sum(p.numel() for p in disc.parameters() if p.requires_grad)))
420
+
421
+ optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
422
+ lr=hparams.initial_learning_rate, betas=(0.5, 0.999))
423
+ disc_optimizer = optim.Adam([p for p in disc.parameters() if p.requires_grad],
424
+ lr=hparams.disc_initial_learning_rate, betas=(0.5, 0.999))
425
+
426
+ if args.checkpoint_path is not None:
427
+ load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)
428
+
429
+ if args.disc_checkpoint_path is not None:
430
+ load_checkpoint(args.disc_checkpoint_path, disc, disc_optimizer,
431
+ reset_optimizer=False, overwrite_global_states=False)
432
+
433
+ load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True,
434
+ overwrite_global_states=False)
435
+
436
+ if not os.path.exists(checkpoint_dir):
437
+ os.mkdir(checkpoint_dir)
438
+
439
+ # Train!
440
+ train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer,
441
+ checkpoint_dir=checkpoint_dir,
442
+ checkpoint_interval=hparams.checkpoint_interval,
443
+ nepochs=hparams.nepochs)
inference.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import listdir, path
2
+ import numpy as np
3
+ import scipy, cv2, os, sys, argparse, audio
4
+ import json, subprocess, random, string
5
+ from tqdm import tqdm
6
+ from glob import glob
7
+ import torch, face_detection
8
+ from models import Wav2Lip
9
+ import platform
10
+
11
+ parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
12
+
13
+ parser.add_argument('--checkpoint_path', type=str,
14
+ help='Name of saved checkpoint to load weights from', required=True)
15
+
16
+ parser.add_argument('--face', type=str,
17
+ help='Filepath of video/image that contains faces to use', required=True)
18
+ parser.add_argument('--audio', type=str,
19
+ help='Filepath of video/audio file to use as raw audio source', required=True)
20
+ parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.',
21
+ default='results/result_voice.mp4')
22
+
23
+ parser.add_argument('--static', type=bool,
24
+ help='If True, then use only first video frame for inference', default=False)
25
+ parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)',
26
+ default=25., required=False)
27
+
28
+ parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
29
+ help='Padding (top, bottom, left, right). Please adjust to include chin at least')
30
+
31
+ parser.add_argument('--face_det_batch_size', type=int,
32
+ help='Batch size for face detection', default=16)
33
+ parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=128)
34
+
35
+ parser.add_argument('--resize_factor', default=1, type=int,
36
+ help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')
37
+
38
+ parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1],
39
+ help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
40
+ 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')
41
+
42
+ parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1],
43
+ help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
44
+ 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')
45
+
46
+ parser.add_argument('--rotate', default=False, action='store_true',
47
+ help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.'
48
+ 'Use if you get a flipped result, despite feeding a normal looking video')
49
+
50
+ parser.add_argument('--nosmooth', default=False, action='store_true',
51
+ help='Prevent smoothing face detections over a short temporal window')
52
+
53
+ args = parser.parse_args()
54
+ args.img_size = 96
55
+
56
+ if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
57
+ args.static = True
58
+
59
+ def get_smoothened_boxes(boxes, T):
60
+ for i in range(len(boxes)):
61
+ if i + T > len(boxes):
62
+ window = boxes[len(boxes) - T:]
63
+ else:
64
+ window = boxes[i : i + T]
65
+ boxes[i] = np.mean(window, axis=0)
66
+ return boxes
67
+
68
+ def face_detect(images):
69
+ detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
70
+ flip_input=False, device=device)
71
+
72
+ batch_size = args.face_det_batch_size
73
+
74
+ while 1:
75
+ predictions = []
76
+ try:
77
+ for i in tqdm(range(0, len(images), batch_size)):
78
+ predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
79
+ except RuntimeError:
80
+ if batch_size == 1:
81
+ raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
82
+ batch_size //= 2
83
+ print('Recovering from OOM error; New batch size: {}'.format(batch_size))
84
+ continue
85
+ break
86
+
87
+ results = []
88
+ pady1, pady2, padx1, padx2 = args.pads
89
+ for rect, image in zip(predictions, images):
90
+ if rect is None:
91
+ cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
92
+ raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
93
+
94
+ y1 = max(0, rect[1] - pady1)
95
+ y2 = min(image.shape[0], rect[3] + pady2)
96
+ x1 = max(0, rect[0] - padx1)
97
+ x2 = min(image.shape[1], rect[2] + padx2)
98
+
99
+ results.append([x1, y1, x2, y2])
100
+
101
+ boxes = np.array(results)
102
+ if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
103
+ results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
104
+
105
+ del detector
106
+ return results
107
+
108
+ def datagen(frames, mels):
109
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
110
+
111
+ if args.box[0] == -1:
112
+ if not args.static:
113
+ face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
114
+ else:
115
+ face_det_results = face_detect([frames[0]])
116
+ else:
117
+ print('Using the specified bounding box instead of face detection...')
118
+ y1, y2, x1, x2 = args.box
119
+ face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
120
+
121
+ for i, m in enumerate(mels):
122
+ idx = 0 if args.static else i%len(frames)
123
+ frame_to_save = frames[idx].copy()
124
+ face, coords = face_det_results[idx].copy()
125
+
126
+ face = cv2.resize(face, (args.img_size, args.img_size))
127
+
128
+ img_batch.append(face)
129
+ mel_batch.append(m)
130
+ frame_batch.append(frame_to_save)
131
+ coords_batch.append(coords)
132
+
133
+ if len(img_batch) >= args.wav2lip_batch_size:
134
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
135
+
136
+ img_masked = img_batch.copy()
137
+ img_masked[:, args.img_size//2:] = 0
138
+
139
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
140
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
141
+
142
+ yield img_batch, mel_batch, frame_batch, coords_batch
143
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
144
+
145
+ if len(img_batch) > 0:
146
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
147
+
148
+ img_masked = img_batch.copy()
149
+ img_masked[:, args.img_size//2:] = 0
150
+
151
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
152
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
153
+
154
+ yield img_batch, mel_batch, frame_batch, coords_batch
155
+
156
+ mel_step_size = 16
157
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
158
+ print('Using {} for inference.'.format(device))
159
+
160
+ def _load(checkpoint_path):
161
+ if device == 'cuda':
162
+ checkpoint = torch.load(checkpoint_path)
163
+ else:
164
+ checkpoint = torch.load(checkpoint_path,
165
+ map_location=lambda storage, loc: storage)
166
+ return checkpoint
167
+
168
+ def load_model(path):
169
+ model = Wav2Lip()
170
+ print("Load checkpoint from: {}".format(path))
171
+ checkpoint = _load(path)
172
+ s = checkpoint["state_dict"]
173
+ new_s = {}
174
+ for k, v in s.items():
175
+ new_s[k.replace('module.', '')] = v
176
+ model.load_state_dict(new_s)
177
+
178
+ model = model.to(device)
179
+ return model.eval()
180
+
181
+ def main():
182
+ if not os.path.isfile(args.face):
183
+ raise ValueError('--face argument must be a valid path to video/image file')
184
+
185
+ elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
186
+ full_frames = [cv2.imread(args.face)]
187
+ fps = args.fps
188
+
189
+ else:
190
+ video_stream = cv2.VideoCapture(args.face)
191
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
192
+
193
+ print('Reading video frames...')
194
+
195
+ full_frames = []
196
+ while 1:
197
+ still_reading, frame = video_stream.read()
198
+ if not still_reading:
199
+ video_stream.release()
200
+ break
201
+ if args.resize_factor > 1:
202
+ frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor))
203
+
204
+ if args.rotate:
205
+ frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
206
+
207
+ y1, y2, x1, x2 = args.crop
208
+ if x2 == -1: x2 = frame.shape[1]
209
+ if y2 == -1: y2 = frame.shape[0]
210
+
211
+ frame = frame[y1:y2, x1:x2]
212
+
213
+ full_frames.append(frame)
214
+
215
+ print ("Number of frames available for inference: "+str(len(full_frames)))
216
+
217
+ if not args.audio.endswith('.wav'):
218
+ print('Extracting raw audio...')
219
+ command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav')
220
+
221
+ subprocess.call(command, shell=True)
222
+ args.audio = 'temp/temp.wav'
223
+
224
+ wav = audio.load_wav(args.audio, 16000)
225
+ mel = audio.melspectrogram(wav)
226
+ print(mel.shape)
227
+
228
+ if np.isnan(mel.reshape(-1)).sum() > 0:
229
+ raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
230
+
231
+ mel_chunks = []
232
+ mel_idx_multiplier = 80./fps
233
+ i = 0
234
+ while 1:
235
+ start_idx = int(i * mel_idx_multiplier)
236
+ if start_idx + mel_step_size > len(mel[0]):
237
+ mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
238
+ break
239
+ mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
240
+ i += 1
241
+
242
+ print("Length of mel chunks: {}".format(len(mel_chunks)))
243
+
244
+ full_frames = full_frames[:len(mel_chunks)]
245
+
246
+ batch_size = args.wav2lip_batch_size
247
+ gen = datagen(full_frames.copy(), mel_chunks)
248
+
249
+ for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
250
+ total=int(np.ceil(float(len(mel_chunks))/batch_size)))):
251
+ if i == 0:
252
+ model = load_model(args.checkpoint_path)
253
+ print ("Model loaded")
254
+
255
+ frame_h, frame_w = full_frames[0].shape[:-1]
256
+ out = cv2.VideoWriter('temp/result.avi',
257
+ cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
258
+
259
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
260
+ mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
261
+
262
+ with torch.no_grad():
263
+ pred = model(mel_batch, img_batch)
264
+
265
+ pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
266
+
267
+ for p, f, c in zip(pred, frames, coords):
268
+ y1, y2, x1, x2 = c
269
+ p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
270
+
271
+ f[y1:y2, x1:x2] = p
272
+ out.write(f)
273
+
274
+ out.release()
275
+
276
+ command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/result.avi', args.outfile)
277
+ subprocess.call(command, shell=platform.system() != 'Windows')
278
+
279
+ if __name__ == '__main__':
280
+ main()
models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .wav2lip import Wav2Lip, Wav2Lip_disc_qual
2
+ from .syncnet import SyncNet_color
models/conv.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+
15
+ def forward(self, x):
16
+ out = self.conv_block(x)
17
+ if self.residual:
18
+ out += x
19
+ return self.act(out)
20
+
21
+ class nonorm_Conv2d(nn.Module):
22
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ self.conv_block = nn.Sequential(
25
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
26
+ )
27
+ self.act = nn.LeakyReLU(0.01, inplace=True)
28
+
29
+ def forward(self, x):
30
+ out = self.conv_block(x)
31
+ return self.act(out)
32
+
33
+ class Conv2dTranspose(nn.Module):
34
+ def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
35
+ super().__init__(*args, **kwargs)
36
+ self.conv_block = nn.Sequential(
37
+ nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
38
+ nn.BatchNorm2d(cout)
39
+ )
40
+ self.act = nn.ReLU()
41
+
42
+ def forward(self, x):
43
+ out = self.conv_block(x)
44
+ return self.act(out)
models/syncnet.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ from .conv import Conv2d
6
+
7
+ class SyncNet_color(nn.Module):
8
+ def __init__(self):
9
+ super(SyncNet_color, self).__init__()
10
+
11
+ self.face_encoder = nn.Sequential(
12
+ Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3),
13
+
14
+ Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1),
15
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
16
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
17
+
18
+ Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
19
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
20
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
21
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
22
+
23
+ Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
24
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
25
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
26
+
27
+ Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
28
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
29
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
30
+
31
+ Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
32
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
33
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
34
+
35
+ self.audio_encoder = nn.Sequential(
36
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
37
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
38
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
39
+
40
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
41
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
42
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
43
+
44
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
45
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
46
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
47
+
48
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
49
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
50
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
51
+
52
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
53
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
54
+
55
+ def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T)
56
+ face_embedding = self.face_encoder(face_sequences)
57
+ audio_embedding = self.audio_encoder(audio_sequences)
58
+
59
+ audio_embedding = audio_embedding.view(audio_embedding.size(0), -1)
60
+ face_embedding = face_embedding.view(face_embedding.size(0), -1)
61
+
62
+ audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
63
+ face_embedding = F.normalize(face_embedding, p=2, dim=1)
64
+
65
+
66
+ return audio_embedding, face_embedding
models/wav2lip.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ import math
5
+
6
+ from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d
7
+
8
+ class Wav2Lip(nn.Module):
9
+ def __init__(self):
10
+ super(Wav2Lip, self).__init__()
11
+
12
+ self.face_encoder_blocks = nn.ModuleList([
13
+ nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96
14
+
15
+ nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48
16
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
17
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)),
18
+
19
+ nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24
20
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
21
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
22
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)),
23
+
24
+ nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12
25
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
26
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)),
27
+
28
+ nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6
29
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
30
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)),
31
+
32
+ nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3
33
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
34
+
35
+ nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
36
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
37
+
38
+ self.audio_encoder = nn.Sequential(
39
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
40
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
41
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
42
+
43
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
44
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
45
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
46
+
47
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
48
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
49
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
50
+
51
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
52
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
53
+
54
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
55
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
56
+
57
+ self.face_decoder_blocks = nn.ModuleList([
58
+ nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),),
59
+
60
+ nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3
61
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
62
+
63
+ nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
64
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
65
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6
66
+
67
+ nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
68
+ Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
69
+ Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12
70
+
71
+ nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
72
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
73
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24
74
+
75
+ nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
76
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
77
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48
78
+
79
+ nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
80
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
81
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96
82
+
83
+ self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
84
+ nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
85
+ nn.Sigmoid())
86
+
87
+ def forward(self, audio_sequences, face_sequences):
88
+ # audio_sequences = (B, T, 1, 80, 16)
89
+ B = audio_sequences.size(0)
90
+
91
+ input_dim_size = len(face_sequences.size())
92
+ if input_dim_size > 4:
93
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
94
+ face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
95
+
96
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
97
+
98
+ feats = []
99
+ x = face_sequences
100
+ for f in self.face_encoder_blocks:
101
+ x = f(x)
102
+ feats.append(x)
103
+
104
+ x = audio_embedding
105
+ for f in self.face_decoder_blocks:
106
+ x = f(x)
107
+ try:
108
+ x = torch.cat((x, feats[-1]), dim=1)
109
+ except Exception as e:
110
+ print(x.size())
111
+ print(feats[-1].size())
112
+ raise e
113
+
114
+ feats.pop()
115
+
116
+ x = self.output_block(x)
117
+
118
+ if input_dim_size > 4:
119
+ x = torch.split(x, B, dim=0) # [(B, C, H, W)]
120
+ outputs = torch.stack(x, dim=2) # (B, C, T, H, W)
121
+
122
+ else:
123
+ outputs = x
124
+
125
+ return outputs
126
+
127
+ class Wav2Lip_disc_qual(nn.Module):
128
+ def __init__(self):
129
+ super(Wav2Lip_disc_qual, self).__init__()
130
+
131
+ self.face_encoder_blocks = nn.ModuleList([
132
+ nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)), # 48,96
133
+
134
+ nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2), # 48,48
135
+ nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)),
136
+
137
+ nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 24,24
138
+ nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)),
139
+
140
+ nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 12,12
141
+ nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)),
142
+
143
+ nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6,6
144
+ nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)),
145
+
146
+ nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 3,3
147
+ nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1),),
148
+
149
+ nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
150
+ nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
151
+
152
+ self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid())
153
+ self.label_noise = .0
154
+
155
+ def get_lower_half(self, face_sequences):
156
+ return face_sequences[:, :, face_sequences.size(2)//2:]
157
+
158
+ def to_2d(self, face_sequences):
159
+ B = face_sequences.size(0)
160
+ face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
161
+ return face_sequences
162
+
163
+ def perceptual_forward(self, false_face_sequences):
164
+ false_face_sequences = self.to_2d(false_face_sequences)
165
+ false_face_sequences = self.get_lower_half(false_face_sequences)
166
+
167
+ false_feats = false_face_sequences
168
+ for f in self.face_encoder_blocks:
169
+ false_feats = f(false_feats)
170
+
171
+ false_pred_loss = F.binary_cross_entropy(self.binary_pred(false_feats).view(len(false_feats), -1),
172
+ torch.ones((len(false_feats), 1)).cuda())
173
+
174
+ return false_pred_loss
175
+
176
+ def forward(self, face_sequences):
177
+ face_sequences = self.to_2d(face_sequences)
178
+ face_sequences = self.get_lower_half(face_sequences)
179
+
180
+ x = face_sequences
181
+ for f in self.face_encoder_blocks:
182
+ x = f(x)
183
+
184
+ return self.binary_pred(x).view(len(x), -1)
preprocess.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ if sys.version_info[0] < 3 and sys.version_info[1] < 2:
4
+ raise Exception("Must be using >= Python 3.2")
5
+
6
+ from os import listdir, path
7
+
8
+ if not path.isfile('face_detection/detection/sfd/s3fd.pth'):
9
+ raise FileNotFoundError('Save the s3fd model to face_detection/detection/sfd/s3fd.pth \
10
+ before running this script!')
11
+
12
+ import multiprocessing as mp
13
+ from concurrent.futures import ThreadPoolExecutor, as_completed
14
+ import numpy as np
15
+ import argparse, os, cv2, traceback, subprocess
16
+ from tqdm import tqdm
17
+ from glob import glob
18
+ import audio
19
+ from hparams import hparams as hp
20
+
21
+ import face_detection
22
+
23
+ parser = argparse.ArgumentParser()
24
+
25
+ parser.add_argument('--ngpu', help='Number of GPUs across which to run in parallel', default=1, type=int)
26
+ parser.add_argument('--batch_size', help='Single GPU Face detection batch size', default=32, type=int)
27
+ parser.add_argument("--data_root", help="Root folder of the LRS2 dataset", required=True)
28
+ parser.add_argument("--preprocessed_root", help="Root folder of the preprocessed dataset", required=True)
29
+
30
+ args = parser.parse_args()
31
+
32
+ fa = [face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False,
33
+ device='cuda:{}'.format(id)) for id in range(args.ngpu)]
34
+
35
+ template = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'
36
+ # template2 = 'ffmpeg -hide_banner -loglevel panic -threads 1 -y -i {} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {}'
37
+
38
+ def process_video_file(vfile, args, gpu_id):
39
+ video_stream = cv2.VideoCapture(vfile)
40
+
41
+ frames = []
42
+ while 1:
43
+ still_reading, frame = video_stream.read()
44
+ if not still_reading:
45
+ video_stream.release()
46
+ break
47
+ frames.append(frame)
48
+
49
+ vidname = os.path.basename(vfile).split('.')[0]
50
+ dirname = vfile.split('/')[-2]
51
+
52
+ fulldir = path.join(args.preprocessed_root, dirname, vidname)
53
+ os.makedirs(fulldir, exist_ok=True)
54
+
55
+ batches = [frames[i:i + args.batch_size] for i in range(0, len(frames), args.batch_size)]
56
+
57
+ i = -1
58
+ for fb in batches:
59
+ preds = fa[gpu_id].get_detections_for_batch(np.asarray(fb))
60
+
61
+ for j, f in enumerate(preds):
62
+ i += 1
63
+ if f is None:
64
+ continue
65
+
66
+ x1, y1, x2, y2 = f
67
+ cv2.imwrite(path.join(fulldir, '{}.jpg'.format(i)), fb[j][y1:y2, x1:x2])
68
+
69
+ def process_audio_file(vfile, args):
70
+ vidname = os.path.basename(vfile).split('.')[0]
71
+ dirname = vfile.split('/')[-2]
72
+
73
+ fulldir = path.join(args.preprocessed_root, dirname, vidname)
74
+ os.makedirs(fulldir, exist_ok=True)
75
+
76
+ wavpath = path.join(fulldir, 'audio.wav')
77
+
78
+ command = template.format(vfile, wavpath)
79
+ subprocess.call(command, shell=True)
80
+
81
+
82
+ def mp_handler(job):
83
+ vfile, args, gpu_id = job
84
+ try:
85
+ process_video_file(vfile, args, gpu_id)
86
+ except KeyboardInterrupt:
87
+ exit(0)
88
+ except:
89
+ traceback.print_exc()
90
+
91
+ def main(args):
92
+ print('Started processing for {} with {} GPUs'.format(args.data_root, args.ngpu))
93
+
94
+ filelist = glob(path.join(args.data_root, '*/*.mp4'))
95
+
96
+ jobs = [(vfile, args, i%args.ngpu) for i, vfile in enumerate(filelist)]
97
+ p = ThreadPoolExecutor(args.ngpu)
98
+ futures = [p.submit(mp_handler, j) for j in jobs]
99
+ _ = [r.result() for r in tqdm(as_completed(futures), total=len(futures))]
100
+
101
+ print('Dumping audios...')
102
+
103
+ for vfile in tqdm(filelist):
104
+ try:
105
+ process_audio_file(vfile, args)
106
+ except KeyboardInterrupt:
107
+ exit(0)
108
+ except:
109
+ traceback.print_exc()
110
+ continue
111
+
112
+ if __name__ == '__main__':
113
+ main(args)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ librosa
2
+ numpy
3
+ opencv-contrib-python
4
+ opencv-python
5
+ torch
6
+ torchvision
7
+ tqdm
8
+ numba
requirementsCPU.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ librosa
2
+ numpy
3
+ opencv-contrib-python
4
+ opencv-python
5
+ -f https://download.pytorch.org/whl/torch_stable.html
6
+ torch
7
+ torchvision
8
+ tqdm
9
+ numba
wav2lip_train.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import dirname, join, basename, isfile
2
+ from tqdm import tqdm
3
+
4
+ from models import SyncNet_color as SyncNet
5
+ from models import Wav2Lip as Wav2Lip
6
+ import audio
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch import optim
11
+ import torch.backends.cudnn as cudnn
12
+ from torch.utils import data as data_utils
13
+ import numpy as np
14
+
15
+ from glob import glob
16
+
17
+ import os, random, cv2, argparse
18
+ from hparams import hparams, get_image_list
19
+
20
+ parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model without the visual quality discriminator')
21
+
22
+ parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True, type=str)
23
+
24
+ parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
25
+ parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True, type=str)
26
+
27
+ parser.add_argument('--checkpoint_path', help='Resume from this checkpoint', default=None, type=str)
28
+
29
+ args = parser.parse_args()
30
+
31
+
32
+ global_step = 0
33
+ global_epoch = 0
34
+ use_cuda = torch.cuda.is_available()
35
+ print('use_cuda: {}'.format(use_cuda))
36
+
37
+ syncnet_T = 5
38
+ syncnet_mel_step_size = 16
39
+
40
+ class Dataset(object):
41
+ def __init__(self, split):
42
+ self.all_videos = get_image_list(args.data_root, split)
43
+
44
+ def get_frame_id(self, frame):
45
+ return int(basename(frame).split('.')[0])
46
+
47
+ def get_window(self, start_frame):
48
+ start_id = self.get_frame_id(start_frame)
49
+ vidname = dirname(start_frame)
50
+
51
+ window_fnames = []
52
+ for frame_id in range(start_id, start_id + syncnet_T):
53
+ frame = join(vidname, '{}.jpg'.format(frame_id))
54
+ if not isfile(frame):
55
+ return None
56
+ window_fnames.append(frame)
57
+ return window_fnames
58
+
59
+ def read_window(self, window_fnames):
60
+ if window_fnames is None: return None
61
+ window = []
62
+ for fname in window_fnames:
63
+ img = cv2.imread(fname)
64
+ if img is None:
65
+ return None
66
+ try:
67
+ img = cv2.resize(img, (hparams.img_size, hparams.img_size))
68
+ except Exception as e:
69
+ return None
70
+
71
+ window.append(img)
72
+
73
+ return window
74
+
75
+ def crop_audio_window(self, spec, start_frame):
76
+ if type(start_frame) == int:
77
+ start_frame_num = start_frame
78
+ else:
79
+ start_frame_num = self.get_frame_id(start_frame) # 0-indexing ---> 1-indexing
80
+ start_idx = int(80. * (start_frame_num / float(hparams.fps)))
81
+
82
+ end_idx = start_idx + syncnet_mel_step_size
83
+
84
+ return spec[start_idx : end_idx, :]
85
+
86
+ def get_segmented_mels(self, spec, start_frame):
87
+ mels = []
88
+ assert syncnet_T == 5
89
+ start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing
90
+ if start_frame_num - 2 < 0: return None
91
+ for i in range(start_frame_num, start_frame_num + syncnet_T):
92
+ m = self.crop_audio_window(spec, i - 2)
93
+ if m.shape[0] != syncnet_mel_step_size:
94
+ return None
95
+ mels.append(m.T)
96
+
97
+ mels = np.asarray(mels)
98
+
99
+ return mels
100
+
101
+ def prepare_window(self, window):
102
+ # 3 x T x H x W
103
+ x = np.asarray(window) / 255.
104
+ x = np.transpose(x, (3, 0, 1, 2))
105
+
106
+ return x
107
+
108
+ def __len__(self):
109
+ return len(self.all_videos)
110
+
111
+ def __getitem__(self, idx):
112
+ while 1:
113
+ idx = random.randint(0, len(self.all_videos) - 1)
114
+ vidname = self.all_videos[idx]
115
+ img_names = list(glob(join(vidname, '*.jpg')))
116
+ if len(img_names) <= 3 * syncnet_T:
117
+ continue
118
+
119
+ img_name = random.choice(img_names)
120
+ wrong_img_name = random.choice(img_names)
121
+ while wrong_img_name == img_name:
122
+ wrong_img_name = random.choice(img_names)
123
+
124
+ window_fnames = self.get_window(img_name)
125
+ wrong_window_fnames = self.get_window(wrong_img_name)
126
+ if window_fnames is None or wrong_window_fnames is None:
127
+ continue
128
+
129
+ window = self.read_window(window_fnames)
130
+ if window is None:
131
+ continue
132
+
133
+ wrong_window = self.read_window(wrong_window_fnames)
134
+ if wrong_window is None:
135
+ continue
136
+
137
+ try:
138
+ wavpath = join(vidname, "audio.wav")
139
+ wav = audio.load_wav(wavpath, hparams.sample_rate)
140
+
141
+ orig_mel = audio.melspectrogram(wav).T
142
+ except Exception as e:
143
+ continue
144
+
145
+ mel = self.crop_audio_window(orig_mel.copy(), img_name)
146
+
147
+ if (mel.shape[0] != syncnet_mel_step_size):
148
+ continue
149
+
150
+ indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name)
151
+ if indiv_mels is None: continue
152
+
153
+ window = self.prepare_window(window)
154
+ y = window.copy()
155
+ window[:, :, window.shape[2]//2:] = 0.
156
+
157
+ wrong_window = self.prepare_window(wrong_window)
158
+ x = np.concatenate([window, wrong_window], axis=0)
159
+
160
+ x = torch.FloatTensor(x)
161
+ mel = torch.FloatTensor(mel.T).unsqueeze(0)
162
+ indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1)
163
+ y = torch.FloatTensor(y)
164
+ return x, indiv_mels, mel, y
165
+
166
+ def save_sample_images(x, g, gt, global_step, checkpoint_dir):
167
+ x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
168
+ g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
169
+ gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
170
+
171
+ refs, inps = x[..., 3:], x[..., :3]
172
+ folder = join(checkpoint_dir, "samples_step{:09d}".format(global_step))
173
+ if not os.path.exists(folder): os.mkdir(folder)
174
+ collage = np.concatenate((refs, inps, g, gt), axis=-2)
175
+ for batch_idx, c in enumerate(collage):
176
+ for t in range(len(c)):
177
+ cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t])
178
+
179
+ logloss = nn.BCELoss()
180
+ def cosine_loss(a, v, y):
181
+ d = nn.functional.cosine_similarity(a, v)
182
+ loss = logloss(d.unsqueeze(1), y)
183
+
184
+ return loss
185
+
186
+ device = torch.device("cuda" if use_cuda else "cpu")
187
+ syncnet = SyncNet().to(device)
188
+ for p in syncnet.parameters():
189
+ p.requires_grad = False
190
+
191
+ recon_loss = nn.L1Loss()
192
+ def get_sync_loss(mel, g):
193
+ g = g[:, :, :, g.size(3)//2:]
194
+ g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1)
195
+ # B, 3 * T, H//2, W
196
+ a, v = syncnet(mel, g)
197
+ y = torch.ones(g.size(0), 1).float().to(device)
198
+ return cosine_loss(a, v, y)
199
+
200
+ def train(device, model, train_data_loader, test_data_loader, optimizer,
201
+ checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
202
+
203
+ global global_step, global_epoch
204
+ resumed_step = global_step
205
+
206
+ while global_epoch < nepochs:
207
+ print('Starting Epoch: {}'.format(global_epoch))
208
+ running_sync_loss, running_l1_loss = 0., 0.
209
+ prog_bar = tqdm(enumerate(train_data_loader))
210
+ for step, (x, indiv_mels, mel, gt) in prog_bar:
211
+ model.train()
212
+ optimizer.zero_grad()
213
+
214
+ # Move data to CUDA device
215
+ x = x.to(device)
216
+ mel = mel.to(device)
217
+ indiv_mels = indiv_mels.to(device)
218
+ gt = gt.to(device)
219
+
220
+ g = model(indiv_mels, x)
221
+
222
+ if hparams.syncnet_wt > 0.:
223
+ sync_loss = get_sync_loss(mel, g)
224
+ else:
225
+ sync_loss = 0.
226
+
227
+ l1loss = recon_loss(g, gt)
228
+
229
+ loss = hparams.syncnet_wt * sync_loss + (1 - hparams.syncnet_wt) * l1loss
230
+ loss.backward()
231
+ optimizer.step()
232
+
233
+ if global_step % checkpoint_interval == 0:
234
+ save_sample_images(x, g, gt, global_step, checkpoint_dir)
235
+
236
+ global_step += 1
237
+ cur_session_steps = global_step - resumed_step
238
+
239
+ running_l1_loss += l1loss.item()
240
+ if hparams.syncnet_wt > 0.:
241
+ running_sync_loss += sync_loss.item()
242
+ else:
243
+ running_sync_loss += 0.
244
+
245
+ if global_step == 1 or global_step % checkpoint_interval == 0:
246
+ save_checkpoint(
247
+ model, optimizer, global_step, checkpoint_dir, global_epoch)
248
+
249
+ if global_step == 1 or global_step % hparams.eval_interval == 0:
250
+ with torch.no_grad():
251
+ average_sync_loss = eval_model(test_data_loader, global_step, device, model, checkpoint_dir)
252
+
253
+ if average_sync_loss < .75:
254
+ hparams.set_hparam('syncnet_wt', 0.01) # without image GAN a lesser weight is sufficient
255
+
256
+ prog_bar.set_description('L1: {}, Sync Loss: {}'.format(running_l1_loss / (step + 1),
257
+ running_sync_loss / (step + 1)))
258
+
259
+ global_epoch += 1
260
+
261
+
262
+ def eval_model(test_data_loader, global_step, device, model, checkpoint_dir):
263
+ eval_steps = 700
264
+ print('Evaluating for {} steps'.format(eval_steps))
265
+ sync_losses, recon_losses = [], []
266
+ step = 0
267
+ while 1:
268
+ for x, indiv_mels, mel, gt in test_data_loader:
269
+ step += 1
270
+ model.eval()
271
+
272
+ # Move data to CUDA device
273
+ x = x.to(device)
274
+ gt = gt.to(device)
275
+ indiv_mels = indiv_mels.to(device)
276
+ mel = mel.to(device)
277
+
278
+ g = model(indiv_mels, x)
279
+
280
+ sync_loss = get_sync_loss(mel, g)
281
+ l1loss = recon_loss(g, gt)
282
+
283
+ sync_losses.append(sync_loss.item())
284
+ recon_losses.append(l1loss.item())
285
+
286
+ if step > eval_steps:
287
+ averaged_sync_loss = sum(sync_losses) / len(sync_losses)
288
+ averaged_recon_loss = sum(recon_losses) / len(recon_losses)
289
+
290
+ print('L1: {}, Sync loss: {}'.format(averaged_recon_loss, averaged_sync_loss))
291
+
292
+ return averaged_sync_loss
293
+
294
+ def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch):
295
+
296
+ checkpoint_path = join(
297
+ checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step))
298
+ optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
299
+ torch.save({
300
+ "state_dict": model.state_dict(),
301
+ "optimizer": optimizer_state,
302
+ "global_step": step,
303
+ "global_epoch": epoch,
304
+ }, checkpoint_path)
305
+ print("Saved checkpoint:", checkpoint_path)
306
+
307
+
308
+ def _load(checkpoint_path):
309
+ if use_cuda:
310
+ checkpoint = torch.load(checkpoint_path)
311
+ else:
312
+ checkpoint = torch.load(checkpoint_path,
313
+ map_location=lambda storage, loc: storage)
314
+ return checkpoint
315
+
316
+ def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
317
+ global global_step
318
+ global global_epoch
319
+
320
+ print("Load checkpoint from: {}".format(path))
321
+ checkpoint = _load(path)
322
+ s = checkpoint["state_dict"]
323
+ new_s = {}
324
+ for k, v in s.items():
325
+ new_s[k.replace('module.', '')] = v
326
+ model.load_state_dict(new_s)
327
+ if not reset_optimizer:
328
+ optimizer_state = checkpoint["optimizer"]
329
+ if optimizer_state is not None:
330
+ print("Load optimizer state from {}".format(path))
331
+ optimizer.load_state_dict(checkpoint["optimizer"])
332
+ if overwrite_global_states:
333
+ global_step = checkpoint["global_step"]
334
+ global_epoch = checkpoint["global_epoch"]
335
+
336
+ return model
337
+
338
+ if __name__ == "__main__":
339
+ checkpoint_dir = args.checkpoint_dir
340
+
341
+ # Dataset and Dataloader setup
342
+ train_dataset = Dataset('train')
343
+ test_dataset = Dataset('val')
344
+
345
+ train_data_loader = data_utils.DataLoader(
346
+ train_dataset, batch_size=hparams.batch_size, shuffle=True,
347
+ num_workers=hparams.num_workers)
348
+
349
+ test_data_loader = data_utils.DataLoader(
350
+ test_dataset, batch_size=hparams.batch_size,
351
+ num_workers=4)
352
+
353
+ device = torch.device("cuda" if use_cuda else "cpu")
354
+
355
+ # Model
356
+ model = Wav2Lip().to(device)
357
+ print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
358
+
359
+ optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
360
+ lr=hparams.initial_learning_rate)
361
+
362
+ if args.checkpoint_path is not None:
363
+ load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)
364
+
365
+ load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True, overwrite_global_states=False)
366
+
367
+ if not os.path.exists(checkpoint_dir):
368
+ os.mkdir(checkpoint_dir)
369
+
370
+ # Train!
371
+ train(device, model, train_data_loader, test_data_loader, optimizer,
372
+ checkpoint_dir=checkpoint_dir,
373
+ checkpoint_interval=hparams.checkpoint_interval,
374
+ nepochs=hparams.nepochs)