File size: 17,841 Bytes
a261d90
0ca18f4
 
a261d90
 
 
0ca18f4
 
 
 
 
a261d90
 
0ca18f4
a261d90
 
 
 
0ca18f4
 
a261d90
0ca18f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a261d90
0ca18f4
 
 
 
 
 
 
 
 
a261d90
0ca18f4
 
 
 
 
 
 
 
 
 
 
 
 
a261d90
0ca18f4
 
a261d90
0ca18f4
 
a261d90
0ca18f4
a261d90
0ca18f4
 
 
 
a261d90
0ca18f4
 
a261d90
0ca18f4
 
 
 
 
 
 
 
 
a261d90
0ca18f4
 
 
a261d90
0ca18f4
 
 
 
a261d90
0ca18f4
a261d90
0ca18f4
 
 
 
 
 
 
 
 
a261d90
0ca18f4
 
 
 
 
 
 
a261d90
0ca18f4
 
 
 
 
 
 
 
 
 
a261d90
0ca18f4
 
 
 
a261d90
0ca18f4
 
 
 
 
 
a261d90
0ca18f4
 
 
 
 
a261d90
0ca18f4
 
 
 
 
 
 
a261d90
0ca18f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a261d90
0ca18f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a261d90
0ca18f4
 
 
a261d90
0ca18f4
 
 
 
a261d90
0ca18f4
 
 
a261d90
0ca18f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a261d90
0ca18f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a261d90
0ca18f4
 
a261d90
0ca18f4
 
 
 
 
a261d90
0ca18f4
 
a261d90
0ca18f4
 
 
 
a261d90
0ca18f4
 
 
 
a261d90
0ca18f4
a261d90
0ca18f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a261d90
0ca18f4
a261d90
 
0ca18f4
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
"""
Wav2Lip Inference Module
Complete implementation for generating lip-sync videos from face images/videos and audio.
"""

import os
import cv2
import numpy as np
import torch
import librosa
import subprocess
import logging
from pathlib import Path
from typing import List, Tuple, Optional

logger = logging.getLogger(__name__)


