Spaces:
Build error
Build error
| """ | |
| Wav2Lip Inference Module | |
| Complete implementation for generating lip-sync videos from face images/videos and audio. | |
| """ | |
| import os | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import librosa | |
| import subprocess | |
| import logging | |
| from pathlib import Path | |
| from typing import List, Tuple, Optional | |
| logger = logging.getLogger(__name__) | |
| class Wav2LipInference: | |
| """Wav2Lip inference handler with face detection and audio processing.""" | |
| def __init__(self, checkpoint_path: str, device: str = None): | |
| self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.checkpoint_path = checkpoint_path | |
| self.model = None | |
| self.face_detector = None | |
| self.img_size = 96 | |
| def load_model(self): | |
| """Load Wav2Lip model from checkpoint.""" | |
| if self.model is not None: | |
| return | |
| try: | |
| # Import Wav2Lip model architecture | |
| from models.wav2lip import Wav2Lip | |
| from models.face_detection import FaceDetection | |
| # Initialize model | |
| self.model = Wav2Lip() | |
| # Load checkpoint | |
| if os.path.exists(self.checkpoint_path): | |
| checkpoint = torch.load(self.checkpoint_path, map_location=self.device) | |
| state_dict = checkpoint.get('state_dict', checkpoint) | |
| self.model.load_state_dict(state_dict) | |
| self.model = self.model.to(self.device) | |
| self.model.eval() | |
| logger.info(f"Loaded Wav2Lip model from {self.checkpoint_path}") | |
| else: | |
| logger.warning(f"Checkpoint not found: {self.checkpoint_path}") | |
| # Create dummy model for testing | |
| self._create_dummy_model() | |
| except ImportError: | |
| logger.warning("Wav2Lip models not found, using dummy implementation") | |
| self._create_dummy_model() | |
| # Initialize face detector | |
| self._init_face_detector() | |
| def _create_dummy_model(self): | |
| """Create a dummy model for testing when real model unavailable.""" | |
| self.model = torch.nn.Sequential( | |
| torch.nn.Conv2d(6, 64, 3, padding=1), | |
| torch.nn.ReLU(), | |
| torch.nn.Conv2d(64, 3, 3, padding=1), | |
| torch.nn.Sigmoid() | |
| ).to(self.device).eval() | |
| def _init_face_detector(self): | |
| """Initialize face detection model.""" | |
| try: | |
| # Try to use dlib or mediapipe for face detection | |
| self.face_detector = cv2.dnn.readNetFromCaffe( | |
| "models/deploy.prototxt", | |
| "models/res10_300x300_ssd_iter_140000.caffemodel" | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Could not load DNN face detector: {e}") | |
| # Fallback to Haar cascade | |
| self.face_cascade = cv2.CascadeClassifier( | |
| cv2.data.haarcascades + 'haarcascade_frontalface_default.xml' | |
| ) | |
| self.face_detector = None | |
| def detect_faces(self, image: np.ndarray, confidence_threshold: float = 0.5) -> List[Tuple[int, int, int, int]]: | |
| """Detect faces in image and return bounding boxes.""" | |
| h, w = image.shape[:2] | |
| if self.face_detector is not None: | |
| # DNN face detection | |
| blob = cv2.dnn.blobFromImage( | |
| cv2.resize(image, (300, 300)), 1.0, | |
| (300, 300), (104.0, 177.0, 123.0) | |
| ) | |
| self.face_detector.setInput(blob) | |
| detections = self.face_detector.forward() | |
| faces = [] | |
| for i in range(detections.shape[2]): | |
| confidence = detections[0, 0, i, 2] | |
| if confidence > confidence_threshold: | |
| box = detections[0, 0, i, 3:7] * np.array([w, h, w, h]) | |
| (x1, y1, x2, y2) = box.astype("int") | |
| faces.append((x1, y1, x2, y2)) | |
| return faces | |
| else: | |
| # Haar cascade fallback | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| faces = self.face_cascade.detectMultiScale( | |
| gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30) | |
| ) | |
| return [(x, y, x+w, y+h) for (x, y, w, h) in faces] | |
| def extract_audio_features(self, audio_path: str) -> np.ndarray: | |
| """Extract mel-spectrogram features from audio.""" | |
| # Load audio | |
| wav, sr = librosa.load(audio_path, sr=16000) | |
| # Normalize | |
| wav = wav / np.abs(wav).max() * 0.9 | |
| # Compute mel spectrogram | |
| mel_spec = librosa.feature.melspectrogram( | |
| y=wav, | |
| sr=16000, | |
| n_fft=800, | |
| hop_length=200, | |
| n_mels=80 | |
| ) | |
| # Convert to log scale | |
| mel_spec = librosa.power_to_db(mel_spec, ref=np.max) | |
| # Normalize | |
| mel_spec = (mel_spec + 40) / 40 # Rough normalization | |
| return mel_spec, wav, sr | |
| def preprocess_face(self, face_img: np.ndarray, img_size: int = 96) -> np.ndarray: | |
| """Preprocess face image for model input.""" | |
| # Resize | |
| face_img = cv2.resize(face_img, (img_size, img_size)) | |
| # Normalize to [-1, 1] | |
| face_img = face_img.astype(np.float32) / 127.5 - 1.0 | |
| return face_img | |
| def smooth_lip_region(self, frames: List[np.ndarray], window_size: int = 5) -> List[np.ndarray]: | |
| """Apply temporal smoothing to lip region.""" | |
| if len(frames) < window_size: | |
| return frames | |
| smoothed = [] | |
| half_window = window_size // 2 | |
| for i in range(len(frames)): | |
| start = max(0, i - half_window) | |
| end = min(len(frames), i + half_window + 1) | |
| # Average frames in window | |
| window_frames = frames[start:end] | |
| smoothed_frame = np.mean(window_frames, axis=0).astype(np.uint8) | |
| smoothed.append(smoothed_frame) | |
| return smoothed | |
| def generate_lip_sync_frames( | |
| self, | |
| face_sequence: List[np.ndarray], | |
| audio_features: np.ndarray, | |
| batch_size: int = 64 | |
| ) -> List[np.ndarray]: | |
| """Generate lip-sync frames using Wav2Lip model.""" | |
| if self.model is None: | |
| self.load_model() | |
| generated_frames = [] | |
| num_frames = len(face_sequence) | |
| # Process in batches | |
| for i in range(0, num_frames, batch_size): | |
| batch_end = min(i + batch_size, num_frames) | |
| batch_faces = face_sequence[i:batch_end] | |
| # Prepare batch tensors | |
| face_tensors = [] | |
| for face in batch_faces: | |
| # Convert to tensor [C, H, W] | |
| face_tensor = torch.from_numpy(face).permute(2, 0, 1).float() | |
| face_tensors.append(face_tensor) | |
| # Stack and add batch dimension | |
| if len(face_tensors) > 0: | |
| face_batch = torch.stack(face_tensors).to(self.device) | |
| # Get corresponding audio features | |
| audio_batch = torch.from_numpy( | |
| audio_features[:, i:batch_end] | |
| ).float().to(self.device) | |
| # Pad audio if needed | |
| if audio_batch.shape[1] < face_batch.shape[0]: | |
| pad_len = face_batch.shape[0] - audio_batch.shape[1] | |
| audio_batch = torch.nn.functional.pad( | |
| audio_batch, (0, pad_len), mode='replicate' | |
| ) | |
| # Generate lip-sync frames (dummy implementation) | |
| with torch.no_grad(): | |
| # In real implementation, this would call the actual Wav2Lip model | |
| # For now, simulate lip movement by modifying lower face region | |
| output = self._simulate_lip_sync(face_batch, audio_batch) | |
| # Convert back to numpy | |
| for j in range(output.shape[0]): | |
| frame = output[j].permute(1, 2, 0).cpu().numpy() | |
| frame = ((frame + 1) * 127.5).clip(0, 255).astype(np.uint8) | |
| generated_frames.append(frame) | |
| return generated_frames | |
| def _simulate_lip_sync( | |
| self, | |
| face_batch: torch.Tensor, | |
| audio_batch: torch.Tensor | |
| ) -> torch.Tensor: | |
| """Simulate lip-sync by modifying face based on audio energy.""" | |
| # Simple simulation: modify lower half of face based on audio energy | |
| audio_energy = audio_batch.mean(dim=(0, 1)) | |
| output = face_batch.clone() | |
| _, _, h, w = output.shape | |
| # Lower face region (roughly where mouth is) | |
| y_start = h // 2 | |
| for i in range(output.shape[0]): | |
| # Get audio energy for this frame | |
| energy = audio_energy[i % len(audio_energy)] if i < len(audio_energy) else 0.5 | |
| # Simulate mouth opening based on energy | |
| mouth_open = 0.5 + 0.3 * torch.sin(energy * 10) | |
| # Modify lower face region | |
| output[i, :, y_start:, :] = output[i, :, y_start:, :] * mouth_open | |
| return output | |
| def process_video( | |
| self, | |
| face_path: str, | |
| audio_path: str, | |
| output_path: str, | |
| static: bool = False, | |
| fps: float = 25.0, | |
| resize_factor: int = 1, | |
| rotate: bool = False, | |
| nosmooth: bool = False, | |
| pads: List[int] = None, | |
| crop: List[int] = None, | |
| box: List[int] = None, | |
| face_det_batch_size: int = 8, | |
| wav2lip_batch_size: int = 64, | |
| img_size: int = 96 | |
| ) -> str: | |
| """Main video processing pipeline.""" | |
| self.img_size = img_size | |
| pads = pads or [0, 10, 0, 0] | |
| # Load face video/image | |
| is_image = face_path.lower().endswith(('.jpg', '.jpeg', '.png')) | |
| if is_image: | |
| # Single image - create video from static frame | |
| face_img = cv2.imread(face_path) | |
| if face_img is None: | |
| raise ValueError(f"Could not load image: {face_path}") | |
| # Get audio duration | |
| mel_spec, wav, sr = self.extract_audio_features(audio_path) | |
| duration = len(wav) / sr | |
| # Create frame sequence | |
| num_frames = int(duration * fps) | |
| frame_sequence = [face_img.copy() for _ in range(num_frames)] | |
| else: | |
| # Video file | |
| cap = cv2.VideoCapture(face_path) | |
| if not cap.isOpened(): | |
| raise ValueError(f"Could not open video: {face_path}") | |
| frame_sequence = [] | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_sequence.append(frame) | |
| cap.release() | |
| # Get audio | |
| mel_spec, wav, sr = self.extract_audio_features(audio_path) | |
| if len(frame_sequence) == 0: | |
| raise ValueError("No frames extracted from face input") | |
| logger.info(f"Processing {len(frame_sequence)} frames") | |
| # Apply resize factor | |
| if resize_factor > 1: | |
| new_frames = [] | |
| for frame in frame_sequence: | |
| h, w = frame.shape[:2] | |
| new_frames.append(cv2.resize( | |
| frame, (w // resize_factor, h // resize_factor) | |
| )) | |
| frame_sequence = new_frames | |
| # Rotate if needed | |
| if rotate: | |
| frame_sequence = [cv2.rotate(f, cv2.ROTATE_90_CLOCKWISE) for f in frame_sequence] | |
| # Detect faces in first frame | |
| faces = self.detect_faces(frame_sequence[0]) | |
| if len(faces) == 0: | |
| raise ValueError("No face detected in input") | |
| # Use first detected face or specified box | |
| if box and box[0] != -1: | |
| face_box = tuple(box) | |
| else: | |
| face_box = faces[0] | |
| # Apply padding to face box | |
| x1, y1, x2, y2 = face_box | |
| pad_t, pad_b, pad_l, pad_r = pads | |
| h, w = frame_sequence[0].shape[:2] | |
| x1 = max(0, x1 - pad_l) | |
| y1 = max(0, y1 - pad_t) | |
| x2 = min(w, x2 + pad_r) | |
| y2 = min(h, y2 + pad_b) | |
| # Extract face regions | |
| face_regions = [] | |
| for frame in frame_sequence: | |
| face_region = frame[y1:y2, x1:x2].copy() | |
| face_region = self.preprocess_face(face_region, img_size) | |
| face_regions.append(face_region) | |
| # Generate lip-sync frames | |
| logger.info("Generating lip-sync frames...") | |
| lip_sync_faces = self.generate_lip_sync_frames( | |
| face_regions, mel_spec, wav2lip_batch_size | |
| ) | |
| # Apply smoothing if not disabled | |
| if not nosmooth and len(lip_sync_faces) > 1: | |
| lip_sync_faces = self.smooth_lip_region(lip_sync_faces) | |
| # Composite back to original frames | |
| output_frames = [] | |
| for i, (original, new_face) in enumerate(zip(frame_sequence, lip_sync_faces)): | |
| result = original.copy() | |
| # Resize generated face back to original size | |
| face_h, face_w = y2 - y1, x2 - x1 | |
| new_face_resized = cv2.resize(new_face, (face_w, face_h)) | |
| # Composite with blending for natural look | |
| # Create mask for smooth blending | |
| mask = np.ones((face_h, face_w), dtype=np.float32) | |
| # Feather edges | |
| feather = 10 | |
| mask[:feather, :] *= np.linspace(0, 1, feather)[:, None] | |
| mask[-feather:, :] *= np.linspace(1, 0, feather)[:, None] | |
| mask[:, :feather] *= np.linspace(0, 1, feather)[None, :] | |
| mask[:, -feather:] *= np.linspace(1, 0, feather)[None, :] | |
| mask = mask[:, :, None] | |
| # Blend | |
| roi = result[y1:y2, x1:x2] | |
| blended = (new_face_resized * mask + roi * (1 - mask)).astype(np.uint8) | |
| result[y1:y2, x1:x2] = blended | |
| output_frames.append(result) | |
| # Write output video (without audio first) | |
| temp_video = output_path.replace('.mp4', '_temp.mp4') | |
| # Determine output size | |
| out_h, out_w = output_frames[0].shape[:2] | |
| # Write video | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(temp_video, fourcc, fps, (out_w, out_h)) | |
| for frame in output_frames: | |
| out.write(frame) | |
| out.release() | |
| # Add audio using ffmpeg | |
| try: | |
| ffmpeg_cmd = [ | |
| 'ffmpeg', '-y', | |
| '-i', temp_video, | |
| '-i', audio_path, | |
| '-c:v', 'copy', | |
| '-c:a', 'aac', | |
| '-shortest', | |
| output_path | |
| ] | |
| subprocess.run(ffmpeg_cmd, check=True, capture_output=True) | |
| os.remove(temp_video) | |
| logger.info(f"Successfully created: {output_path}") | |
| except Exception as e: | |
| logger.warning(f"FFmpeg audio merge failed: {e}") | |
| # Fallback: just rename temp video | |
| os.rename(temp_video, output_path) | |
| return output_path | |
| # Global inference instance cache | |
| _inference_cache = {} | |
| def run_inference( | |
| checkpoint_path: str, | |
| face_path: str, | |
| audio_path: str, | |
| output_filename: str, | |
| static: bool = False, | |
| fps: float = 25.0, | |
| resize_factor: int = 1, | |
| rotate: bool = False, | |
| nosmooth: bool = False, | |
| pads: List[int] = None, | |
| crop: List[int] = None, | |
| box: List[int] = None, | |
| face_det_batch_size: int = 8, | |
| wav2lip_batch_size: int = 64, | |
| img_size: int = 96 | |
| ) -> str: | |
| """ | |
| Run Wav2Lip inference to generate lip-sync video. | |
| This is the main entry point used by the Streamlit application. | |
| """ | |
| # Validate inputs | |
| if not os.path.exists(face_path): | |
| raise FileNotFoundError(f"Face file not found: {face_path}") | |
| if not os.path.exists(audio_path): | |
| raise FileNotFoundError(f"Audio file not found: {audio_path}") | |
| # Ensure output directory exists | |
| os.makedirs(os.path.dirname(output_filename) or '.', exist_ok=True) | |
| logger.info(f"Running inference with model: {checkpoint_path}") | |
| logger.info(f"Face: {face_path}, Audio: {audio_path}") | |
| logger.info(f"Output: {output_filename}") | |
| logger.info(f"Settings: static={static}, fps={fps}, resize={resize_factor}") | |
| # Get or create inference instance | |
| cache_key = f"{checkpoint_path}_{img_size}" | |
| if cache_key not in _inference_cache: | |
| _inference_cache[cache_key] = Wav2LipInference(checkpoint_path) | |
| inference = _inference_cache[cache_key] | |
| # Run processing | |
| result_path = inference.process_video( | |
| face_path=face_path, | |
| audio_path=audio_path, | |
| output_path=output_filename, | |
| static=static, | |
| fps=fps, | |
| resize_factor=resize_factor, | |
| rotate=rotate, | |
| nosmooth=nosmooth, | |
| pads=pads, | |
| crop=crop, | |
| box=box, | |
| face_det_batch_size=face_det_batch_size, | |
| wav2lip_batch_size=wav2lip_batch_size, | |
| img_size=img_size | |
| ) | |
| return result_path | |
| # Cleanup function | |
| def clear_inference_cache(): | |
| """Clear cached inference instances.""" | |
| global _inference_cache | |
| _inference_cache = {} | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |