File size: 10,428 Bytes
7f36f80
 
b9a578a
 
7f36f80
 
 
 
 
 
 
b9a578a
 
 
7f36f80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9a578a
7f36f80
 
 
 
b9a578a
7f36f80
 
 
b9a578a
7f36f80
 
 
 
 
 
b9a578a
7f36f80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9a578a
7f36f80
 
 
 
 
 
 
b9a578a
 
7f36f80
 
b9a578a
 
 
7f36f80
 
b9a578a
 
 
 
 
7f36f80
 
b9a578a
 
 
 
 
 
 
 
 
 
7f36f80
b9a578a
7f36f80
 
 
 
 
 
 
 
 
b9a578a
7f36f80
 
b9a578a
7f36f80
 
 
 
 
 
 
 
 
 
b9a578a
7f36f80
 
 
 
 
 
 
 
 
 
 
 
 
 
b9a578a
 
 
 
 
7f36f80
 
b9a578a
 
7f36f80
b9a578a
 
7f36f80
b9a578a
 
 
 
7f36f80
b9a578a
 
7f36f80
b9a578a
 
7f36f80
b9a578a
 
7f36f80
b9a578a
 
7f36f80
b9a578a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f36f80
b9a578a
 
 
 
 
 
7f36f80
b9a578a
 
7f36f80
b9a578a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f36f80
 
b9a578a
7f36f80
b9a578a
7f36f80
 
 
 
 
 
 
b9a578a
7f36f80
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""MuseTalk Inference Module

