LipNet / lipnet /preprocessing.py
thienphuc12339's picture
Update lipnet/preprocessing.py
295376e verified
import os
import logging
from pathlib import Path
from typing import Iterable, List, Optional
import cv2
import numpy as np
import tensorflow as tf
import os
_mp_import_error = None
mp_solutions = None
try:
import mediapipe as mp # keep this for version/file debug
try:
# NEW: works even when mp.solutions is not exposed
from mediapipe.python import solutions as mp_solutions # type: ignore
except Exception:
# fallback for older layouts
from mediapipe import solutions as mp_solutions # type: ignore
except Exception as exc:
_mp_import_error = exc
mp_solutions = None
if os.getenv("DEBUG_MEDIAPIPE", "0") == "1":
try:
import mediapipe as mp
print("mediapipe version:", getattr(mp, "__version__", "unknown"))
print("mediapipe file:", getattr(mp, "__file__", "unknown"))
print("has solutions attr:", hasattr(mp, "solutions"))
# also verify the actual module we will use:
print("mp_solutions module:", getattr(mp_solutions, "__name__", None))
except Exception as dbg_exc:
print("mediapipe debug import failed:", dbg_exc)
# ------------------------------------------------------------------
# Local imports
# ------------------------------------------------------------------
from . import config
logger = logging.getLogger(__name__)
class VideoPreprocessor:
"""
Handles frame extraction and normalization from either a video file
or an iterable of pre-captured frames.
"""
def __init__(
self,
target_size: int = config.TARGET_SIZE,
max_frames: Optional[int] = config.MAX_FRAMES,
detection_confidence: float = config.DETECTION_CONFIDENCE,
tracking_confidence: float = config.TRACKING_CONFIDENCE,
):
self.target_size = target_size
self.max_frames = max_frames
self.detection_confidence = detection_confidence
self.tracking_confidence = tracking_confidence
# Indices for lip landmarks
self.UPPER_LIP_INDICES = [61, 185, 40, 39, 37, 0, 267, 269, 270, 409, 291]
self.LOWER_LIP_INDICES = [146, 91, 181, 84, 17, 314, 405, 321, 375, 291]
self.LIP_INDICES = self.UPPER_LIP_INDICES + self.LOWER_LIP_INDICES
def _require_face_mesh_module(self):
if mp_solutions is None:
raise RuntimeError(
"Mediapipe is not installed correctly. "
"Please install with `pip install mediapipe` (>=0.10). "
f"Original import error: {_mp_import_error}"
)
return mp_solutions.face_mesh
def preprocess_video(self, video_path: str) -> Optional[tf.Tensor]:
"""
Preprocess frames from a video file path.
Returns a normalized tensor of shape (num_frames, target_size, target_size, 1)
or None if no usable frames are found.
"""
path = Path(video_path)
if not path.exists():
logger.error("Video path does not exist: %s", video_path)
return None
cap = cv2.VideoCapture(str(path))
if not cap.isOpened():
logger.error("Failed to open video: %s", video_path)
return None
frames: List[tf.Tensor] = []
try:
face_mesh_module = self._require_face_mesh_module()
with face_mesh_module.FaceMesh(
static_image_mode=False,
max_num_faces=1,
refine_landmarks=True,
min_detection_confidence=self.detection_confidence,
min_tracking_confidence=self.tracking_confidence,
) as face_mesh:
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
processed = self._process_frame(frame, face_mesh)
if processed is not None:
frames.append(processed)
if self.max_frames and len(frames) >= self.max_frames:
logger.info("Reached max_frames=%s; stopping early.", self.max_frames)
break
finally:
cap.release()
return self._finalize_frames(frames)
def preprocess_frames(self, frames: Iterable[np.ndarray]) -> Optional[tf.Tensor]:
"""
Preprocess frames that have already been captured (e.g., from a webcam).
"""
processed_frames: List[tf.Tensor] = []
face_mesh_module = self._require_face_mesh_module()
with face_mesh_module.FaceMesh(
static_image_mode=False,
max_num_faces=1,
refine_landmarks=True,
min_detection_confidence=self.detection_confidence,
min_tracking_confidence=self.tracking_confidence,
) as face_mesh:
for frame in frames:
processed = self._process_frame(frame, face_mesh)
if processed is not None:
processed_frames.append(processed)
if self.max_frames and len(processed_frames) >= self.max_frames:
logger.info("Reached max_frames=%s; stopping early.", self.max_frames)
break
return self._finalize_frames(processed_frames)
def _process_frame(self, frame: np.ndarray, face_mesh) -> Optional[tf.Tensor]:
"""
Run landmark detection on a single frame and return a grayscale lip crop.
"""
try:
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = face_mesh.process(rgb_frame)
if not results.multi_face_landmarks:
logger.debug("No face landmarks detected in frame.")
return None
face_landmarks = results.multi_face_landmarks[0]
lip_landmarks = [face_landmarks.landmark[i] for i in self.LIP_INDICES]
h, w, _ = frame.shape
x_coords = [int(landmark.x * w) for landmark in lip_landmarks]
y_coords = [int(landmark.y * h) for landmark in lip_landmarks]
x_min, x_max = max(0, min(x_coords)), min(w, max(x_coords))
y_min, y_max = max(0, min(y_coords)), min(h, max(y_coords))
if x_max <= x_min or y_max <= y_min:
logger.debug("Invalid lip bounding box; skipping frame.")
return None
lip_frame = frame[y_min:y_max, x_min:x_max]
lip_frame_resized = cv2.resize(lip_frame, (self.target_size, self.target_size))
lip_frame_gray = tf.image.rgb_to_grayscale(lip_frame_resized)
return lip_frame_gray
except Exception as exc:
logger.warning("Error processing frame: %s", exc)
return None
def _finalize_frames(self, frames: List[tf.Tensor]) -> Optional[tf.Tensor]:
if not frames:
logger.error("No frames extracted during preprocessing.")
return None
stacked = tf.stack(frames)
stacked = tf.cast(stacked, tf.float32)
mean = tf.math.reduce_mean(stacked)
std = tf.math.reduce_std(stacked)
if tf.math.equal(std, 0.0):
std = tf.constant(1.0, dtype=tf.float32)
normalized = (stacked - mean) / std
return normalized