)
+
+ frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
+
+ # Mel and Linear spectrograms normalization/scaling and clipping
+ signal_normalization=True,
+ # Whether to normalize mel spectrograms to some predefined range (following below parameters)
+ allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
+ symmetric_mels=True,
+ # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
+ # faster and cleaner convergence)
+ max_abs_value=4.,
+ # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
+ # be too big to avoid gradient explosion,
+ # not too small for fast convergence)
+ # Contribution by @begeekmyfriend
+ # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
+ # levels. Also allows for better G&L phase reconstruction)
+ preemphasize=True, # whether to apply filter
+ preemphasis=0.97, # filter coefficient.
+
+ # Limits
+ min_level_db=-100,
+ ref_level_db=20,
+ fmin=55,
+ # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
+ # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
+ fmax=7600, # To be increased/reduced depending on data.
+
+ ###################### Our training parameters #################################
+ img_size=96,
+ fps=25,
+
+ batch_size=16,
+ initial_learning_rate=1e-4,
+ nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
+ num_workers=16,
+ checkpoint_interval=3000,
+ eval_interval=3000,
+ save_optimizer_state=True,
+
+ syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
+ syncnet_batch_size=64,
+ syncnet_lr=1e-4,
+ syncnet_eval_interval=10000,
+ syncnet_checkpoint_interval=10000,
+
+ disc_wt=0.07,
+ disc_initial_learning_rate=1e-4,
+)
+
+
+def hparams_debug_string():
+ values = hparams.values()
+ hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"]
+ return "Hyperparameters:\n" + "\n".join(hp)
diff --git a/Wav2Lip/hq_wav2lip_train.py b/Wav2Lip/hq_wav2lip_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8994adba8f9196d628079cb2e28af57cef4bf34
--- /dev/null
+++ b/Wav2Lip/hq_wav2lip_train.py
@@ -0,0 +1,443 @@
+from os.path import dirname, join, basename, isfile
+from tqdm import tqdm
+
+from models import SyncNet_color as SyncNet
+from models import Wav2Lip, Wav2Lip_disc_qual
+import audio
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch import optim
+import torch.backends.cudnn as cudnn
+from torch.utils import data as data_utils
+import numpy as np
+
+from glob import glob
+
+import os, random, cv2, argparse
+from hparams import hparams, get_image_list
+
+parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model WITH the visual quality discriminator')
+
+parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True, type=str)
+
+parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
+parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True, type=str)
+
+parser.add_argument('--checkpoint_path', help='Resume generator from this checkpoint', default=None, type=str)
+parser.add_argument('--disc_checkpoint_path', help='Resume quality disc from this checkpoint', default=None, type=str)
+
+args = parser.parse_args()
+
+
+global_step = 0
+global_epoch = 0
+use_cuda = torch.cuda.is_available()
+print('use_cuda: {}'.format(use_cuda))
+
+syncnet_T = 5
+syncnet_mel_step_size = 16
+
+class Dataset(object):
+ def __init__(self, split):
+ self.all_videos = get_image_list(args.data_root, split)
+
+ def get_frame_id(self, frame):
+ return int(basename(frame).split('.')[0])
+
+ def get_window(self, start_frame):
+ start_id = self.get_frame_id(start_frame)
+ vidname = dirname(start_frame)
+
+ window_fnames = []
+ for frame_id in range(start_id, start_id + syncnet_T):
+ frame = join(vidname, '{}.jpg'.format(frame_id))
+ if not isfile(frame):
+ return None
+ window_fnames.append(frame)
+ return window_fnames
+
+ def read_window(self, window_fnames):
+ if window_fnames is None: return None
+ window = []
+ for fname in window_fnames:
+ img = cv2.imread(fname)
+ if img is None:
+ return None
+ try:
+ img = cv2.resize(img, (hparams.img_size, hparams.img_size))
+ except Exception as e:
+ return None
+
+ window.append(img)
+
+ return window
+
+ def crop_audio_window(self, spec, start_frame):
+ if type(start_frame) == int:
+ start_frame_num = start_frame
+ else:
+ start_frame_num = self.get_frame_id(start_frame)
+ start_idx = int(80. * (start_frame_num / float(hparams.fps)))
+
+ end_idx = start_idx + syncnet_mel_step_size
+
+ return spec[start_idx : end_idx, :]
+
+ def get_segmented_mels(self, spec, start_frame):
+ mels = []
+ assert syncnet_T == 5
+ start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing
+ if start_frame_num - 2 < 0: return None
+ for i in range(start_frame_num, start_frame_num + syncnet_T):
+ m = self.crop_audio_window(spec, i - 2)
+ if m.shape[0] != syncnet_mel_step_size:
+ return None
+ mels.append(m.T)
+
+ mels = np.asarray(mels)
+
+ return mels
+
+ def prepare_window(self, window):
+ # 3 x T x H x W
+ x = np.asarray(window) / 255.
+ x = np.transpose(x, (3, 0, 1, 2))
+
+ return x
+
+ def __len__(self):
+ return len(self.all_videos)
+
+ def __getitem__(self, idx):
+ while 1:
+ idx = random.randint(0, len(self.all_videos) - 1)
+ vidname = self.all_videos[idx]
+ img_names = list(glob(join(vidname, '*.jpg')))
+ if len(img_names) <= 3 * syncnet_T:
+ continue
+
+ img_name = random.choice(img_names)
+ wrong_img_name = random.choice(img_names)
+ while wrong_img_name == img_name:
+ wrong_img_name = random.choice(img_names)
+
+ window_fnames = self.get_window(img_name)
+ wrong_window_fnames = self.get_window(wrong_img_name)
+ if window_fnames is None or wrong_window_fnames is None:
+ continue
+
+ window = self.read_window(window_fnames)
+ if window is None:
+ continue
+
+ wrong_window = self.read_window(wrong_window_fnames)
+ if wrong_window is None:
+ continue
+
+ try:
+ wavpath = join(vidname, "audio.wav")
+ wav = audio.load_wav(wavpath, hparams.sample_rate)
+
+ orig_mel = audio.melspectrogram(wav).T
+ except Exception as e:
+ continue
+
+ mel = self.crop_audio_window(orig_mel.copy(), img_name)
+
+ if (mel.shape[0] != syncnet_mel_step_size):
+ continue
+
+ indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name)
+ if indiv_mels is None: continue
+
+ window = self.prepare_window(window)
+ y = window.copy()
+ window[:, :, window.shape[2]//2:] = 0.
+
+ wrong_window = self.prepare_window(wrong_window)
+ x = np.concatenate([window, wrong_window], axis=0)
+
+ x = torch.FloatTensor(x)
+ mel = torch.FloatTensor(mel.T).unsqueeze(0)
+ indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1)
+ y = torch.FloatTensor(y)
+ return x, indiv_mels, mel, y
+
+def save_sample_images(x, g, gt, global_step, checkpoint_dir):
+ x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
+ g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
+ gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
+
+ refs, inps = x[..., 3:], x[..., :3]
+ folder = join(checkpoint_dir, "samples_step{:09d}".format(global_step))
+ if not os.path.exists(folder): os.mkdir(folder)
+ collage = np.concatenate((refs, inps, g, gt), axis=-2)
+ for batch_idx, c in enumerate(collage):
+ for t in range(len(c)):
+ cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t])
+
+logloss = nn.BCELoss()
+def cosine_loss(a, v, y):
+ d = nn.functional.cosine_similarity(a, v)
+ loss = logloss(d.unsqueeze(1), y)
+
+ return loss
+
+device = torch.device("cuda" if use_cuda else "cpu")
+syncnet = SyncNet().to(device)
+for p in syncnet.parameters():
+ p.requires_grad = False
+
+recon_loss = nn.L1Loss()
+def get_sync_loss(mel, g):
+ g = g[:, :, :, g.size(3)//2:]
+ g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1)
+ # B, 3 * T, H//2, W
+ a, v = syncnet(mel, g)
+ y = torch.ones(g.size(0), 1).float().to(device)
+ return cosine_loss(a, v, y)
+
+def train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer,
+ checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
+ global global_step, global_epoch
+ resumed_step = global_step
+
+ while global_epoch < nepochs:
+ print('Starting Epoch: {}'.format(global_epoch))
+ running_sync_loss, running_l1_loss, disc_loss, running_perceptual_loss = 0., 0., 0., 0.
+ running_disc_real_loss, running_disc_fake_loss = 0., 0.
+ prog_bar = tqdm(enumerate(train_data_loader))
+ for step, (x, indiv_mels, mel, gt) in prog_bar:
+ disc.train()
+ model.train()
+
+ x = x.to(device)
+ mel = mel.to(device)
+ indiv_mels = indiv_mels.to(device)
+ gt = gt.to(device)
+
+ ### Train generator now. Remove ALL grads.
+ optimizer.zero_grad()
+ disc_optimizer.zero_grad()
+
+ g = model(indiv_mels, x)
+
+ if hparams.syncnet_wt > 0.:
+ sync_loss = get_sync_loss(mel, g)
+ else:
+ sync_loss = 0.
+
+ if hparams.disc_wt > 0.:
+ perceptual_loss = disc.perceptual_forward(g)
+ else:
+ perceptual_loss = 0.
+
+ l1loss = recon_loss(g, gt)
+
+ loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \
+ (1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss
+
+ loss.backward()
+ optimizer.step()
+
+ ### Remove all gradients before Training disc
+ disc_optimizer.zero_grad()
+
+ pred = disc(gt)
+ disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device))
+ disc_real_loss.backward()
+
+ pred = disc(g.detach())
+ disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device))
+ disc_fake_loss.backward()
+
+ disc_optimizer.step()
+
+ running_disc_real_loss += disc_real_loss.item()
+ running_disc_fake_loss += disc_fake_loss.item()
+
+ if global_step % checkpoint_interval == 0:
+ save_sample_images(x, g, gt, global_step, checkpoint_dir)
+
+ # Logs
+ global_step += 1
+ cur_session_steps = global_step - resumed_step
+
+ running_l1_loss += l1loss.item()
+ if hparams.syncnet_wt > 0.:
+ running_sync_loss += sync_loss.item()
+ else:
+ running_sync_loss += 0.
+
+ if hparams.disc_wt > 0.:
+ running_perceptual_loss += perceptual_loss.item()
+ else:
+ running_perceptual_loss += 0.
+
+ if global_step == 1 or global_step % checkpoint_interval == 0:
+ save_checkpoint(
+ model, optimizer, global_step, checkpoint_dir, global_epoch)
+ save_checkpoint(disc, disc_optimizer, global_step, checkpoint_dir, global_epoch, prefix='disc_')
+
+
+ if global_step % hparams.eval_interval == 0:
+ with torch.no_grad():
+ average_sync_loss = eval_model(test_data_loader, global_step, device, model, disc)
+
+ if average_sync_loss < .75:
+ hparams.set_hparam('syncnet_wt', 0.03)
+
+ prog_bar.set_description('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(running_l1_loss / (step + 1),
+ running_sync_loss / (step + 1),
+ running_perceptual_loss / (step + 1),
+ running_disc_fake_loss / (step + 1),
+ running_disc_real_loss / (step + 1)))
+
+ global_epoch += 1
+
+def eval_model(test_data_loader, global_step, device, model, disc):
+ eval_steps = 300
+ print('Evaluating for {} steps'.format(eval_steps))
+ running_sync_loss, running_l1_loss, running_disc_real_loss, running_disc_fake_loss, running_perceptual_loss = [], [], [], [], []
+ while 1:
+ for step, (x, indiv_mels, mel, gt) in enumerate((test_data_loader)):
+ model.eval()
+ disc.eval()
+
+ x = x.to(device)
+ mel = mel.to(device)
+ indiv_mels = indiv_mels.to(device)
+ gt = gt.to(device)
+
+ pred = disc(gt)
+ disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device))
+
+ g = model(indiv_mels, x)
+ pred = disc(g)
+ disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device))
+
+ running_disc_real_loss.append(disc_real_loss.item())
+ running_disc_fake_loss.append(disc_fake_loss.item())
+
+ sync_loss = get_sync_loss(mel, g)
+
+ if hparams.disc_wt > 0.:
+ perceptual_loss = disc.perceptual_forward(g)
+ else:
+ perceptual_loss = 0.
+
+ l1loss = recon_loss(g, gt)
+
+ loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \
+ (1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss
+
+ running_l1_loss.append(l1loss.item())
+ running_sync_loss.append(sync_loss.item())
+
+ if hparams.disc_wt > 0.:
+ running_perceptual_loss.append(perceptual_loss.item())
+ else:
+ running_perceptual_loss.append(0.)
+
+ if step > eval_steps: break
+
+ print('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(sum(running_l1_loss) / len(running_l1_loss),
+ sum(running_sync_loss) / len(running_sync_loss),
+ sum(running_perceptual_loss) / len(running_perceptual_loss),
+ sum(running_disc_fake_loss) / len(running_disc_fake_loss),
+ sum(running_disc_real_loss) / len(running_disc_real_loss)))
+ return sum(running_sync_loss) / len(running_sync_loss)
+
+
+def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch, prefix=''):
+ checkpoint_path = join(
+ checkpoint_dir, "{}checkpoint_step{:09d}.pth".format(prefix, global_step))
+ optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
+ torch.save({
+ "state_dict": model.state_dict(),
+ "optimizer": optimizer_state,
+ "global_step": step,
+ "global_epoch": epoch,
+ }, checkpoint_path)
+ print("Saved checkpoint:", checkpoint_path)
+
+def _load(checkpoint_path):
+ if use_cuda:
+ checkpoint = torch.load(checkpoint_path)
+ else:
+ checkpoint = torch.load(checkpoint_path,
+ map_location=lambda storage, loc: storage)
+ return checkpoint
+
+
+def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
+ global global_step
+ global global_epoch
+
+ print("Load checkpoint from: {}".format(path))
+ checkpoint = _load(path)
+ s = checkpoint["state_dict"]
+ new_s = {}
+ for k, v in s.items():
+ new_s[k.replace('module.', '')] = v
+ model.load_state_dict(new_s)
+ if not reset_optimizer:
+ optimizer_state = checkpoint["optimizer"]
+ if optimizer_state is not None:
+ print("Load optimizer state from {}".format(path))
+ optimizer.load_state_dict(checkpoint["optimizer"])
+ if overwrite_global_states:
+ global_step = checkpoint["global_step"]
+ global_epoch = checkpoint["global_epoch"]
+
+ return model
+
+if __name__ == "__main__":
+ checkpoint_dir = args.checkpoint_dir
+
+ # Dataset and Dataloader setup
+ train_dataset = Dataset('train')
+ test_dataset = Dataset('val')
+
+ train_data_loader = data_utils.DataLoader(
+ train_dataset, batch_size=hparams.batch_size, shuffle=True,
+ num_workers=hparams.num_workers)
+
+ test_data_loader = data_utils.DataLoader(
+ test_dataset, batch_size=hparams.batch_size,
+ num_workers=4)
+
+ device = torch.device("cuda" if use_cuda else "cpu")
+
+ # Model
+ model = Wav2Lip().to(device)
+ disc = Wav2Lip_disc_qual().to(device)
+
+ print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
+ print('total DISC trainable params {}'.format(sum(p.numel() for p in disc.parameters() if p.requires_grad)))
+
+ optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
+ lr=hparams.initial_learning_rate, betas=(0.5, 0.999))
+ disc_optimizer = optim.Adam([p for p in disc.parameters() if p.requires_grad],
+ lr=hparams.disc_initial_learning_rate, betas=(0.5, 0.999))
+
+ if args.checkpoint_path is not None:
+ load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)
+
+ if args.disc_checkpoint_path is not None:
+ load_checkpoint(args.disc_checkpoint_path, disc, disc_optimizer,
+ reset_optimizer=False, overwrite_global_states=False)
+
+ load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True,
+ overwrite_global_states=False)
+
+ if not os.path.exists(checkpoint_dir):
+ os.mkdir(checkpoint_dir)
+
+ # Train!
+ train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer,
+ checkpoint_dir=checkpoint_dir,
+ checkpoint_interval=hparams.checkpoint_interval,
+ nepochs=hparams.nepochs)
diff --git a/Wav2Lip/inference.py b/Wav2Lip/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..143537969f96fc72f41c5ef4136c679f5b5751f8
--- /dev/null
+++ b/Wav2Lip/inference.py
@@ -0,0 +1,280 @@
+from os import listdir, path
+import numpy as np
+import scipy, cv2, os, sys, argparse, audio
+import json, subprocess, random, string
+from tqdm import tqdm
+from glob import glob
+import torch, face_detection
+from models import Wav2Lip
+import platform
+
+parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
+
+parser.add_argument('--checkpoint_path', type=str,
+ help='Name of saved checkpoint to load weights from', required=True)
+
+parser.add_argument('--face', type=str,
+ help='Filepath of video/image that contains faces to use', required=True)
+parser.add_argument('--audio', type=str,
+ help='Filepath of video/audio file to use as raw audio source', required=True)
+parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.',
+ default='results/result_voice.mp4')
+
+parser.add_argument('--static', type=bool,
+ help='If True, then use only first video frame for inference', default=False)
+parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)',
+ default=25., required=False)
+
+parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
+ help='Padding (top, bottom, left, right). Please adjust to include chin at least')
+
+parser.add_argument('--face_det_batch_size', type=int,
+ help='Batch size for face detection', default=16)
+parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=128)
+
+parser.add_argument('--resize_factor', default=1, type=int,
+ help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')
+
+parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1],
+ help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
+ 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')
+
+parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1],
+ help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
+ 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')
+
+parser.add_argument('--rotate', default=False, action='store_true',
+ help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.'
+ 'Use if you get a flipped result, despite feeding a normal looking video')
+
+parser.add_argument('--nosmooth', default=False, action='store_true',
+ help='Prevent smoothing face detections over a short temporal window')
+
+args = parser.parse_args()
+args.img_size = 96
+
+if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
+ args.static = True
+
+def get_smoothened_boxes(boxes, T):
+ for i in range(len(boxes)):
+ if i + T > len(boxes):
+ window = boxes[len(boxes) - T:]
+ else:
+ window = boxes[i : i + T]
+ boxes[i] = np.mean(window, axis=0)
+ return boxes
+
+def face_detect(images):
+ detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
+ flip_input=False, device=device)
+
+ batch_size = args.face_det_batch_size
+
+ while 1:
+ predictions = []
+ try:
+ for i in tqdm(range(0, len(images), batch_size)):
+ predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
+ except RuntimeError:
+ if batch_size == 1:
+ raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
+ batch_size //= 2
+ print('Recovering from OOM error; New batch size: {}'.format(batch_size))
+ continue
+ break
+
+ results = []
+ pady1, pady2, padx1, padx2 = args.pads
+ for rect, image in zip(predictions, images):
+ if rect is None:
+ cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
+ raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
+
+ y1 = max(0, rect[1] - pady1)
+ y2 = min(image.shape[0], rect[3] + pady2)
+ x1 = max(0, rect[0] - padx1)
+ x2 = min(image.shape[1], rect[2] + padx2)
+
+ results.append([x1, y1, x2, y2])
+
+ boxes = np.array(results)
+ if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
+ results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
+
+ del detector
+ return results
+
+def datagen(frames, mels):
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
+
+ if args.box[0] == -1:
+ if not args.static:
+ face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
+ else:
+ face_det_results = face_detect([frames[0]])
+ else:
+ print('Using the specified bounding box instead of face detection...')
+ y1, y2, x1, x2 = args.box
+ face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
+
+ for i, m in enumerate(mels):
+ idx = 0 if args.static else i%len(frames)
+ frame_to_save = frames[idx].copy()
+ face, coords = face_det_results[idx].copy()
+
+ face = cv2.resize(face, (args.img_size, args.img_size))
+
+ img_batch.append(face)
+ mel_batch.append(m)
+ frame_batch.append(frame_to_save)
+ coords_batch.append(coords)
+
+ if len(img_batch) >= args.wav2lip_batch_size:
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
+
+ img_masked = img_batch.copy()
+ img_masked[:, args.img_size//2:] = 0
+
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
+
+ yield img_batch, mel_batch, frame_batch, coords_batch
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
+
+ if len(img_batch) > 0:
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
+
+ img_masked = img_batch.copy()
+ img_masked[:, args.img_size//2:] = 0
+
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
+
+ yield img_batch, mel_batch, frame_batch, coords_batch
+
+mel_step_size = 16
+device = 'cuda' if torch.cuda.is_available() else 'cpu'
+print('Using {} for inference.'.format(device))
+
+def _load(checkpoint_path):
+ if device == 'cuda':
+ checkpoint = torch.load(checkpoint_path)
+ else:
+ checkpoint = torch.load(checkpoint_path,
+ map_location=lambda storage, loc: storage)
+ return checkpoint
+
+def load_model(path):
+ model = Wav2Lip()
+ print("Load checkpoint from: {}".format(path))
+ checkpoint = _load(path)
+ s = checkpoint["state_dict"]
+ new_s = {}
+ for k, v in s.items():
+ new_s[k.replace('module.', '')] = v
+ model.load_state_dict(new_s)
+
+ model = model.to(device)
+ return model.eval()
+
+def main():
+ if not os.path.isfile(args.face):
+ raise ValueError('--face argument must be a valid path to video/image file')
+
+ elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
+ full_frames = [cv2.imread(args.face)]
+ fps = args.fps
+
+ else:
+ video_stream = cv2.VideoCapture(args.face)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+
+ print('Reading video frames...')
+
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ if args.resize_factor > 1:
+ frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor))
+
+ if args.rotate:
+ frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
+
+ y1, y2, x1, x2 = args.crop
+ if x2 == -1: x2 = frame.shape[1]
+ if y2 == -1: y2 = frame.shape[0]
+
+ frame = frame[y1:y2, x1:x2]
+
+ full_frames.append(frame)
+
+ print ("Number of frames available for inference: "+str(len(full_frames)))
+
+ if not args.audio.endswith('.wav'):
+ print('Extracting raw audio...')
+ command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav')
+
+ subprocess.call(command, shell=True)
+ args.audio = 'temp/temp.wav'
+
+ wav = audio.load_wav(args.audio, 16000)
+ mel = audio.melspectrogram(wav)
+ print(mel.shape)
+
+ if np.isnan(mel.reshape(-1)).sum() > 0:
+ raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
+
+ mel_chunks = []
+ mel_idx_multiplier = 80./fps
+ i = 0
+ while 1:
+ start_idx = int(i * mel_idx_multiplier)
+ if start_idx + mel_step_size > len(mel[0]):
+ mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
+ break
+ mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
+ i += 1
+
+ print("Length of mel chunks: {}".format(len(mel_chunks)))
+
+ full_frames = full_frames[:len(mel_chunks)]
+
+ batch_size = args.wav2lip_batch_size
+ gen = datagen(full_frames.copy(), mel_chunks)
+
+ for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
+ total=int(np.ceil(float(len(mel_chunks))/batch_size)))):
+ if i == 0:
+ model = load_model(args.checkpoint_path)
+ print ("Model loaded")
+
+ frame_h, frame_w = full_frames[0].shape[:-1]
+ out = cv2.VideoWriter('temp/result.avi',
+ cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
+
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
+ mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
+
+ with torch.no_grad():
+ pred = model(mel_batch, img_batch)
+
+ pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
+
+ for p, f, c in zip(pred, frames, coords):
+ y1, y2, x1, x2 = c
+ p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
+
+ f[y1:y2, x1:x2] = p
+ out.write(f)
+
+ out.release()
+
+ command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/result.avi', args.outfile)
+ subprocess.call(command, shell=platform.system() != 'Windows')
+
+if __name__ == '__main__':
+ main()
diff --git a/Wav2Lip/models/__init__.py b/Wav2Lip/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3c7fdf25db091078de11fff81ee59b818c092f0
--- /dev/null
+++ b/Wav2Lip/models/__init__.py
@@ -0,0 +1,2 @@
+from .wav2lip import Wav2Lip, Wav2Lip_disc_qual
+from .syncnet import SyncNet_color
\ No newline at end of file
diff --git a/Wav2Lip/models/conv.py b/Wav2Lip/models/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..73529ffd03be8357dab5369d46a049c85b67a520
--- /dev/null
+++ b/Wav2Lip/models/conv.py
@@ -0,0 +1,44 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+class Conv2d(nn.Module):
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
+ nn.BatchNorm2d(cout)
+ )
+ self.act = nn.ReLU()
+ self.residual = residual
+
+ def forward(self, x):
+ out = self.conv_block(x)
+ if self.residual:
+ out += x
+ return self.act(out)
+
+class nonorm_Conv2d(nn.Module):
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
+ )
+ self.act = nn.LeakyReLU(0.01, inplace=True)
+
+ def forward(self, x):
+ out = self.conv_block(x)
+ return self.act(out)
+
+class Conv2dTranspose(nn.Module):
+ def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.conv_block = nn.Sequential(
+ nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
+ nn.BatchNorm2d(cout)
+ )
+ self.act = nn.ReLU()
+
+ def forward(self, x):
+ out = self.conv_block(x)
+ return self.act(out)
diff --git a/Wav2Lip/models/syncnet.py b/Wav2Lip/models/syncnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..d452f8d63686e666b81a55893d8f9a921724afc9
--- /dev/null
+++ b/Wav2Lip/models/syncnet.py
@@ -0,0 +1,66 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .conv import Conv2d
+
+class SyncNet_color(nn.Module):
+ def __init__(self):
+ super(SyncNet_color, self).__init__()
+
+ self.face_encoder = nn.Sequential(
+ Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3),
+
+ Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
+
+ self.audio_encoder = nn.Sequential(
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
+
+ def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T)
+ face_embedding = self.face_encoder(face_sequences)
+ audio_embedding = self.audio_encoder(audio_sequences)
+
+ audio_embedding = audio_embedding.view(audio_embedding.size(0), -1)
+ face_embedding = face_embedding.view(face_embedding.size(0), -1)
+
+ audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
+ face_embedding = F.normalize(face_embedding, p=2, dim=1)
+
+
+ return audio_embedding, face_embedding
diff --git a/Wav2Lip/models/wav2lip.py b/Wav2Lip/models/wav2lip.py
new file mode 100644
index 0000000000000000000000000000000000000000..a198113dff2149f3f13a5bcc29b73eca5469abe0
--- /dev/null
+++ b/Wav2Lip/models/wav2lip.py
@@ -0,0 +1,184 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+import math
+
+from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d
+
+class Wav2Lip(nn.Module):
+ def __init__(self):
+ super(Wav2Lip, self).__init__()
+
+ self.face_encoder_blocks = nn.ModuleList([
+ nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96
+
+ nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)),
+
+ nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)),
+
+ nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)),
+
+ nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)),
+
+ nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
+
+ nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
+
+ self.audio_encoder = nn.Sequential(
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
+
+ self.face_decoder_blocks = nn.ModuleList([
+ nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),),
+
+ nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
+
+ nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6
+
+ nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12
+
+ nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24
+
+ nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48
+
+ nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96
+
+ self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
+ nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
+ nn.Sigmoid())
+
+ def forward(self, audio_sequences, face_sequences):
+ # audio_sequences = (B, T, 1, 80, 16)
+ B = audio_sequences.size(0)
+
+ input_dim_size = len(face_sequences.size())
+ if input_dim_size > 4:
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
+ face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
+
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
+
+ feats = []
+ x = face_sequences
+ for f in self.face_encoder_blocks:
+ x = f(x)
+ feats.append(x)
+
+ x = audio_embedding
+ for f in self.face_decoder_blocks:
+ x = f(x)
+ try:
+ x = torch.cat((x, feats[-1]), dim=1)
+ except Exception as e:
+ print(x.size())
+ print(feats[-1].size())
+ raise e
+
+ feats.pop()
+
+ x = self.output_block(x)
+
+ if input_dim_size > 4:
+ x = torch.split(x, B, dim=0) # [(B, C, H, W)]
+ outputs = torch.stack(x, dim=2) # (B, C, T, H, W)
+
+ else:
+ outputs = x
+
+ return outputs
+
+class Wav2Lip_disc_qual(nn.Module):
+ def __init__(self):
+ super(Wav2Lip_disc_qual, self).__init__()
+
+ self.face_encoder_blocks = nn.ModuleList([
+ nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)), # 48,96
+
+ nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2), # 48,48
+ nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)),
+
+ nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 24,24
+ nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)),
+
+ nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 12,12
+ nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)),
+
+ nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6,6
+ nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)),
+
+ nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 3,3
+ nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1),),
+
+ nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
+ nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
+
+ self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid())
+ self.label_noise = .0
+
+ def get_lower_half(self, face_sequences):
+ return face_sequences[:, :, face_sequences.size(2)//2:]
+
+ def to_2d(self, face_sequences):
+ B = face_sequences.size(0)
+ face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
+ return face_sequences
+
+ def perceptual_forward(self, false_face_sequences):
+ false_face_sequences = self.to_2d(false_face_sequences)
+ false_face_sequences = self.get_lower_half(false_face_sequences)
+
+ false_feats = false_face_sequences
+ for f in self.face_encoder_blocks:
+ false_feats = f(false_feats)
+
+ false_pred_loss = F.binary_cross_entropy(self.binary_pred(false_feats).view(len(false_feats), -1),
+ torch.ones((len(false_feats), 1)).cuda())
+
+ return false_pred_loss
+
+ def forward(self, face_sequences):
+ face_sequences = self.to_2d(face_sequences)
+ face_sequences = self.get_lower_half(face_sequences)
+
+ x = face_sequences
+ for f in self.face_encoder_blocks:
+ x = f(x)
+
+ return self.binary_pred(x).view(len(x), -1)
diff --git a/Wav2Lip/preprocess.py b/Wav2Lip/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad24b1599d961152c68ac388574ee60b5d925ad1
--- /dev/null
+++ b/Wav2Lip/preprocess.py
@@ -0,0 +1,113 @@
+import sys
+
+if sys.version_info[0] < 3 and sys.version_info[1] < 2:
+ raise Exception("Must be using >= Python 3.2")
+
+from os import listdir, path
+
+if not path.isfile('face_detection/detection/sfd/s3fd.pth'):
+ raise FileNotFoundError('Save the s3fd model to face_detection/detection/sfd/s3fd.pth \
+ before running this script!')
+
+import multiprocessing as mp
+from concurrent.futures import ThreadPoolExecutor, as_completed
+import numpy as np
+import argparse, os, cv2, traceback, subprocess
+from tqdm import tqdm
+from glob import glob
+import audio
+from hparams import hparams as hp
+
+import face_detection
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('--ngpu', help='Number of GPUs across which to run in parallel', default=1, type=int)
+parser.add_argument('--batch_size', help='Single GPU Face detection batch size', default=32, type=int)
+parser.add_argument("--data_root", help="Root folder of the LRS2 dataset", required=True)
+parser.add_argument("--preprocessed_root", help="Root folder of the preprocessed dataset", required=True)
+
+args = parser.parse_args()
+
+fa = [face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False,
+ device='cuda:{}'.format(id)) for id in range(args.ngpu)]
+
+template = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'
+# template2 = 'ffmpeg -hide_banner -loglevel panic -threads 1 -y -i {} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {}'
+
+def process_video_file(vfile, args, gpu_id):
+ video_stream = cv2.VideoCapture(vfile)
+
+ frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ frames.append(frame)
+
+ vidname = os.path.basename(vfile).split('.')[0]
+ dirname = vfile.split('/')[-2]
+
+ fulldir = path.join(args.preprocessed_root, dirname, vidname)
+ os.makedirs(fulldir, exist_ok=True)
+
+ batches = [frames[i:i + args.batch_size] for i in range(0, len(frames), args.batch_size)]
+
+ i = -1
+ for fb in batches:
+ preds = fa[gpu_id].get_detections_for_batch(np.asarray(fb))
+
+ for j, f in enumerate(preds):
+ i += 1
+ if f is None:
+ continue
+
+ x1, y1, x2, y2 = f
+ cv2.imwrite(path.join(fulldir, '{}.jpg'.format(i)), fb[j][y1:y2, x1:x2])
+
+def process_audio_file(vfile, args):
+ vidname = os.path.basename(vfile).split('.')[0]
+ dirname = vfile.split('/')[-2]
+
+ fulldir = path.join(args.preprocessed_root, dirname, vidname)
+ os.makedirs(fulldir, exist_ok=True)
+
+ wavpath = path.join(fulldir, 'audio.wav')
+
+ command = template.format(vfile, wavpath)
+ subprocess.call(command, shell=True)
+
+
+def mp_handler(job):
+ vfile, args, gpu_id = job
+ try:
+ process_video_file(vfile, args, gpu_id)
+ except KeyboardInterrupt:
+ exit(0)
+ except:
+ traceback.print_exc()
+
+def main(args):
+ print('Started processing for {} with {} GPUs'.format(args.data_root, args.ngpu))
+
+ filelist = glob(path.join(args.data_root, '*/*.mp4'))
+
+ jobs = [(vfile, args, i%args.ngpu) for i, vfile in enumerate(filelist)]
+ p = ThreadPoolExecutor(args.ngpu)
+ futures = [p.submit(mp_handler, j) for j in jobs]
+ _ = [r.result() for r in tqdm(as_completed(futures), total=len(futures))]
+
+ print('Dumping audios...')
+
+ for vfile in tqdm(filelist):
+ try:
+ process_audio_file(vfile, args)
+ except KeyboardInterrupt:
+ exit(0)
+ except:
+ traceback.print_exc()
+ continue
+
+if __name__ == '__main__':
+ main(args)
\ No newline at end of file
diff --git a/Wav2Lip/requirements.txt b/Wav2Lip/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..77353f8a0a6068fc26cd693a779e2d7b4000b0e5
--- /dev/null
+++ b/Wav2Lip/requirements.txt
@@ -0,0 +1,8 @@
+librosa==0.7.0
+numpy==1.17.1
+opencv-contrib-python>=4.2.0.34
+opencv-python==4.1.0.25
+torch==1.1.0
+torchvision==0.3.0
+tqdm==4.45.0
+numba==0.48
diff --git a/Wav2Lip/results/README.md b/Wav2Lip/results/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b1bbfd53fded37aefe0f4fc97adc8de343341b7a
--- /dev/null
+++ b/Wav2Lip/results/README.md
@@ -0,0 +1 @@
+Generated results will be placed in this folder by default.
\ No newline at end of file
diff --git a/Wav2Lip/temp/README.md b/Wav2Lip/temp/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..04c910499300fa8dc05c317d7d30cb29f31ff836
--- /dev/null
+++ b/Wav2Lip/temp/README.md
@@ -0,0 +1 @@
+Temporary files at the time of inference/testing will be saved here. You can ignore them.
\ No newline at end of file
diff --git a/Wav2Lip/wav2lip_train.py b/Wav2Lip/wav2lip_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..59d12d4cca29292d6a5b4c78de0854a454fde5d3
--- /dev/null
+++ b/Wav2Lip/wav2lip_train.py
@@ -0,0 +1,374 @@
+from os.path import dirname, join, basename, isfile
+from tqdm import tqdm
+
+from models import SyncNet_color as SyncNet
+from models import Wav2Lip as Wav2Lip
+import audio
+
+import torch
+from torch import nn
+from torch import optim
+import torch.backends.cudnn as cudnn
+from torch.utils import data as data_utils
+import numpy as np
+
+from glob import glob
+
+import os, random, cv2, argparse
+from hparams import hparams, get_image_list
+
+parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model without the visual quality discriminator')
+
+parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True, type=str)
+
+parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
+parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True, type=str)
+
+parser.add_argument('--checkpoint_path', help='Resume from this checkpoint', default=None, type=str)
+
+args = parser.parse_args()
+
+
+global_step = 0
+global_epoch = 0
+use_cuda = torch.cuda.is_available()
+print('use_cuda: {}'.format(use_cuda))
+
+syncnet_T = 5
+syncnet_mel_step_size = 16
+
+class Dataset(object):
+ def __init__(self, split):
+ self.all_videos = get_image_list(args.data_root, split)
+
+ def get_frame_id(self, frame):
+ return int(basename(frame).split('.')[0])
+
+ def get_window(self, start_frame):
+ start_id = self.get_frame_id(start_frame)
+ vidname = dirname(start_frame)
+
+ window_fnames = []
+ for frame_id in range(start_id, start_id + syncnet_T):
+ frame = join(vidname, '{}.jpg'.format(frame_id))
+ if not isfile(frame):
+ return None
+ window_fnames.append(frame)
+ return window_fnames
+
+ def read_window(self, window_fnames):
+ if window_fnames is None: return None
+ window = []
+ for fname in window_fnames:
+ img = cv2.imread(fname)
+ if img is None:
+ return None
+ try:
+ img = cv2.resize(img, (hparams.img_size, hparams.img_size))
+ except Exception as e:
+ return None
+
+ window.append(img)
+
+ return window
+
+ def crop_audio_window(self, spec, start_frame):
+ if type(start_frame) == int:
+ start_frame_num = start_frame
+ else:
+ start_frame_num = self.get_frame_id(start_frame) # 0-indexing ---> 1-indexing
+ start_idx = int(80. * (start_frame_num / float(hparams.fps)))
+
+ end_idx = start_idx + syncnet_mel_step_size
+
+ return spec[start_idx : end_idx, :]
+
+ def get_segmented_mels(self, spec, start_frame):
+ mels = []
+ assert syncnet_T == 5
+ start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing
+ if start_frame_num - 2 < 0: return None
+ for i in range(start_frame_num, start_frame_num + syncnet_T):
+ m = self.crop_audio_window(spec, i - 2)
+ if m.shape[0] != syncnet_mel_step_size:
+ return None
+ mels.append(m.T)
+
+ mels = np.asarray(mels)
+
+ return mels
+
+ def prepare_window(self, window):
+ # 3 x T x H x W
+ x = np.asarray(window) / 255.
+ x = np.transpose(x, (3, 0, 1, 2))
+
+ return x
+
+ def __len__(self):
+ return len(self.all_videos)
+
+ def __getitem__(self, idx):
+ while 1:
+ idx = random.randint(0, len(self.all_videos) - 1)
+ vidname = self.all_videos[idx]
+ img_names = list(glob(join(vidname, '*.jpg')))
+ if len(img_names) <= 3 * syncnet_T:
+ continue
+
+ img_name = random.choice(img_names)
+ wrong_img_name = random.choice(img_names)
+ while wrong_img_name == img_name:
+ wrong_img_name = random.choice(img_names)
+
+ window_fnames = self.get_window(img_name)
+ wrong_window_fnames = self.get_window(wrong_img_name)
+ if window_fnames is None or wrong_window_fnames is None:
+ continue
+
+ window = self.read_window(window_fnames)
+ if window is None:
+ continue
+
+ wrong_window = self.read_window(wrong_window_fnames)
+ if wrong_window is None:
+ continue
+
+ try:
+ wavpath = join(vidname, "audio.wav")
+ wav = audio.load_wav(wavpath, hparams.sample_rate)
+
+ orig_mel = audio.melspectrogram(wav).T
+ except Exception as e:
+ continue
+
+ mel = self.crop_audio_window(orig_mel.copy(), img_name)
+
+ if (mel.shape[0] != syncnet_mel_step_size):
+ continue
+
+ indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name)
+ if indiv_mels is None: continue
+
+ window = self.prepare_window(window)
+ y = window.copy()
+ window[:, :, window.shape[2]//2:] = 0.
+
+ wrong_window = self.prepare_window(wrong_window)
+ x = np.concatenate([window, wrong_window], axis=0)
+
+ x = torch.FloatTensor(x)
+ mel = torch.FloatTensor(mel.T).unsqueeze(0)
+ indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1)
+ y = torch.FloatTensor(y)
+ return x, indiv_mels, mel, y
+
+def save_sample_images(x, g, gt, global_step, checkpoint_dir):
+ x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
+ g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
+ gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
+
+ refs, inps = x[..., 3:], x[..., :3]
+ folder = join(checkpoint_dir, "samples_step{:09d}".format(global_step))
+ if not os.path.exists(folder): os.mkdir(folder)
+ collage = np.concatenate((refs, inps, g, gt), axis=-2)
+ for batch_idx, c in enumerate(collage):
+ for t in range(len(c)):
+ cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t])
+
+logloss = nn.BCELoss()
+def cosine_loss(a, v, y):
+ d = nn.functional.cosine_similarity(a, v)
+ loss = logloss(d.unsqueeze(1), y)
+
+ return loss
+
+device = torch.device("cuda" if use_cuda else "cpu")
+syncnet = SyncNet().to(device)
+for p in syncnet.parameters():
+ p.requires_grad = False
+
+recon_loss = nn.L1Loss()
+def get_sync_loss(mel, g):
+ g = g[:, :, :, g.size(3)//2:]
+ g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1)
+ # B, 3 * T, H//2, W
+ a, v = syncnet(mel, g)
+ y = torch.ones(g.size(0), 1).float().to(device)
+ return cosine_loss(a, v, y)
+
+def train(device, model, train_data_loader, test_data_loader, optimizer,
+ checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
+
+ global global_step, global_epoch
+ resumed_step = global_step
+
+ while global_epoch < nepochs:
+ print('Starting Epoch: {}'.format(global_epoch))
+ running_sync_loss, running_l1_loss = 0., 0.
+ prog_bar = tqdm(enumerate(train_data_loader))
+ for step, (x, indiv_mels, mel, gt) in prog_bar:
+ model.train()
+ optimizer.zero_grad()
+
+ # Move data to CUDA device
+ x = x.to(device)
+ mel = mel.to(device)
+ indiv_mels = indiv_mels.to(device)
+ gt = gt.to(device)
+
+ g = model(indiv_mels, x)
+
+ if hparams.syncnet_wt > 0.:
+ sync_loss = get_sync_loss(mel, g)
+ else:
+ sync_loss = 0.
+
+ l1loss = recon_loss(g, gt)
+
+ loss = hparams.syncnet_wt * sync_loss + (1 - hparams.syncnet_wt) * l1loss
+ loss.backward()
+ optimizer.step()
+
+ if global_step % checkpoint_interval == 0:
+ save_sample_images(x, g, gt, global_step, checkpoint_dir)
+
+ global_step += 1
+ cur_session_steps = global_step - resumed_step
+
+ running_l1_loss += l1loss.item()
+ if hparams.syncnet_wt > 0.:
+ running_sync_loss += sync_loss.item()
+ else:
+ running_sync_loss += 0.
+
+ if global_step == 1 or global_step % checkpoint_interval == 0:
+ save_checkpoint(
+ model, optimizer, global_step, checkpoint_dir, global_epoch)
+
+ if global_step == 1 or global_step % hparams.eval_interval == 0:
+ with torch.no_grad():
+ average_sync_loss = eval_model(test_data_loader, global_step, device, model, checkpoint_dir)
+
+ if average_sync_loss < .75:
+ hparams.set_hparam('syncnet_wt', 0.01) # without image GAN a lesser weight is sufficient
+
+ prog_bar.set_description('L1: {}, Sync Loss: {}'.format(running_l1_loss / (step + 1),
+ running_sync_loss / (step + 1)))
+
+ global_epoch += 1
+
+
+def eval_model(test_data_loader, global_step, device, model, checkpoint_dir):
+ eval_steps = 700
+ print('Evaluating for {} steps'.format(eval_steps))
+ sync_losses, recon_losses = [], []
+ step = 0
+ while 1:
+ for x, indiv_mels, mel, gt in test_data_loader:
+ step += 1
+ model.eval()
+
+ # Move data to CUDA device
+ x = x.to(device)
+ gt = gt.to(device)
+ indiv_mels = indiv_mels.to(device)
+ mel = mel.to(device)
+
+ g = model(indiv_mels, x)
+
+ sync_loss = get_sync_loss(mel, g)
+ l1loss = recon_loss(g, gt)
+
+ sync_losses.append(sync_loss.item())
+ recon_losses.append(l1loss.item())
+
+ if step > eval_steps:
+ averaged_sync_loss = sum(sync_losses) / len(sync_losses)
+ averaged_recon_loss = sum(recon_losses) / len(recon_losses)
+
+ print('L1: {}, Sync loss: {}'.format(averaged_recon_loss, averaged_sync_loss))
+
+ return averaged_sync_loss
+
+def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch):
+
+ checkpoint_path = join(
+ checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step))
+ optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
+ torch.save({
+ "state_dict": model.state_dict(),
+ "optimizer": optimizer_state,
+ "global_step": step,
+ "global_epoch": epoch,
+ }, checkpoint_path)
+ print("Saved checkpoint:", checkpoint_path)
+
+
+def _load(checkpoint_path):
+ if use_cuda:
+ checkpoint = torch.load(checkpoint_path)
+ else:
+ checkpoint = torch.load(checkpoint_path,
+ map_location=lambda storage, loc: storage)
+ return checkpoint
+
+def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
+ global global_step
+ global global_epoch
+
+ print("Load checkpoint from: {}".format(path))
+ checkpoint = _load(path)
+ s = checkpoint["state_dict"]
+ new_s = {}
+ for k, v in s.items():
+ new_s[k.replace('module.', '')] = v
+ model.load_state_dict(new_s)
+ if not reset_optimizer:
+ optimizer_state = checkpoint["optimizer"]
+ if optimizer_state is not None:
+ print("Load optimizer state from {}".format(path))
+ optimizer.load_state_dict(checkpoint["optimizer"])
+ if overwrite_global_states:
+ global_step = checkpoint["global_step"]
+ global_epoch = checkpoint["global_epoch"]
+
+ return model
+
+if __name__ == "__main__":
+ checkpoint_dir = args.checkpoint_dir
+
+ # Dataset and Dataloader setup
+ train_dataset = Dataset('train')
+ test_dataset = Dataset('val')
+
+ train_data_loader = data_utils.DataLoader(
+ train_dataset, batch_size=hparams.batch_size, shuffle=True,
+ num_workers=hparams.num_workers)
+
+ test_data_loader = data_utils.DataLoader(
+ test_dataset, batch_size=hparams.batch_size,
+ num_workers=4)
+
+ device = torch.device("cuda" if use_cuda else "cpu")
+
+ # Model
+ model = Wav2Lip().to(device)
+ print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
+
+ optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
+ lr=hparams.initial_learning_rate)
+
+ if args.checkpoint_path is not None:
+ load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)
+
+ load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True, overwrite_global_states=False)
+
+ if not os.path.exists(checkpoint_dir):
+ os.mkdir(checkpoint_dir)
+
+ # Train!
+ train(device, model, train_data_loader, test_data_loader, optimizer,
+ checkpoint_dir=checkpoint_dir,
+ checkpoint_interval=hparams.checkpoint_interval,
+ nepochs=hparams.nepochs)
diff --git a/__pycache__/ollama_chatbotTTS.cpython-312.pyc b/__pycache__/ollama_chatbotTTS.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49694cb390aba460ffa50d3eb093590752a0914c
Binary files /dev/null and b/__pycache__/ollama_chatbotTTS.cpython-312.pyc differ
diff --git a/__pycache__/sync_audio_video.cpython-312.pyc b/__pycache__/sync_audio_video.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14531b9817437504df5cf5e18872406755dd3b63
Binary files /dev/null and b/__pycache__/sync_audio_video.cpython-312.pyc differ
diff --git a/__pycache__/text_to_speech.cpython-312.pyc b/__pycache__/text_to_speech.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..321250ba2b48b1da8bc9f0a1124a5c33a30883e3
Binary files /dev/null and b/__pycache__/text_to_speech.cpython-312.pyc differ
diff --git a/__pycache__/whisper_tts.cpython-310.pyc b/__pycache__/whisper_tts.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..05ad2931b1737d8f684604833eb9f43da8af5117
Binary files /dev/null and b/__pycache__/whisper_tts.cpython-310.pyc differ
diff --git a/__pycache__/whisper_tts.cpython-312.pyc b/__pycache__/whisper_tts.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..236c599d0b342feecbb9fb7db73afa0613ee606c
Binary files /dev/null and b/__pycache__/whisper_tts.cpython-312.pyc differ
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..143929b4ca38001fd92da2d644ac3845c09f84e5
--- /dev/null
+++ b/app.py
@@ -0,0 +1,69 @@
+from flask import Flask, request, render_template, jsonify
+import os, re
+from whisper_tts import WhisperTTS
+from ollama_chatbotTTS import OllamaChat
+from text_to_speech import TextToSpeech
+from sync_audio_video import AudioVideoSync
+
+app = Flask(__name__, static_folder='.', static_url_path='')
+
+THUMBNAILS_DIR = "thumbnails"
+VIDEO_DIR = "sample_video"
+UPLOAD_DIR = "uploads"
+
+def get_thumbnail_images():
+ if not os.path.exists(THUMBNAILS_DIR):
+ return []
+ return [os.path.splitext(f)[0] for f in os.listdir(THUMBNAILS_DIR) if f.lower().endswith((".png",".jpg",".jpeg"))]
+
+@app.route('/')
+def index():
+ avatars = get_thumbnail_images()
+ return render_template('index.html', avatars=avatars)
+
+@app.route('/transcribe', methods=['POST'])
+def transcribe():
+ f = request.files.get('audio')
+ if f:
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
+ path = os.path.join(UPLOAD_DIR, f.filename)
+ f.save(path)
+ text = WhisperTTS().transcribe_audio(path)
+ return jsonify(text=text)
+ return jsonify(text="")
+
+@app.route('/chat', methods=['POST'])
+def chat():
+ data = request.get_json()
+ text = data.get('text', "")
+ r = OllamaChat().get_response(text)
+ resp = re.sub(r"|", "", r).strip()
+ return jsonify(response=resp)
+
+@app.route('/tts', methods=['POST'])
+def tts():
+ data = request.get_json()
+ text = data.get('text', "")
+ if not text:
+ return jsonify(audio_url="")
+ path = TextToSpeech().synthesize(text)
+ return jsonify(audio_url="/" + path)
+
+@app.route('/sync', methods=['POST'])
+def sync_audio_video():
+ data = request.get_json()
+ avatar = data.get('avatar', "")
+ audio_url = data.get('audio_url', "")
+ audio_path = audio_url.lstrip('/')
+ vid = None
+ for v in os.listdir(VIDEO_DIR):
+ if os.path.splitext(v)[0].lower() == avatar.lower():
+ vid = os.path.join(VIDEO_DIR, v)
+ break
+ if not vid or not audio_path:
+ return jsonify(video_url="")
+ out = AudioVideoSync().sync_audio_video(vid, audio_path)
+ return jsonify(video_url="/" + out)
+
+if __name__ == '__main__':
+ app.run(debug=True)
diff --git a/gradio_ui.py b/gradio_ui.py
new file mode 100644
index 0000000000000000000000000000000000000000..21ef96493daac005ae6f75a854b231b538484dd2
--- /dev/null
+++ b/gradio_ui.py
@@ -0,0 +1,118 @@
+import gradio as gr
+import os
+from whisper_tts import WhisperTTS
+from ollama_chatbotTTS import OllamaChat
+from text_to_speech import TextToSpeech
+from sync_audio_video import AudioVideoSync
+import re
+
+os.system("ollama serve &")
+
+# Paths
+THUMBNAILS_DIR = "thumbnails"
+VIDEO_DIR = "sample_video"
+
+def get_thumbnail_images():
+ if not os.path.exists(THUMBNAILS_DIR):
+ return []
+ return [(os.path.splitext(f)[0], os.path.join(THUMBNAILS_DIR, f))
+ for f in os.listdir(THUMBNAILS_DIR) if f.endswith((".png", ".jpg", ".jpeg"))]
+
+thumbnail_images = get_thumbnail_images()
+avatar_names = [name for name, _ in thumbnail_images]
+
+def find_matching_video(file_name):
+ file_name = file_name.lower()
+ if not os.path.exists(VIDEO_DIR):
+ return None
+ for video in os.listdir(VIDEO_DIR):
+ video_name, ext = os.path.splitext(video)
+ if video_name.lower() == file_name and ext in [".mp4", ".avi", ".mov"]:
+ return os.path.join(VIDEO_DIR, video)
+ return None
+
+def update_avatar_display(selected_name):
+ for name, img_path in thumbnail_images:
+ if name == selected_name:
+ return img_path
+ return None
+
+def check_enable_process_button(selected_name, audio_file, transcribed_text):
+ if selected_name and (audio_file or transcribed_text.strip()):
+ return gr.update(interactive=True)
+ return gr.update(interactive=False)
+
+def process_pipeline(audio_file, transcribed_text, selected_name):
+ if audio_file:
+ whisper_tts = WhisperTTS()
+ transcribed_text = whisper_tts.transcribe_audio(audio_file)
+ yield transcribed_text, "", None, None # Show transcribed text first
+
+ if not transcribed_text.strip():
+ yield "Warning: Please provide valid text.", "", None, None
+ return
+
+ ollama_chat = OllamaChat()
+ chatbot_response = ollama_chat.get_response(transcribed_text)
+ chatbot_response = re.sub(r"|", "", chatbot_response).strip()
+ yield transcribed_text, chatbot_response, None, None # Show chatbot response next
+
+ if not chatbot_response:
+ yield transcribed_text, "Warning: No chatbot response.", None, None
+ return
+
+ tts = TextToSpeech()
+ output_audio_path = tts.synthesize(chatbot_response)
+ yield transcribed_text, chatbot_response, output_audio_path, None # Show generated speech
+
+ if not selected_name:
+ yield transcribed_text, chatbot_response, output_audio_path, "Warning: Select an avatar."
+ return
+
+ input_video = find_matching_video(selected_name.lower())
+ if not input_video:
+ yield transcribed_text, chatbot_response, output_audio_path, "Warning: No matching video."
+ return
+
+ sync = AudioVideoSync()
+ output_video_path = sync.sync_audio_video(input_video, output_audio_path)
+ yield transcribed_text, chatbot_response, output_audio_path, output_video_path # Show final video
+
+with gr.Blocks() as demo:
+ gr.Markdown("## Personalized Avatar Video")
+
+ with gr.Row():
+ with gr.Column():
+ audio_input = gr.Audio(type="filepath", label="Audio Input")
+ transcribed_text_output = gr.Textbox(label="Edit and Process Text")
+ chatbot_response_output = gr.Textbox(label="Assistant Response")
+ gr.Markdown("### Select an Avatar")
+ selected_avatar = gr.Radio(choices=avatar_names, label="Select an Avatar")
+ avatar_display = gr.Image(label="Selected Avatar", width=150, height=150)
+ process_button = gr.Button("Generate Lip-Sync Video", interactive=False)
+
+ with gr.Column():
+ tts_audio_output = gr.Audio(label="Generated Speech")
+ video_output = gr.Video(label="Final Lip-Synced Video")
+
+ selected_avatar.change(update_avatar_display, inputs=[selected_avatar], outputs=[avatar_display])
+ selected_avatar.change(check_enable_process_button, inputs=[selected_avatar, audio_input, transcribed_text_output], outputs=[process_button])
+ audio_input.change(check_enable_process_button, inputs=[selected_avatar, audio_input, transcribed_text_output], outputs=[process_button])
+ transcribed_text_output.change(check_enable_process_button, inputs=[selected_avatar, audio_input, transcribed_text_output], outputs=[process_button])
+
+ process_button.click(
+ process_pipeline,
+ inputs=[audio_input, transcribed_text_output, selected_avatar],
+ outputs=[transcribed_text_output, chatbot_response_output, tts_audio_output, video_output]
+ )
+
+if __name__ == "__main__":
+ demo.launch(
+ server_name="0.0.0.0",
+ server_port=7860,
+ # opcionalmente:
+ share=True, # para obtener un link público
+ inbrowser=True, # para abrir automáticamente el navegador
+# prevent_thread_lock=True # si quieres que el script no bloquee el hilo principal
+ )
+
diff --git a/image_generator.py b/image_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..105665f96fbc763c52c269d2e06512ebf9bdd324
--- /dev/null
+++ b/image_generator.py
@@ -0,0 +1,30 @@
+import os
+import subprocess
+
+def generate_thumbnails(video_folder, thumbnail_folder, timestamp="00:00:05"):
+ if not os.path.exists(thumbnail_folder):
+ os.makedirs(thumbnail_folder)
+
+ for video_file in os.listdir(video_folder):
+ video_path = os.path.join(video_folder, video_file)
+
+ if os.path.isfile(video_path) and video_file.lower().endswith((".mp4", ".mkv", ".avi", ".mov", ".flv")):
+ thumbnail_name = os.path.splitext(video_file)[0] + ".png"
+ thumbnail_path = os.path.join(thumbnail_folder, thumbnail_name)
+
+ command = [
+ "ffmpeg", "-i", video_path,
+ "-ss", timestamp, "-vframes", "1", "-q:v", "2",
+ thumbnail_path
+ ]
+
+ try:
+ subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ print(f"Thumbnail generated: {thumbnail_path}")
+ except subprocess.CalledProcessError as e:
+ print(f"Error generating thumbnail for {video_file}: {e}")
+
+if __name__ == "__main__":
+ video_folder = "sample_video" # Change to your video folder path
+ thumbnail_folder = "thumbnails"
+ generate_thumbnails(video_folder, thumbnail_folder)
\ No newline at end of file
diff --git a/ollama_chatbotTTS.py b/ollama_chatbotTTS.py
new file mode 100644
index 0000000000000000000000000000000000000000..889390e936f36d27d04707bfba3cc167b302a411
--- /dev/null
+++ b/ollama_chatbotTTS.py
@@ -0,0 +1,48 @@
+import ollama
+import json
+from pydantic import BaseModel
+
+class ChatResponse(BaseModel):
+ response: str
+
+class OllamaChat:
+ response: str
+
+ def __init__(self):
+ """Initialize the Ollama model and pull it locally if needed."""
+ # Modelo que queremos usar
+ self.model = "hf.co/jnjj/gemma-3-1b-it-qat-q4_0-unquantized-Q2_K-GGUF:Q2_K"
+ self.system_prompt = (
+ "You are a concise and natural-sounding assistant. "
+ "Answer questions briefly, in one or two sentences at most, "
+ "as if responding for text-to-speech (TTS). Keep it natural and conversational."
+ )
+
+ # Intentamos descargar (pull) el modelo
+ try:
+ print(f"Descargando el modelo {self.model}…")
+ ollama.pull(model=self.model)
+ print("Modelo descargado correctamente.")
+ except Exception as e:
+ print(f"No se pudo descargar el modelo: {e}")
+
+ def get_response(self, user_input: str) -> str:
+ """Processes the input text using Ollama and returns only the response string."""
+ result = ollama.chat(
+ model=self.model,
+ messages=[
+ {"role": "system", "content": self.system_prompt},
+ {"role": "user", "content": user_input}
+ ],
+ format=ChatResponse.model_json_schema()
+ )
+ # Extraemos el campo "response" del JSON devuelto
+ response = json.loads(result["message"]["content"])["response"]
+ return response
+
+
+# Ejemplo de uso
+if __name__ == "__main__":
+ ollama_chat = OllamaChat()
+ user_text = input("Enter your question: ")
+ print("\nOllama Response:\n", ollama_chat.get_response(user_text))
diff --git a/packages.txt b/packages.txt
new file mode 100644
index 0000000000000000000000000000000000000000..809f403def95cd25ed4ae684551def8fb7add385
--- /dev/null
+++ b/packages.txt
@@ -0,0 +1,13 @@
+fakeroot
+git
+git-lfs
+ffmpeg
+libsm6
+libxext6
+cmake
+rsync
+libgl1-mesa-glx
+espeak-ng
+curl
+nodejs
+sudo
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..26332835732c88126868844ca8feda5492c55bb4
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,18 @@
+gradio==4.19.2 # Or the version you are using
+ollama
+torch
+ffmpeg-python
+TTS
+openai-whisper
+gTTS
+soundfile
+#Wan2lip requirements
+librosa
+numpy
+opencv-contrib-python>=4.2.0.34
+opencv-python
+torch
+torchvision
+tqdm
+numba
+flask
diff --git a/sample_video/Amicia.mp4 b/sample_video/Amicia.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..f6e79576349305ba4f2feaee114867ca80d8be5d
--- /dev/null
+++ b/sample_video/Amicia.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0eaed827867ed845acd798228f23ac1a56dab26e53c59cbd591aa2a40786aff9
+size 1367849
diff --git a/sample_video/Claire.mp4 b/sample_video/Claire.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..0d225a57fb42eb8b9f1005885bbb08ecc1edc4a3
--- /dev/null
+++ b/sample_video/Claire.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f1d2a4fff81f005af0488b97d238ee426e4011953310dd15183bfc38772d76a5
+size 1290787
diff --git a/sample_video/Gina.mp4 b/sample_video/Gina.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..b34fea7ee236246db55c2b806a728e62ec7225ab
--- /dev/null
+++ b/sample_video/Gina.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:95250199d6a1f2f66a9f9b529899bd8469bc7eab41f4820ab5f7962e20941a19
+size 1649442
diff --git a/sample_video/Ruth.mp4 b/sample_video/Ruth.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..d24695ef6d759a36a62d7c6f47158dcedc121a37
--- /dev/null
+++ b/sample_video/Ruth.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a1ac91edfd378e6c25baa1b063b9bf4168149cb4bdeebce3cc86fd2e9642d2cd
+size 1493882
diff --git a/sample_video/Samantha.mp4 b/sample_video/Samantha.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..3157a60cffcdf709e9563e4599111c107043ce1e
--- /dev/null
+++ b/sample_video/Samantha.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:41156d86c16c4d5a0ec80d24701a17fb412c119568beb6307c1c2bc8655351d8
+size 2750728
diff --git a/sync_audio_video.py b/sync_audio_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..393761a22df21b8da93623ad3f877f3249772ab1
--- /dev/null
+++ b/sync_audio_video.py
@@ -0,0 +1,57 @@
+import torch
+import gc
+import os
+import subprocess
+import datetime
+
+class AudioVideoSync:
+ def __init__(self, wav2lip_dir="Wav2Lip"):
+ self.wav2lip_dir = wav2lip_dir
+ self.checkpoint_path = os.path.join(wav2lip_dir, "checkpoints", "wav2lip_gan.pth")
+
+ def print_memory_usage(self, stage=""):
+ """Prints GPU memory usage at different stages."""
+ os.system('nvidia-smi')
+ print(f"[{stage}] Allocated Memory: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
+ print(f"[{stage}] Reserved Memory: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
+ print("-" * 50)
+
+ def sync_audio_video(self, video_path, audio_path, output_video=None):
+ """Syncs audio and video using Wav2Lip."""
+ if not os.path.exists(video_path) or not os.path.exists(audio_path):
+ raise FileNotFoundError("Video or Audio file not found.")
+
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
+ result_video = output_video or f"result_{timestamp}.mp4"
+
+ # Before running Wav2Lip, clear cache and check memory
+ gc.collect()
+ torch.cuda.empty_cache()
+ self.print_memory_usage("Before Wav2Lip Inference")
+
+ # Run Wav2Lip
+ # Run Wav2Lip inference with more accurate lip sync
+ print("Running Wav2Lip for better lip movement...")
+ result = subprocess.run([
+ "python", os.path.join(self.wav2lip_dir, "inference.py"),
+ "--checkpoint_path", self.checkpoint_path,
+ "--face", video_path,
+ "--audio", audio_path,
+ "--outfile", result_video,
+ "--wav2lip_batch_size", "1",
+ "--resize_factor", "2", # Better accuracy for lips
+ "--nosmooth" # Ensures smoother transitions
+ ], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+
+ if result.returncode != 0:
+ print("Error in Wav2Lip inference:", result.stderr.decode())
+ return None
+
+ # After inference, free memory
+ gc.collect()
+ torch.cuda.empty_cache()
+ self.print_memory_usage("After Wav2Lip Inference")
+
+ print(f" Output saved at: {result_video}")
+ return result_video
+
diff --git a/templates/index.html b/templates/index.html
new file mode 100644
index 0000000000000000000000000000000000000000..cea57dd6e5122836c5eb062d1a1d841d2c3dc013
--- /dev/null
+++ b/templates/index.html
@@ -0,0 +1,93 @@
+
+
+
+
+ Personalized Avatar Video
+
+
+ Personalized Avatar Video
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
![]()
+
+
+
+
+
+
+
+
+
+
+
diff --git a/text_to_speech.py b/text_to_speech.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecdc466cc818a3e8943a6e2c476499499c4a5267
--- /dev/null
+++ b/text_to_speech.py
@@ -0,0 +1,52 @@
+import os
+from TTS.api import TTS
+from datetime import datetime
+
+class TextToSpeech:
+ def __init__(self, model_name="tts_models/en/ljspeech/vits", device="cpu"):
+ """
+ Initialize the TTS model.
+ :param model_name: The name of the TTS model to use.
+ :param device: The device to run the model on ("cuda" for GPU, "cpu" for CPU).
+ """
+ self.model_name = model_name
+ self.device = device
+ self.tts = TTS(model_name).to(device)
+
+ def synthesize(self, text, output_dir="output_audio"):
+ """
+ Convert text to speech and save as a .wav file.
+ :param text: The text to convert to speech.
+ :param output_dir: Directory to save the audio file.
+ :return: The path of the saved audio file.
+ """
+ # Create output directory if it doesn't exist
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Generate output filename with timestamp
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ if not output_dir.endswith(".wav"): # Ensure correct file extension
+ output_file = os.path.join(output_dir, f"tts_output_{timestamp}.wav")
+ else:
+ output_file = output_dir
+
+ # Generate speech with custom parameters
+ self.tts.tts_to_file(
+ text=text,
+ file_path=output_file,
+ speed=1.4, # Adjust speed (1.0 = normal)
+ noise_scale=0.8, # Control expressiveness
+ noise_scale_w=0.5, # Control speech rhythm variation
+ length_scale=1.1 # Control speech speed and pauses
+ )
+
+ print(f"Speech synthesis complete. Saved as {output_file}.")
+ return output_file
+
+# Example usage
+""" if __name__ == "__main__":
+ tts = TextToSpeech()
+ text_input = "Hello! This is a test for text-to-speech conversion."
+
+ output_path = tts.synthesize(text_input)
+ print("\nGenerated Audio File:", output_path) """
diff --git a/thumbnails/Amicia.jpg b/thumbnails/Amicia.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..be96da71ea0ca29d4943fad9a52b083223bf14c7
--- /dev/null
+++ b/thumbnails/Amicia.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6ef8d32bac35aa41c535ef4470b4c03e7278527fa4e2a8f98c830a90b400230a
+size 124232
diff --git a/thumbnails/Claire.jpg b/thumbnails/Claire.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..df4599839790e901fd6566f587370be299bafc16
--- /dev/null
+++ b/thumbnails/Claire.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e4ddc65b378a4d7604c1f30f46f7c331b2a68313b0da52a1f6e2c05f94250982
+size 143992
diff --git a/thumbnails/Gina.jpg b/thumbnails/Gina.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2b8a5098cfd90e628f833d9f6aa07a0e5f8a53de
--- /dev/null
+++ b/thumbnails/Gina.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1c531a66aa4938693338d1af44adeab46236a0ac6fa98eafeecc5cc2d2f5e270
+size 110105
diff --git a/thumbnails/Ruth.jpg b/thumbnails/Ruth.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0bb86a45d28d88878c37ea71295295d21b973867
--- /dev/null
+++ b/thumbnails/Ruth.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5bedca2deaac0215fd0ff03087a13eb8784653a93dd91bf19b49e8e9f3a90d20
+size 120337
diff --git a/thumbnails/Samantha.jpg b/thumbnails/Samantha.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1e6f4c7bebec1778d316fbffd1b65740973e2730
--- /dev/null
+++ b/thumbnails/Samantha.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a0920a19a1c0cbe0e35e963039376dd8b632fe8068d71d9138e11263c8927418
+size 140448
diff --git a/whisper_tts.py b/whisper_tts.py
new file mode 100644
index 0000000000000000000000000000000000000000..93763193ab837f88819d2225709d14f4593a7d8b
--- /dev/null
+++ b/whisper_tts.py
@@ -0,0 +1,49 @@
+import whisper
+from gtts import gTTS
+import soundfile as sf
+
+class WhisperTTS:
+ def __init__(self, model_size="base"):
+ """Initialize the Whisper model."""
+ self.model = whisper.load_model(model_size)
+
+ def transcribe_audio(self, input_audio, output_text_file="transcription.txt"):
+ """Transcribes audio and saves text."""
+ result = self.model.transcribe(input_audio)
+
+ with open(output_text_file, "w") as f:
+ f.write(result["text"])
+
+ print("\nTranscription Saved:", output_text_file)
+ return result["text"]
+
+ def text_to_speech(self, text, output_audio="output.wav"):
+ """Converts transcribed text to speech and saves it as WAV."""
+ tts = gTTS(text, lang="en") # Convert text to speech
+ tts.save("temp.mp3") # Save as temporary MP3
+
+ # Convert MP3 to WAV
+ data, samplerate = sf.read("temp.mp3")
+ sf.write(output_audio, data, samplerate)
+
+ print("\nTTS Audio Saved:", output_audio)
+
+ def process_audio(self, input_audio):
+ """Full pipeline: Transcribe and generate speech."""
+ transcribed_text = self.transcribe_audio(input_audio)
+ print("\nTranscribed Text:\n", transcribed_text)
+
+ output_wav = "transcribed_audio.wav"
+ self.text_to_speech(transcribed_text, output_wav)
+
+ return transcribed_text, output_wav
+
+# Usage Example
+""" if __name__ == "__main__":
+ whisper_tts = WhisperTTS()
+
+ input_audio_file = "sample_audio/signal-2025-03-29-153916.mp3" # Change this to your actual file
+ text, wav_file = whisper_tts.process_audio(input_audio_file)
+
+ print("\nFinal Output:\nText File: transcription.txt\nWAV File:", wav_file)
+ """
\ No newline at end of file