File size: 16,840 Bytes
38572a2
 
a209921
38572a2
 
 
 
f076b1f
 
a209921
f076b1f
 
 
38572a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc5110c
 
 
 
 
 
 
 
 
 
 
 
 
 
38572a2
 
bc5110c
38572a2
 
 
 
 
 
 
 
 
bc5110c
 
38572a2
 
bc5110c
38572a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a209921
 
 
 
 
 
 
38572a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc5110c
 
38572a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc5110c
 
 
38572a2
bc5110c
 
 
 
 
 
38572a2
bc5110c
38572a2
bc5110c
38572a2
bc5110c
 
 
 
 
 
 
 
 
 
38572a2
 
 
 
 
 
 
 
 
 
 
 
bc5110c
38572a2
 
bc5110c
38572a2
 
 
bc5110c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38572a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459be94
38572a2
 
 
459be94
38572a2
 
 
 
 
 
459be94
38572a2
 
 
 
 
 
459be94
38572a2
 
 
 
 
 
459be94
38572a2
 
 
459be94
38572a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459be94
38572a2
 
 
459be94
38572a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a209921
38572a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
"""
InfiniteTalk - Talking Video Generator
Gradio Space for HuggingFace
"""

import os
import sys

# CRITICAL: Set environment variables BEFORE any torch/torchvision imports
# This prevents torchvision from registering CUDA ops that don't exist at import time
os.environ["TORCHVISION_DISABLE_META_REGISTRATIONS"] = "1"
os.environ["TORCH_LOGS"] = "-all"  # Reduce torch logging noise

import random
import logging
import warnings
from pathlib import Path

import gradio as gr
import torch
import numpy as np

# Suppress warnings
warnings.filterwarnings('ignore')

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Add current directory to path
sys.path.insert(0, str(Path(__file__).parent))

# Import utilities
from utils.model_loader import ModelManager
from utils.gpu_manager import gpu_manager

# Import InfiniteTalk modules
import wan
from wan.configs import SIZE_CONFIGS, WAN_CONFIGS
from wan.utils.utils import cache_image, cache_video, is_video
from wan.utils.multitalk_utils import save_video_ffmpeg

# Audio processing
import librosa
import soundfile as sf
import pyloudnorm as pyln
from transformers import Wav2Vec2FeatureExtractor
from src.audio_analysis.wav2vec2 import Wav2Vec2Model

# Image/Video processing
from PIL import Image
from einops import rearrange

# Global variables
model_manager = None
models_loaded = False


def initialize_models(progress=gr.Progress()):
    """Initialize models on first use"""
    global model_manager, models_loaded

    if models_loaded:
        return

    try:
        progress(0.1, desc="Initializing model manager...")
        model_manager = ModelManager()

        progress(0.3, desc="Downloading models (first time only - may take 2-3 minutes)...")

        # Download models (lazy loading - they'll be loaded on first inference)
        model_manager.get_wan_model_path()
        model_manager.get_infinitetalk_weights_path()
        model_manager.get_wav2vec_model_path()

        models_loaded = True
        progress(1.0, desc="Models ready!")
        logger.info("Models initialized successfully")

    except Exception as e:
        logger.error(f"Error initializing models: {e}")
        raise gr.Error(f"Failed to initialize models: {str(e)}")


def loudness_norm(audio_array, sr=16000, lufs=-20.0):
    """Normalize audio loudness using pyloudnorm"""
    try:
        meter = pyln.Meter(sr)
        loudness = meter.integrated_loudness(audio_array)
        if abs(loudness) > 100:  # Skip if loudness measurement failed
            return audio_array
        normalized_audio = pyln.normalize.loudness(audio_array, loudness, lufs)
        return normalized_audio
    except Exception as e:
        logger.warning(f"Loudness normalization failed: {e}, returning original audio")
        return audio_array