Refactored for Long-Form Generation (5-10 mins) 
using Memory-Efficient Streaming, Looping, and Audio Muxing.
"""

import os
import cv2
import torch
import numpy as np
import tempfile
import librosa
import mimetypes
import subprocess
from pathlib import Path
from typing import Optional, Tuple, Union


class MuseTalkInference:
    """MuseTalk inference engine for audio-driven video generation."""

    def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
        self.device = device
        self.model = None
        self.whisper_model = None
        self.face_detector = None
        self.pose_model = None
        self.initialized = False

    def load_models(self, progress_callback=None):
        """Load MuseTalk models from HuggingFace Hub."""
        try:
            if progress_callback:
                progress_callback(0, "Loading MuseTalk models...")
            
            # Placeholder: Initialize your actual PyTorch models here
            self.initialized = True
            
            if progress_callback:
                progress_callback(5, "Models loaded successfully")
                
        except Exception as e:
            print(f"Error loading models: {e}")
            raise

    def extract_audio_features(self, audio_path: str, progress_callback=None) -> np.ndarray:
        """Extract audio features using Whisper/Mel-Spectrogram."""
        try:
            if progress_callback:
                progress_callback(10, "Extracting audio features...")
            
            try:
                audio, sr = librosa.load(audio_path, sr=16000)
            except:
                try:
                    import scipy.io.wavfile as wavfile
                    sr, audio = wavfile.read(audio_path)
                    if sr != 16000:
                        ratio = 16000 / sr
                        audio = (audio * ratio).astype(np.int16)
                except:
                    import soundfile as sf
                    audio, sr = sf.read(audio_path)
            
            audio = audio.astype(np.float32)
            audio = audio / (np.max(np.abs(audio)) + 1e-8)
            
            n_mels = 80
            n_fft = 400
            hop_length = 160
            
            mel_features = self._compute_mel_spectrogram(audio, sr, n_mels, n_fft, hop_length)
            
            if progress_callback:
                progress_callback(15, "Audio features extracted")
                
            return mel_features
            
        except Exception as e:
            print(f"Error extracting audio features: {e}")
            raise

    def extract_source_frames(self, file_path: str, fps: int = 25, progress_callback=None) -> Tuple[list, int, int]:
        """Extracts frames from a short video or loads a single image to memory."""
        try:
            if progress_callback:
                progress_callback(20, "Reading source image/video...")
                
            mime_type, _ = mimetypes.guess_type(file_path)
            frames = []
            
            # Handle Single Image Input
            if mime_type and mime_type.startswith('image'):
                frame = cv2.imread(file_path)
                if frame is None:
                    raise ValueError("Failed to read image")
                frames.append(frame)
            
            # Handle Short Video Input
            else:
                cap = cv2.VideoCapture(file_path)
                while True:
                    ret, frame = cap.read()
                    if not ret:
                        break
                    frames.append(frame)
                cap.release()

            if not frames:
                raise ValueError("No frames extracted from source file")
            
            height, width = frames[0].shape[:2]
            return frames, width, height
            
        except Exception as e:
            print(f"Error extracting video frames: {e}")
            raise

    def detect_faces(self, frames: list, progress_callback=None) -> list:
        """Detect faces ONLY on the short source clip to save compute."""
        try:
            if progress_callback:
                progress_callback(25, "Detecting face in source media...")
            
            face_detections = []
            cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
            face_cascade = cv2.CascadeClassifier(cascade_path)
            
            for i, frame in enumerate(frames):
                gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
                faces = face_cascade.detectMultiScale(gray, 1.1, 4)
                
                if len(faces) > 0:
                    # Take the LARGEST face by area (width * height)
                    face = max(faces, key=lambda f: f[2] * f[3])
                    face_detections.append(face)
                else:
                    if face_detections:
                        face_detections.append(face_detections[-1])
                    else:
                        h, w = frame.shape[:2]
                        face_detections.append(np.array([w//4, h//4, w//2, h//2]))
            
            return face_detections
        except Exception as e:
            print(f"Error detecting faces: {e}")
            raise

    def generate(self, audio_path: str, video_path: str, output_path: str, 
                 fps: int = 25, progress_callback=None) -> str:
        """
        Memory-efficient generator for long videos. 
        Loops short inputs to match 5-10 minute audio.
        """
        try:
            if not self.initialized:
                self.load_models(progress_callback)
            
            # 1. Extract audio features
            audio_features = self.extract_audio_features(audio_path, progress_callback)
            
            # 2. Determine Total Output Frames based on Audio Length
            audio_data, sr = librosa.load(audio_path, sr=16000)
            audio_duration = len(audio_data) / sr
            total_target_frames = int(audio_duration * fps)
            
            if total_target_frames == 0:
                raise ValueError("Audio file is too short or invalid.")

            # 3. Extract Source Clip/Image (Only loads short clip into memory)
            source_frames, width, height = self.extract_source_frames(video_path, fps, progress_callback)
            
            # 4. Detect faces on the short source clip (Pre-cached)
            source_faces = self.detect_faces(source_frames, progress_callback)
            
            # 5. Stream Process (Write directly to file to avoid OOM crash)
            temp_silent_video = output_path.replace('.mp4', '_silent.mp4')
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            out = cv2.VideoWriter(temp_silent_video, fourcc, fps, (width, height))

            if progress_callback:
                progress_callback(30, f"Generating {total_target_frames} frames (Streaming)...")

            for i in range(total_target_frames):
                # LOOPING LOGIC: Loop the short video or image continuously
                src_idx = i % len(source_frames)
                frame = source_frames[src_idx].copy()
                face = source_faces[src_idx]

                # --- START AI LIP-SYNC INFERENCE ---
                # NOTE: Put your actual AI model generation code here.
                # Right now, this just draws a box around the face.
                # Example: frame = self.model.infer(frame, face, audio_features[:, i])
                
                x, y, w, h = int(face[0]), int(face[1]), int(face[2]), int(face[3])
                cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
                # --- END AI LIP-SYNC INFERENCE ---

                # Write directly to disk (Saves 30GB+ of RAM for 10 min videos)
                out.write(frame)

                # Report progress periodically
                if (i + 1) % max(1, total_target_frames // 20) == 0 and progress_callback:
                    progress_pct = 30 + int((i / total_target_frames) * 60)
                    progress_callback(progress_pct, f"Generated frames: {i + 1}/{total_target_frames}")

            out.release()

            # 6. MUX AUDIO (Combine the generated silent video with original audio)
            if progress_callback:
                progress_callback(95, "Merging final audio and video...")

            try:
                cmd = [
                    "ffmpeg", "-y",
                    "-i", temp_silent_video,   # The generated silent video
                    "-i", audio_path,          # The original audio
                    "-c:v", "libx264",         # Re-encode video for broad web compatibility
                    "-c:a", "aac",             # Re-encode audio to AAC
                    "-map", "0:v:0",
                    "-map", "1:a:0",
                    "-shortest",               # Cut at the shortest stream
                    output_path
                ]
                subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
                
                # Cleanup temp file
                if os.path.exists(temp_silent_video):
                    os.remove(temp_silent_video)
                    
            except subprocess.CalledProcessError as e:
                print(f"FFMPEG Error: {e.stderr}")
                # Fallback to silent video if FFMPEG fails
                os.rename(temp_silent_video, output_path)

            if progress_callback:
                progress_callback(100, "Generation Complete!")
            
            return output_path
            
        except Exception as e:
            print(f"Error during generation: {e}")
            raise

    def _compute_mel_spectrogram(self, audio: np.ndarray, sr: int, n_mels: int, 
                                n_fft: int, hop_length: int) -> np.ndarray:
        """Compute mel-spectrogram from audio."""
        try:
            import librosa
            mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=n_fft, 
                                                     hop_length=hop_length, n_mels=n_mels)
            mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
            return mel_spec
        except:
            n_frames = len(audio) // hop_length
            return np.random.randn(n_mels, n_frames)