import cv2 import librosa import mediapipe as mp import numpy as np import torch import torch.nn.functional as F import torchvision.transforms.v2 as transforms from numpy.typing import NDArray from python_speech_features import logfbank from transformers import FeatureExtractionMixin from transformers.feature_extraction_utils import BatchFeature mp_face_mesh = mp.solutions.face_mesh class AVHubertFeatureExtractor(FeatureExtractionMixin): model_input_names = ["input_values", "pixel_values"] def __init__( self, max_sample_size: int | None = None, normalize: bool = True, stack_order_audio: int = 4, image_crop_size: int = 88, image_mean: float = 0.421, image_std: float = 0.165, sr: int = 16_000, static_image_mode: bool = False, refine_landmarks: bool = False, min_detection_confidence: float = 0.5, min_tracking_confidence: float = 0.5, landmark_indices: tuple[int, ...] = (5, 411, 199, 187), # (top, right, bottom, left) of mouth **kwargs, ) -> None: super().__init__(**kwargs) self.max_sample_size = max_sample_size self.normalize = normalize self.stack_order_audio = stack_order_audio self.image_crop_size = image_crop_size self.transforms = transforms.Compose( [ transforms.ToImage(), transforms.CenterCrop(image_crop_size), transforms.ToDtype(torch.float32, scale=True), transforms.Normalize([image_mean], [image_std]), ] ) self.sr = sr self.static_image_mode = static_image_mode self.refine_landmarks = refine_landmarks self.min_detection_confidence = min_detection_confidence self.min_tracking_confidence = min_tracking_confidence self.landmark_indices = landmark_indices def _load_video(self, video: str | NDArray[np.uint8], extract_mouth: bool = False) -> torch.FloatTensor: """Input video must be in RGB format if type is numpy array.""" if isinstance(video, str): cap = cv2.VideoCapture(video) frames = [] for _ in range(int(cap.get(cv2.CAP_PROP_FRAME_COUNT))): ret, frame = cap.read() if not ret: break if not extract_mouth: # Already extracted mouth frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)) else: frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) frames_np = np.stack(frames, axis=0) else: frames_np = video if not extract_mouth: # Already extracted mouth frames_np = np.stack([cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) for frame in frames_np], axis=0) if extract_mouth: frames_np = self._extract_mouth(frames_np) return torch.from_numpy(frames_np).unsqueeze(dim=1) def _extract_mouth(self, frames: NDArray[np.uint8]) -> NDArray[np.uint8]: mouth_frames = [] top_idx, right_idx, bottom_idx, left_idx = self.landmark_indices with mp_face_mesh.FaceMesh( static_image_mode=self.static_image_mode, max_num_faces=1, refine_landmarks=self.refine_landmarks, min_detection_confidence=self.min_detection_confidence, min_tracking_confidence=self.min_tracking_confidence, ) as face_mesh: for frame in frames: res = face_mesh.process(frame) if res.multi_face_landmarks is None or len(res.multi_face_landmarks) == 0: mouth_frames.append(np.zeros([self.image_crop_size, self.image_crop_size], dtype=np.uint8)) continue landmarks = res.multi_face_landmarks[0].landmark top = landmarks[top_idx] left = landmarks[left_idx] right = landmarks[right_idx] bottom = landmarks[bottom_idx] H, W = frame.shape[:2] xmax = max(top.x, left.x, right.x, bottom.x) ymax = max(top.y, left.y, right.y, bottom.y) xmin = min(top.x, left.x, right.x, bottom.x) ymin = min(top.y, left.y, right.y, bottom.y) patch_size = max((xmax - xmin) * W, (ymax - ymin) * H) # To extract square region half = int(patch_size / 2) y_center = int(ymin * H) + int(((ymax - ymin) / 2) * H) x_center = int(xmin * W) + int(((xmax - xmin) / 2) * W) lip = frame[ y_center - half : y_center + half, x_center - half : x_center + half, :, ] try: lip = cv2.resize(lip, (self.image_crop_size, self.image_crop_size)) except Exception: lip = np.zeros([self.image_crop_size, self.image_crop_size, 3], dtype=np.uint8) mouth_frames.append(cv2.cvtColor(lip, cv2.COLOR_RGB2GRAY)) return np.stack(mouth_frames, axis=0) def _load_audio(self, audio: str | NDArray[np.float32]) -> torch.FloatTensor: def stacker(feats, stack_order): feat_dim = feats.shape[1] if len(feats) % stack_order != 0: res = stack_order - len(feats) % stack_order res = np.zeros([res, feat_dim]).astype(feats.dtype) feats = np.concatenate([feats, res], axis=0) feats = feats.reshape((-1, stack_order, feat_dim)).reshape(-1, stack_order * feat_dim) return feats sr = None if isinstance(audio, str): audio, sr = librosa.load(audio, sr=16_000) if sr is None: sr = self.sr fbank = logfbank(audio, samplerate=sr).astype(np.float32) fbank = stacker(fbank, self.stack_order_audio) return torch.from_numpy(fbank) def _align_time_steps( self, audio: list[torch.FloatTensor], video: list[torch.FloatTensor] ) -> tuple[list[torch.FloatTensor], list[torch.FloatTensor]]: aligned_indices = [] for sample_audio, sample_video in zip(audio, video): diff = len(sample_audio) - len(sample_video) if diff != 0: aligned_indices.append( torch.arange(0, len(sample_audio)).float() * len(sample_video) / len(sample_audio) ) else: aligned_indices.append(torch.arange(0, len(sample_audio))) return ( audio, [ sample[torch.clamp(torch.floor(indices), max=sample.shape[0] - 1).long()] for sample, indices in zip(video, aligned_indices) ], ) def __call__( self, raw_audio: NDArray[np.float32] | str | list[NDArray[np.float32]] | list[str] | None = None, raw_video: NDArray[np.uint8] | str | list[NDArray[np.uint8]] | list[str] | None = None, extract_mouth: bool = False, **kwargs, ) -> BatchFeature: if not isinstance(raw_audio, list): raw_audio = [raw_audio] if not isinstance(raw_video, list): raw_video = [raw_video] audio = [self._load_audio(sample) if sample is not None else None for sample in raw_audio] video = [self._load_video(sample, extract_mouth) if sample is not None else None for sample in raw_video] for batch_idx in range(len(audio)): sample_a = audio[batch_idx] sample_v = video[batch_idx] assert sample_a is not None or sample_v is not None if sample_a is None: sample_a = torch.zeros((sample_v.shape[0], 26 * self.stack_order_audio)) audio[batch_idx] = sample_a elif sample_v is None: # 25 fps sample_v = torch.zeros((sample_a.shape[0], 1, self.image_crop_size, self.image_crop_size)) video[batch_idx] = sample_v audio, video = self._align_time_steps(audio, video) max_length = max(len(data) for data in audio) input_values = [] pixel_values = [] padding_mask = [] for feat_audio, feat_video in zip(audio, video): remainder_length = max_length - len(feat_audio) audio_remainder = torch.zeros( size=(remainder_length,) + feat_audio.size()[1:], dtype=feat_audio.dtype, ) video_remainder = torch.zeros( size=(remainder_length,) + feat_video.size()[1:], dtype=feat_video.dtype, ) feat_audio = torch.cat((feat_audio, audio_remainder)) feat_video = torch.cat((feat_video, video_remainder)) if self.max_sample_size: feat_audio = feat_audio[: self.max_sample_size] feat_video = feat_video[: self.max_sample_size] pad_mask = torch.zeros(max_length) pad_mask[max_length - remainder_length :] = 1 input_values.append(feat_audio) pixel_values.append(feat_video) padding_mask.append(pad_mask) input_values = torch.stack(input_values) batch = BatchFeature( { "input_values": ( F.layer_norm(input_values, input_values.shape[2:]) if self.normalize else input_values ), "pixel_values": self.transforms(torch.stack(pixel_values)), "padding_mask": torch.stack(padding_mask), } ) return batch def to_dict(self): output = super().to_dict() output["transforms"] = self._transforms_to_dict(output["transforms"]) return output def _transforms_to_dict(self, transforms: transforms.Compose): output = [] for component in transforms.__dict__["transforms"]: name = component.__class__.__name__ component_dict = {"transforms_type": name} for k, v in component.__dict__.items(): if k.startswith("_"): continue component_dict[k] = str(v) output.append(component_dict) return output