Spaces:
Runtime error
Runtime error
| from typing import Iterable, Iterator, List, Tuple | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from omegaconf import DictConfig | |
| from tqdm import tqdm | |
| from config import hparams as hp | |
| from nota_wav2lip.models.util import count_params, load_model | |
| class Wav2LipInferenceImpl: | |
| def __init__(self, model_name: str, hp_inference_model: DictConfig, device='cpu'): | |
| self.model: nn.Module = load_model( | |
| model_name, | |
| device=device, | |
| **hp_inference_model | |
| ) | |
| self.device = device | |
| self._params: str = self._format_param(count_params(self.model)) | |
| def params(self): | |
| return self._params | |
| def _format_param(num_params: int) -> str: | |
| params_in_million = num_params / 1e6 | |
| return f"{params_in_million:.1f}M" | |
| def _reset_batch() -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[List[int]]]: | |
| return [], [], [], [] | |
| def get_data_iterator( | |
| self, | |
| audio_iterable: Iterable[np.ndarray], | |
| video_iterable: List[Tuple[np.ndarray, List[int]]] | |
| ) -> Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray, List[int]]]: | |
| img_batch, mel_batch, frame_batch, coords_batch = self._reset_batch() | |
| for i, m in enumerate(audio_iterable): | |
| idx = i % len(video_iterable) | |
| _frame_to_save, coords = video_iterable[idx] | |
| frame_to_save = _frame_to_save.copy() | |
| face = frame_to_save[coords[0]:coords[1], coords[2]:coords[3]].copy() | |
| face: np.ndarray = cv2.resize(face, (hp.face.img_size, hp.face.img_size)) | |
| img_batch.append(face) | |
| mel_batch.append(m) | |
| frame_batch.append(frame_to_save) | |
| coords_batch.append(coords) | |
| if len(img_batch) >= hp.inference.batch_size: | |
| img_batch = np.asarray(img_batch) | |
| mel_batch = np.asarray(mel_batch) | |
| img_masked = img_batch.copy() | |
| img_masked[:, hp.face.img_size // 2:] = 0 | |
| img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. | |
| mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) | |
| yield img_batch, mel_batch, frame_batch, coords_batch | |
| img_batch, mel_batch, frame_batch, coords_batch = self._reset_batch() | |
| if len(img_batch) > 0: | |
| img_batch = np.asarray(img_batch) | |
| mel_batch = np.asarray(mel_batch) | |
| img_masked = img_batch.copy() | |
| img_masked[:, hp.face.img_size // 2:] = 0 | |
| img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. | |
| mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) | |
| yield img_batch, mel_batch, frame_batch, coords_batch | |
| def inference_with_iterator( | |
| self, | |
| audio_iterable: Iterable[np.ndarray], | |
| video_iterable: List[Tuple[np.ndarray, List[int]]] | |
| ) -> Iterator[np.ndarray]: | |
| data_iterator = self.get_data_iterator(audio_iterable, video_iterable) | |
| for (img_batch, mel_batch, frames, coords) in \ | |
| tqdm(data_iterator, total=int(np.ceil(float(len(audio_iterable)) / hp.inference.batch_size))): | |
| img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(self.device) | |
| mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(self.device) | |
| preds: torch.Tensor = self.forward(mel_batch, img_batch) | |
| preds = preds.cpu().numpy().transpose(0, 2, 3, 1) * 255. | |
| for pred, frame, coord in zip(preds, frames, coords): | |
| y1, y2, x1, x2 = coord | |
| pred = cv2.resize(pred.astype(np.uint8), (x2 - x1, y2 - y1)) | |
| frame[y1:y2, x1:x2] = pred | |
| yield frame | |
| def forward(self, audio_sequences: torch.Tensor, face_sequences: torch.Tensor) -> torch.Tensor: | |
| return self.model(audio_sequences, face_sequences) | |
| def __call__(self, *args, **kwargs): | |
| return self.forward(*args, **kwargs) | |