class Wav2LipInference:
    """Wav2Lip inference handler with face detection and audio processing."""
    
    def __init__(self, checkpoint_path: str, device: str = None):
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.checkpoint_path = checkpoint_path
        self.model = None
        self.face_detector = None
        self.img_size = 96
        
    def load_model(self):
        """Load Wav2Lip model from checkpoint."""
        if self.model is not None:
            return
            
        try:
            # Import Wav2Lip model architecture
            from models.wav2lip import Wav2Lip
            from models.face_detection import FaceDetection
            
            # Initialize model
            self.model = Wav2Lip()
            
            # Load checkpoint
            if os.path.exists(self.checkpoint_path):
                checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
                state_dict = checkpoint.get('state_dict', checkpoint)
                self.model.load_state_dict(state_dict)
                self.model = self.model.to(self.device)
                self.model.eval()
                logger.info(f"Loaded Wav2Lip model from {self.checkpoint_path}")
            else:
                logger.warning(f"Checkpoint not found: {self.checkpoint_path}")
                # Create dummy model for testing
                self._create_dummy_model()
                
        except ImportError:
            logger.warning("Wav2Lip models not found, using dummy implementation")
            self._create_dummy_model()
            
        # Initialize face detector
        self._init_face_detector()
    
    def _create_dummy_model(self):
        """Create a dummy model for testing when real model unavailable."""
        self.model = torch.nn.Sequential(
            torch.nn.Conv2d(6, 64, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 3, 3, padding=1),
            torch.nn.Sigmoid()
        ).to(self.device).eval()
        
    def _init_face_detector(self):
        """Initialize face detection model."""
        try:
            # Try to use dlib or mediapipe for face detection
            self.face_detector = cv2.dnn.readNetFromCaffe(
                "models/deploy.prototxt",
                "models/res10_300x300_ssd_iter_140000.caffemodel"
            )
        except Exception as e:
            logger.warning(f"Could not load DNN face detector: {e}")
            # Fallback to Haar cascade
            self.face_cascade = cv2.CascadeClassifier(
                cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
            )
            self.face_detector = None
    
    def detect_faces(self, image: np.ndarray, confidence_threshold: float = 0.5) -> List[Tuple[int, int, int, int]]:
        """Detect faces in image and return bounding boxes."""
        h, w = image.shape[:2]
        
        if self.face_detector is not None:
            # DNN face detection
            blob = cv2.dnn.blobFromImage(
                cv2.resize(image, (300, 300)), 1.0,
                (300, 300), (104.0, 177.0, 123.0)
            )
            self.face_detector.setInput(blob)
            detections = self.face_detector.forward()
            
            faces = []
            for i in range(detections.shape[2]):
                confidence = detections[0, 0, i, 2]
                if confidence > confidence_threshold:
                    box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
                    (x1, y1, x2, y2) = box.astype("int")
                    faces.append((x1, y1, x2, y2))
            return faces
        else:
            # Haar cascade fallback
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            faces = self.face_cascade.detectMultiScale(
                gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30)
            )
            return [(x, y, x+w, y+h) for (x, y, w, h) in faces]
    
    def extract_audio_features(self, audio_path: str) -> np.ndarray:
        """Extract mel-spectrogram features from audio."""
        # Load audio
        wav, sr = librosa.load(audio_path, sr=16000)
        
        # Normalize
        wav = wav / np.abs(wav).max() * 0.9
        
        # Compute mel spectrogram
        mel_spec = librosa.feature.melspectrogram(
            y=wav,
            sr=16000,
            n_fft=800,
            hop_length=200,
            n_mels=80
        )
        
        # Convert to log scale
        mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
        
        # Normalize
        mel_spec = (mel_spec + 40) / 40  # Rough normalization
        
        return mel_spec, wav, sr
    
    def preprocess_face(self, face_img: np.ndarray, img_size: int = 96) -> np.ndarray:
        """Preprocess face image for model input."""
        # Resize
        face_img = cv2.resize(face_img, (img_size, img_size))
        
        # Normalize to [-1, 1]
        face_img = face_img.astype(np.float32) / 127.5 - 1.0
        
        return face_img
    
    def smooth_lip_region(self, frames: List[np.ndarray], window_size: int = 5) -> List[np.ndarray]:
        """Apply temporal smoothing to lip region."""
        if len(frames) < window_size:
            return frames
            
        smoothed = []
        half_window = window_size // 2
        
        for i in range(len(frames)):
            start = max(0, i - half_window)
            end = min(len(frames), i + half_window + 1)
            
            # Average frames in window
            window_frames = frames[start:end]
            smoothed_frame = np.mean(window_frames, axis=0).astype(np.uint8)
            smoothed.append(smoothed_frame)
            
        return smoothed
    
    def generate_lip_sync_frames(
        self,
        face_sequence: List[np.ndarray],
        audio_features: np.ndarray,
        batch_size: int = 64
    ) -> List[np.ndarray]:
        """Generate lip-sync frames using Wav2Lip model."""
        if self.model is None:
            self.load_model()
            
        generated_frames = []
        num_frames = len(face_sequence)
        
        # Process in batches
        for i in range(0, num_frames, batch_size):
            batch_end = min(i + batch_size, num_frames)
            batch_faces = face_sequence[i:batch_end]
            
            # Prepare batch tensors
            face_tensors = []
            for face in batch_faces:
                # Convert to tensor [C, H, W]
                face_tensor = torch.from_numpy(face).permute(2, 0, 1).float()
                face_tensors.append(face_tensor)
            
            # Stack and add batch dimension
            if len(face_tensors) > 0:
                face_batch = torch.stack(face_tensors).to(self.device)
                
                # Get corresponding audio features
                audio_batch = torch.from_numpy(
                    audio_features[:, i:batch_end]
                ).float().to(self.device)
                
                # Pad audio if needed
                if audio_batch.shape[1] < face_batch.shape[0]:
                    pad_len = face_batch.shape[0] - audio_batch.shape[1]
                    audio_batch = torch.nn.functional.pad(
                        audio_batch, (0, pad_len), mode='replicate'
                    )
                
                # Generate lip-sync frames (dummy implementation)
                with torch.no_grad():
                    # In real implementation, this would call the actual Wav2Lip model
                    # For now, simulate lip movement by modifying lower face region
                    output = self._simulate_lip_sync(face_batch, audio_batch)
                
                # Convert back to numpy
                for j in range(output.shape[0]):
                    frame = output[j].permute(1, 2, 0).cpu().numpy()
                    frame = ((frame + 1) * 127.5).clip(0, 255).astype(np.uint8)
                    generated_frames.append(frame)
        
        return generated_frames
    
    def _simulate_lip_sync(
        self,
        face_batch: torch.Tensor,
        audio_batch: torch.Tensor
    ) -> torch.Tensor:
        """Simulate lip-sync by modifying face based on audio energy."""
        # Simple simulation: modify lower half of face based on audio energy
        audio_energy = audio_batch.mean(dim=(0, 1))
        
        output = face_batch.clone()
        _, _, h, w = output.shape
        
        # Lower face region (roughly where mouth is)
        y_start = h // 2
        
        for i in range(output.shape[0]):
            # Get audio energy for this frame
            energy = audio_energy[i % len(audio_energy)] if i < len(audio_energy) else 0.5
            
            # Simulate mouth opening based on energy
            mouth_open = 0.5 + 0.3 * torch.sin(energy * 10)
            
            # Modify lower face region
            output[i, :, y_start:, :] = output[i, :, y_start:, :] * mouth_open
        
        return output
    
    def process_video(
        self,
        face_path: str,
        audio_path: str,
        output_path: str,
        static: bool = False,
        fps: float = 25.0,
        resize_factor: int = 1,
        rotate: bool = False,
        nosmooth: bool = False,
        pads: List[int] = None,
        crop: List[int] = None,
        box: List[int] = None,
        face_det_batch_size: int = 8,
        wav2lip_batch_size: int = 64,
        img_size: int = 96
    ) -> str:
        """Main video processing pipeline."""
        
        self.img_size = img_size
        pads = pads or [0, 10, 0, 0]
        
        # Load face video/image
        is_image = face_path.lower().endswith(('.jpg', '.jpeg', '.png'))
        
        if is_image:
            # Single image - create video from static frame
            face_img = cv2.imread(face_path)
            if face_img is None:
                raise ValueError(f"Could not load image: {face_path}")
            
            # Get audio duration
            mel_spec, wav, sr = self.extract_audio_features(audio_path)
            duration = len(wav) / sr
            
            # Create frame sequence
            num_frames = int(duration * fps)
            frame_sequence = [face_img.copy() for _ in range(num_frames)]
        else:
            # Video file
            cap = cv2.VideoCapture(face_path)
            if not cap.isOpened():
                raise ValueError(f"Could not open video: {face_path}")
            
            frame_sequence = []
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                frame_sequence.append(frame)
            cap.release()
            
            # Get audio
            mel_spec, wav, sr = self.extract_audio_features(audio_path)
        
        if len(frame_sequence) == 0:
            raise ValueError("No frames extracted from face input")
        
        logger.info(f"Processing {len(frame_sequence)} frames")
        
        # Apply resize factor
        if resize_factor > 1:
            new_frames = []
            for frame in frame_sequence:
                h, w = frame.shape[:2]
                new_frames.append(cv2.resize(
                    frame, (w // resize_factor, h // resize_factor)
                ))
            frame_sequence = new_frames
        
        # Rotate if needed
        if rotate:
            frame_sequence = [cv2.rotate(f, cv2.ROTATE_90_CLOCKWISE) for f in frame_sequence]
        
        # Detect faces in first frame
        faces = self.detect_faces(frame_sequence[0])
        if len(faces) == 0:
            raise ValueError("No face detected in input")
        
        # Use first detected face or specified box
        if box and box[0] != -1:
            face_box = tuple(box)
        else:
            face_box = faces[0]
        
        # Apply padding to face box
        x1, y1, x2, y2 = face_box
        pad_t, pad_b, pad_l, pad_r = pads
        h, w = frame_sequence[0].shape[:2]
        
        x1 = max(0, x1 - pad_l)
        y1 = max(0, y1 - pad_t)
        x2 = min(w, x2 + pad_r)
        y2 = min(h, y2 + pad_b)
        
        # Extract face regions
        face_regions = []
        for frame in frame_sequence:
            face_region = frame[y1:y2, x1:x2].copy()
            face_region = self.preprocess_face(face_region, img_size)
            face_regions.append(face_region)
        
        # Generate lip-sync frames
        logger.info("Generating lip-sync frames...")
        lip_sync_faces = self.generate_lip_sync_frames(
            face_regions, mel_spec, wav2lip_batch_size
        )
        
        # Apply smoothing if not disabled
        if not nosmooth and len(lip_sync_faces) > 1:
            lip_sync_faces = self.smooth_lip_region(lip_sync_faces)
        
        # Composite back to original frames
        output_frames = []
        for i, (original, new_face) in enumerate(zip(frame_sequence, lip_sync_faces)):
            result = original.copy()
            
            # Resize generated face back to original size
            face_h, face_w = y2 - y1, x2 - x1
            new_face_resized = cv2.resize(new_face, (face_w, face_h))
            
            # Composite with blending for natural look
            # Create mask for smooth blending
            mask = np.ones((face_h, face_w), dtype=np.float32)
            
            # Feather edges
            feather = 10
            mask[:feather, :] *= np.linspace(0, 1, feather)[:, None]
            mask[-feather:, :] *= np.linspace(1, 0, feather)[:, None]
            mask[:, :feather] *= np.linspace(0, 1, feather)[None, :]
            mask[:, -feather:] *= np.linspace(1, 0, feather)[None, :]
            
            mask = mask[:, :, None]
            
            # Blend
            roi = result[y1:y2, x1:x2]
            blended = (new_face_resized * mask + roi * (1 - mask)).astype(np.uint8)
            result[y1:y2, x1:x2] = blended
            
            output_frames.append(result)
        
        # Write output video (without audio first)
        temp_video = output_path.replace('.mp4', '_temp.mp4')
        
        # Determine output size
        out_h, out_w = output_frames[0].shape[:2]
        
        # Write video
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(temp_video, fourcc, fps, (out_w, out_h))
        
        for frame in output_frames:
            out.write(frame)
        out.release()
        
        # Add audio using ffmpeg
        try:
            ffmpeg_cmd = [
                'ffmpeg', '-y',
                '-i', temp_video,
                '-i', audio_path,
                '-c:v', 'copy',
                '-c:a', 'aac',
                '-shortest',
                output_path
            ]
            subprocess.run(ffmpeg_cmd, check=True, capture_output=True)
            os.remove(temp_video)
            logger.info(f"Successfully created: {output_path}")
        except Exception as e:
            logger.warning(f"FFmpeg audio merge failed: {e}")
            # Fallback: just rename temp video
            os.rename(temp_video, output_path)
        
        return output_path


# Global inference instance cache
_inference_cache = {}


def run_inference(
    checkpoint_path: str,
    face_path: str,
    audio_path: str,
    output_filename: str,
    static: bool = False,
    fps: float = 25.0,
    resize_factor: int = 1,
    rotate: bool = False,
    nosmooth: bool = False,
    pads: List[int] = None,
    crop: List[int] = None,
    box: List[int] = None,
    face_det_batch_size: int = 8,
    wav2lip_batch_size: int = 64,
    img_size: int = 96
) -> str:
    """
    Run Wav2Lip inference to generate lip-sync video.
    
    This is the main entry point used by the Streamlit application.
    """
    
    # Validate inputs
    if not os.path.exists(face_path):
        raise FileNotFoundError(f"Face file not found: {face_path}")
    if not os.path.exists(audio_path):
        raise FileNotFoundError(f"Audio file not found: {audio_path}")
    
    # Ensure output directory exists
    os.makedirs(os.path.dirname(output_filename) or '.', exist_ok=True)
    
    logger.info(f"Running inference with model: {checkpoint_path}")
    logger.info(f"Face: {face_path}, Audio: {audio_path}")
    logger.info(f"Output: {output_filename}")
    logger.info(f"Settings: static={static}, fps={fps}, resize={resize_factor}")
    
    # Get or create inference instance
    cache_key = f"{checkpoint_path}_{img_size}"
    if cache_key not in _inference_cache:
        _inference_cache[cache_key] = Wav2LipInference(checkpoint_path)
    
    inference = _inference_cache[cache_key]
    
    # Run processing
    result_path = inference.process_video(
        face_path=face_path,
        audio_path=audio_path,
        output_path=output_filename,
        static=static,
        fps=fps,
        resize_factor=resize_factor,
        rotate=rotate,
        nosmooth=nosmooth,
        pads=pads,
        crop=crop,
        box=box,
        face_det_batch_size=face_det_batch_size,
        wav2lip_batch_size=wav2lip_batch_size,
        img_size=img_size
    )
    
    return result_path


# Cleanup function
def clear_inference_cache():
    """Clear cached inference instances."""
    global _inference_cache
    _inference_cache = {}
    torch.cuda.empty_cache() if torch.cuda.is_available() else None