def process_audio(audio_path, target_sr=16000):
    """
    Process audio file for InfiniteTalk (matches audio_prepare_single from reference)

    Args:
        audio_path: Path to audio file
        target_sr: Target sample rate

    Returns:
        Processed audio array and sample rate
    """
    try:
        # Load audio with librosa
        audio, sr = librosa.load(audio_path, sr=target_sr)

        # Normalize loudness
        audio = loudness_norm(audio, sr)

        # Ensure mono
        if len(audio.shape) > 1:
            audio = np.mean(audio, axis=1)

        return audio, sr

    except Exception as e:
        logger.error(f"Error processing audio: {e}")
        raise gr.Error(f"Audio processing failed: {str(e)}")


def validate_inputs(image_or_video, audio, resolution, steps):
    """Validate user inputs"""
    errors = []

    if image_or_video is None:
        errors.append("Please upload an image or video")

    if audio is None:
        errors.append("Please upload an audio file")

    if resolution not in ["480p", "720p"]:
        errors.append("Invalid resolution selected")

    if not (20 <= steps <= 50):
        errors.append("Steps must be between 20 and 50")

    if errors:
        raise gr.Error(" | ".join(errors))


def generate_video(
    image_or_video,
    audio_file,
    resolution="480p",
    steps=40,
    audio_guide_scale=3.0,
    seed=-1,
    progress=gr.Progress()
):
    """
    Generate talking video from image or dub existing video

    Args:
        image_or_video: Input image or video file
        audio_file: Audio file for lip-sync
        resolution: Output resolution (480p or 720p)
        steps: Number of diffusion steps
        audio_guide_scale: Audio conditioning strength
        seed: Random seed for reproducibility
        progress: Gradio progress tracker

    Returns:
        Path to generated video
    """
    try:
        # Check if GPU is available
        if not torch.cuda.is_available():
            raise gr.Error(
                "⚠️ GPU not available. This Space requires GPU hardware to generate videos. "
                "Please apply for a Community GPU Grant in the Space settings, or run this app locally with a GPU."
            )

        # Initialize models if needed
        if not models_loaded:
            initialize_models(progress)

        # Validate inputs
        validate_inputs(image_or_video, audio_file, resolution, steps)

        # GPU memory check
        gpu_manager.print_memory_usage("Initial - ")

        progress(0.1, desc="Processing audio...")

        # Process audio
        audio, sr = process_audio(audio_file)
        audio_duration = len(audio) / sr
        logger.info(f"Audio duration: {audio_duration:.2f}s")

        # Calculate ZeroGPU duration
        zerogpu_duration = gpu_manager.calculate_duration_for_zerogpu(
            audio_duration, resolution
        )

        progress(0.2, desc="Loading models...")

        # Load models
        size = f"infinitetalk-{resolution.replace('p', '')}"

        # Load InfiniteTalk pipeline
        wan_pipeline = model_manager.load_wan_model(size=size, device="cuda")

        # Load audio encoder
        audio_encoder, feature_extractor = model_manager.load_audio_encoder(device="cuda")

        gpu_manager.print_memory_usage("After model loading - ")

        progress(0.3, desc="Processing input...")

        # Determine if input is image or video
        is_input_video = is_video(image_or_video)

        if is_input_video:
            logger.info("Processing video dubbing...")
            input_frames = cache_video(image_or_video)
        else:
            logger.info("Processing image-to-video...")
            input_image = Image.open(image_or_video).convert("RGB")
            input_frames = [input_image]

        progress(0.4, desc="Extracting audio features...")

        # Extract audio features (matches get_embedding from reference)
        audio_duration = len(audio) / sr
        video_length = audio_duration * 25  # Assume 25 FPS

        # Extract features with wav2vec
        audio_feature = np.squeeze(
            feature_extractor(audio, sampling_rate=sr).input_values
        )
        audio_feature = torch.from_numpy(audio_feature).float().to(device="cuda")
        audio_feature = audio_feature.unsqueeze(0)

        # Get embeddings from audio encoder
        with torch.no_grad():
            embeddings = audio_encoder(audio_feature, seq_len=int(video_length), output_hidden_states=True)

        if len(embeddings) == 0 or not hasattr(embeddings, 'hidden_states'):
            raise gr.Error("Failed to extract audio embeddings")

        # Stack hidden states (matches reference implementation)
        from einops import rearrange
        audio_embeddings = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
        audio_embeddings = rearrange(audio_embeddings, "b s d -> s b d")
        audio_embeddings = audio_embeddings.cpu().detach()

        logger.info(f"Audio embeddings shape: {audio_embeddings.shape}")
        gpu_manager.print_memory_usage("After audio processing - ")

        progress(0.5, desc="Generating video (this may take a minute)...")

        # Set random seed
        if seed == -1:
            seed = random.randint(0, 99999999)

        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)

        # Generate video with InfiniteTalk
        output_path = f"/tmp/output_{seed}.mp4"

        # Prepare input for pipeline (following generate_infinitetalk.py structure)
        with torch.no_grad():
            logger.info(f"Generating {resolution} video with {steps} steps...")

            # Save audio embeddings to temporary file (pipeline expects file path)
            import tempfile
            os.makedirs("/tmp/audio_embeddings", exist_ok=True)
            emb_path = "/tmp/audio_embeddings/1.pt"
            audio_wav_path = "/tmp/audio_embeddings/sum.wav"

            torch.save(audio_embeddings, emb_path)
            sf.write(audio_wav_path, audio, sr)

            # Prepare input dictionary (matches generate_infinitetalk.py format)
            input_clip = {
                "prompt": "",  # Empty prompt for talking head
                "cond_video": image_or_video,
                "cond_audio": {
                    "person1": emb_path
                },
                "video_audio": audio_wav_path
            }

            # Calculate sample_shift based on resolution
            sample_shift = 7 if resolution == "480p" else 11

            # Call InfiniteTalk pipeline
            video_tensor = wan_pipeline.generate_infinitetalk(
                input_clip,
                size_buckget=size,
                motion_frame=9,  # Default motion frame
                frame_num=81,  # Default frame num (4n+1 format)
                shift=sample_shift,
                sampling_steps=steps,
                text_guide_scale=5.0,  # Default text guidance
                audio_guide_scale=audio_guide_scale,
                seed=seed,
                offload_model=True,
                max_frames_num=81,  # For clip mode
                color_correction_strength=1.0,
                extra_args=None
            )

            # Save video with audio
            from wan.utils.multitalk_utils import save_video_ffmpeg

            save_video_ffmpeg(
                video_tensor,
                output_path.replace(".mp4", ""),  # Function adds .mp4 extension
                [audio_wav_path],
                high_quality_save=False
            )

        progress(0.9, desc="Finalizing...")

        # Cleanup
        gpu_manager.cleanup()

        progress(1.0, desc="Complete!")

        logger.info(f"Video generated successfully: {output_path}")
        return output_path

    except Exception as e:
        logger.error(f"Error generating video: {e}")
        gpu_manager.cleanup()
        raise gr.Error(f"Generation failed: {str(e)}")


