|
|
|
|
| import torch
|
| import numpy
|
| import time, pdb, argparse, subprocess, os, math, glob
|
| import cv2
|
| import python_speech_features
|
|
|
| from scipy import signal
|
| from scipy.io import wavfile
|
| from .syncnet import S
|
| from shutil import rmtree
|
| from latentsync.utils.util import check_model_and_download
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def calc_pdist(feat1, feat2, vshift=10):
|
| win_size = vshift * 2 + 1
|
|
|
| feat2p = torch.nn.functional.pad(feat2, (0, 0, vshift, vshift))
|
|
|
| dists = []
|
|
|
| for i in range(0, len(feat1)):
|
|
|
| dists.append(
|
| torch.nn.functional.pairwise_distance(feat1[[i], :].repeat(win_size, 1), feat2p[i : i + win_size, :])
|
| )
|
|
|
| return dists
|
|
|
|
|
|
|
|
|
|
|
| class SyncNetEval(torch.nn.Module):
|
| def __init__(self, dropout=0, num_layers_in_fc_layers=1024, device="cpu"):
|
| super().__init__()
|
|
|
| self.__S__ = S(num_layers_in_fc_layers=num_layers_in_fc_layers).to(device)
|
| self.device = device
|
|
|
| def evaluate(self, video_path, temp_dir="temp", batch_size=20, vshift=15):
|
|
|
| self.__S__.eval()
|
|
|
|
|
|
|
|
|
|
|
| if os.path.exists(temp_dir):
|
| rmtree(temp_dir)
|
|
|
| os.makedirs(temp_dir)
|
|
|
|
|
|
|
|
|
|
|
| command = f"ffmpeg -loglevel error -nostdin -y -i {video_path} -f image2 {os.path.join(temp_dir, '%06d.jpg')}"
|
| subprocess.call(command, shell=True, stdout=None)
|
|
|
| command = f"ffmpeg -loglevel error -nostdin -y -i {video_path} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {os.path.join(temp_dir, 'audio.wav')}"
|
| subprocess.call(command, shell=True, stdout=None)
|
|
|
|
|
|
|
|
|
|
|
| images = []
|
|
|
| flist = glob.glob(os.path.join(temp_dir, "*.jpg"))
|
| flist.sort()
|
|
|
| for fname in flist:
|
| img_input = cv2.imread(fname)
|
| img_input = cv2.resize(img_input, (224, 224))
|
| images.append(img_input)
|
|
|
| im = numpy.stack(images, axis=3)
|
| im = numpy.expand_dims(im, axis=0)
|
| im = numpy.transpose(im, (0, 3, 4, 1, 2))
|
|
|
| imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
|
|
|
|
|
|
|
|
|
|
|
| sample_rate, audio = wavfile.read(os.path.join(temp_dir, "audio.wav"))
|
| mfcc = zip(*python_speech_features.mfcc(audio, sample_rate))
|
| mfcc = numpy.stack([numpy.array(i) for i in mfcc])
|
|
|
| cc = numpy.expand_dims(numpy.expand_dims(mfcc, axis=0), axis=0)
|
| cct = torch.autograd.Variable(torch.from_numpy(cc.astype(float)).float())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| min_length = min(len(images), math.floor(len(audio) / 640))
|
|
|
|
|
|
|
|
|
|
|
| lastframe = min_length - 5
|
| im_feat = []
|
| cc_feat = []
|
|
|
| tS = time.time()
|
| for i in range(0, lastframe, batch_size):
|
|
|
| im_batch = [imtv[:, :, vframe : vframe + 5, :, :] for vframe in range(i, min(lastframe, i + batch_size))]
|
| im_in = torch.cat(im_batch, 0)
|
| im_out = self.__S__.forward_lip(im_in.to(self.device))
|
| im_feat.append(im_out.data.cpu())
|
|
|
| cc_batch = [
|
| cct[:, :, :, vframe * 4 : vframe * 4 + 20] for vframe in range(i, min(lastframe, i + batch_size))
|
| ]
|
| cc_in = torch.cat(cc_batch, 0)
|
| cc_out = self.__S__.forward_aud(cc_in.to(self.device))
|
| cc_feat.append(cc_out.data.cpu())
|
|
|
| im_feat = torch.cat(im_feat, 0)
|
| cc_feat = torch.cat(cc_feat, 0)
|
|
|
|
|
|
|
|
|
|
|
| dists = calc_pdist(im_feat, cc_feat, vshift=vshift)
|
| mean_dists = torch.mean(torch.stack(dists, 1), 1)
|
|
|
| min_dist, minidx = torch.min(mean_dists, 0)
|
|
|
| av_offset = vshift - minidx
|
| conf = torch.median(mean_dists) - min_dist
|
|
|
| fdist = numpy.stack([dist[minidx].numpy() for dist in dists])
|
|
|
| fconf = torch.median(mean_dists).numpy() - fdist
|
| framewise_conf = signal.medfilt(fconf, kernel_size=9)
|
|
|
|
|
| rmtree(temp_dir)
|
| return av_offset.item(), min_dist.item(), conf.item()
|
|
|
| def extract_feature(self, opt, videofile):
|
|
|
| self.__S__.eval()
|
|
|
|
|
|
|
|
|
| cap = cv2.VideoCapture(videofile)
|
|
|
| frame_num = 1
|
| images = []
|
| while frame_num:
|
| frame_num += 1
|
| ret, image = cap.read()
|
| if ret == 0:
|
| break
|
|
|
| images.append(image)
|
|
|
| im = numpy.stack(images, axis=3)
|
| im = numpy.expand_dims(im, axis=0)
|
| im = numpy.transpose(im, (0, 3, 4, 1, 2))
|
|
|
| imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
|
|
|
|
|
|
|
|
|
|
|
| lastframe = len(images) - 4
|
| im_feat = []
|
|
|
| tS = time.time()
|
| for i in range(0, lastframe, opt.batch_size):
|
|
|
| im_batch = [
|
| imtv[:, :, vframe : vframe + 5, :, :] for vframe in range(i, min(lastframe, i + opt.batch_size))
|
| ]
|
| im_in = torch.cat(im_batch, 0)
|
| im_out = self.__S__.forward_lipfeat(im_in.to(self.device))
|
| im_feat.append(im_out.data.cpu())
|
|
|
| im_feat = torch.cat(im_feat, 0)
|
|
|
|
|
|
|
|
|
|
|
| print("Compute time %.3f sec." % (time.time() - tS))
|
|
|
| return im_feat
|
|
|
| def loadParameters(self, path):
|
| check_model_and_download(path)
|
| loaded_state = torch.load(path, map_location=lambda storage, loc: storage, weights_only=True)
|
|
|
| self_state = self.__S__.state_dict()
|
|
|
| for name, param in loaded_state.items():
|
|
|
| self_state[name].copy_(param)
|
|
|