def create_interface():
    """Create Gradio interface"""

    with gr.Blocks(title="InfiniteTalk - Talking Video Generator", theme=gr.themes.Soft()) as demo:
        gr.Markdown("""
        # 🎬 InfiniteTalk - Talking Video Generator

        Generate realistic talking head videos with accurate lip-sync from images or dub existing videos with new audio!

        **Note**: First generation may take 2-3 minutes while models download. Subsequent generations are much faster (~40s for 10s video).
        """)

        with gr.Tabs():
            # Tab 1: Image-to-Video
            with gr.Tab("📸 Image-to-Video"):
                gr.Markdown("Transform a static portrait into a talking video")

                with gr.Row():
                    with gr.Column():
                        image_input = gr.Image(
                            type="filepath",
                            label="Upload Portrait Image (clear face visibility recommended)"
                        )
                        audio_input_i2v = gr.Audio(
                            type="filepath",
                            label="Upload Audio (MP3, WAV, or FLAC)"
                        )

                        with gr.Accordion("Advanced Settings", open=False):
                            resolution_i2v = gr.Radio(
                                choices=["480p", "720p"],
                                value="480p",
                                label="Resolution (480p faster, 720p higher quality)"
                            )
                            steps_i2v = gr.Slider(
                                minimum=20,
                                maximum=50,
                                value=40,
                                step=1,
                                label="Diffusion Steps (more = higher quality but slower)"
                            )
                            audio_scale_i2v = gr.Slider(
                                minimum=1.0,
                                maximum=5.0,
                                value=3.0,
                                step=0.5,
                                label="Audio Guide Scale (2-4 recommended)"
                            )
                            seed_i2v = gr.Number(
                                value=-1,
                                label="Seed (-1 for random)"
                            )

                        generate_btn_i2v = gr.Button("🎬 Generate Video", variant="primary", size="lg")

                    with gr.Column():
                        output_video_i2v = gr.Video(label="Generated Video")
                        gr.Markdown("**💡 Tip**: Use high-quality portrait images with clear facial features for best results")

                generate_btn_i2v.click(
                    fn=generate_video,
                    inputs=[image_input, audio_input_i2v, resolution_i2v, steps_i2v, audio_scale_i2v, seed_i2v],
                    outputs=output_video_i2v
                )

            # Tab 2: Video Dubbing
            with gr.Tab("🎥 Video Dubbing"):
                gr.Markdown("Dub an existing video with new audio while maintaining natural movements")

                with gr.Row():
                    with gr.Column():
                        video_input = gr.Video(
                            label="Upload Video (with visible face)"
                        )
                        audio_input_v2v = gr.Audio(
                            type="filepath",
                            label="Upload New Audio (MP3, WAV, or FLAC)"
                        )

                        with gr.Accordion("Advanced Settings", open=False):
                            resolution_v2v = gr.Radio(
                                choices=["480p", "720p"],
                                value="480p",
                                label="Resolution"
                            )
                            steps_v2v = gr.Slider(
                                minimum=20,
                                maximum=50,
                                value=40,
                                step=1,
                                label="Diffusion Steps"
                            )
                            audio_scale_v2v = gr.Slider(
                                minimum=1.0,
                                maximum=5.0,
                                value=3.0,
                                step=0.5,
                                label="Audio Guide Scale"
                            )
                            seed_v2v = gr.Number(
                                value=-1,
                                label="Seed"
                            )

                        generate_btn_v2v = gr.Button("🎬 Generate Dubbed Video", variant="primary", size="lg")

                    with gr.Column():
                        output_video_v2v = gr.Video(label="Dubbed Video")
                        gr.Markdown("**💡 Tip**: For best results, use videos with consistent face visibility throughout")

                generate_btn_v2v.click(
                    fn=generate_video,
                    inputs=[video_input, audio_input_v2v, resolution_v2v, steps_v2v, audio_scale_v2v, seed_v2v],
                    outputs=output_video_v2v
                )

        # Footer
        gr.Markdown("""
        ---
        ### About
        Powered by [InfiniteTalk](https://github.com/MeiGen-AI/InfiniteTalk) - Apache 2.0 License

        ⚠️ **Note**: This Space requires GPU hardware to generate videos. Apply for a Community GPU Grant in Settings.

        💡 **Tips**:
        - First generation downloads models (~15GB) and may take 2-3 minutes
        - Use 480p for faster generation (~40s for 10s video)
        - Use 720p for higher quality (slower but better results)
        - Clear, well-lit images produce the best results
        """)

    return demo


if __name__ == "__main__":
    demo = create_interface()
    demo.queue(max_size=10)
    demo.launch()