infinitetalk / app.py
ShalomKing's picture
Upload app.py with huggingface_hub
a209921 verified
"""
InfiniteTalk - Talking Video Generator
Gradio Space for HuggingFace
"""
import os
import sys
# CRITICAL: Set environment variables BEFORE any torch/torchvision imports
# This prevents torchvision from registering CUDA ops that don't exist at import time
os.environ["TORCHVISION_DISABLE_META_REGISTRATIONS"] = "1"
os.environ["TORCH_LOGS"] = "-all" # Reduce torch logging noise
import random
import logging
import warnings
from pathlib import Path
import gradio as gr
import torch
import numpy as np
# Suppress warnings
warnings.filterwarnings('ignore')
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Add current directory to path
sys.path.insert(0, str(Path(__file__).parent))
# Import utilities
from utils.model_loader import ModelManager
from utils.gpu_manager import gpu_manager
# Import InfiniteTalk modules
import wan
from wan.configs import SIZE_CONFIGS, WAN_CONFIGS
from wan.utils.utils import cache_image, cache_video, is_video
from wan.utils.multitalk_utils import save_video_ffmpeg
# Audio processing
import librosa
import soundfile as sf
import pyloudnorm as pyln
from transformers import Wav2Vec2FeatureExtractor
from src.audio_analysis.wav2vec2 import Wav2Vec2Model
# Image/Video processing
from PIL import Image
from einops import rearrange
# Global variables
model_manager = None
models_loaded = False
def initialize_models(progress=gr.Progress()):
"""Initialize models on first use"""
global model_manager, models_loaded
if models_loaded:
return
try:
progress(0.1, desc="Initializing model manager...")
model_manager = ModelManager()
progress(0.3, desc="Downloading models (first time only - may take 2-3 minutes)...")
# Download models (lazy loading - they'll be loaded on first inference)
model_manager.get_wan_model_path()
model_manager.get_infinitetalk_weights_path()
model_manager.get_wav2vec_model_path()
models_loaded = True
progress(1.0, desc="Models ready!")
logger.info("Models initialized successfully")
except Exception as e:
logger.error(f"Error initializing models: {e}")
raise gr.Error(f"Failed to initialize models: {str(e)}")
def loudness_norm(audio_array, sr=16000, lufs=-20.0):
"""Normalize audio loudness using pyloudnorm"""
try:
meter = pyln.Meter(sr)
loudness = meter.integrated_loudness(audio_array)
if abs(loudness) > 100: # Skip if loudness measurement failed
return audio_array
normalized_audio = pyln.normalize.loudness(audio_array, loudness, lufs)
return normalized_audio
except Exception as e:
logger.warning(f"Loudness normalization failed: {e}, returning original audio")
return audio_array
def process_audio(audio_path, target_sr=16000):
"""
Process audio file for InfiniteTalk (matches audio_prepare_single from reference)
Args:
audio_path: Path to audio file
target_sr: Target sample rate
Returns:
Processed audio array and sample rate
"""
try:
# Load audio with librosa
audio, sr = librosa.load(audio_path, sr=target_sr)
# Normalize loudness
audio = loudness_norm(audio, sr)
# Ensure mono
if len(audio.shape) > 1:
audio = np.mean(audio, axis=1)
return audio, sr
except Exception as e:
logger.error(f"Error processing audio: {e}")
raise gr.Error(f"Audio processing failed: {str(e)}")
def validate_inputs(image_or_video, audio, resolution, steps):
"""Validate user inputs"""
errors = []
if image_or_video is None:
errors.append("Please upload an image or video")
if audio is None:
errors.append("Please upload an audio file")
if resolution not in ["480p", "720p"]:
errors.append("Invalid resolution selected")
if not (20 <= steps <= 50):
errors.append("Steps must be between 20 and 50")
if errors:
raise gr.Error(" | ".join(errors))
def generate_video(
image_or_video,
audio_file,
resolution="480p",
steps=40,
audio_guide_scale=3.0,
seed=-1,
progress=gr.Progress()
):
"""
Generate talking video from image or dub existing video
Args:
image_or_video: Input image or video file
audio_file: Audio file for lip-sync
resolution: Output resolution (480p or 720p)
steps: Number of diffusion steps
audio_guide_scale: Audio conditioning strength
seed: Random seed for reproducibility
progress: Gradio progress tracker
Returns:
Path to generated video
"""
try:
# Check if GPU is available
if not torch.cuda.is_available():
raise gr.Error(
"⚠️ GPU not available. This Space requires GPU hardware to generate videos. "
"Please apply for a Community GPU Grant in the Space settings, or run this app locally with a GPU."
)
# Initialize models if needed
if not models_loaded:
initialize_models(progress)
# Validate inputs
validate_inputs(image_or_video, audio_file, resolution, steps)
# GPU memory check
gpu_manager.print_memory_usage("Initial - ")
progress(0.1, desc="Processing audio...")
# Process audio
audio, sr = process_audio(audio_file)
audio_duration = len(audio) / sr
logger.info(f"Audio duration: {audio_duration:.2f}s")
# Calculate ZeroGPU duration
zerogpu_duration = gpu_manager.calculate_duration_for_zerogpu(
audio_duration, resolution
)
progress(0.2, desc="Loading models...")
# Load models
size = f"infinitetalk-{resolution.replace('p', '')}"
# Load InfiniteTalk pipeline
wan_pipeline = model_manager.load_wan_model(size=size, device="cuda")
# Load audio encoder
audio_encoder, feature_extractor = model_manager.load_audio_encoder(device="cuda")
gpu_manager.print_memory_usage("After model loading - ")
progress(0.3, desc="Processing input...")
# Determine if input is image or video
is_input_video = is_video(image_or_video)
if is_input_video:
logger.info("Processing video dubbing...")
input_frames = cache_video(image_or_video)
else:
logger.info("Processing image-to-video...")
input_image = Image.open(image_or_video).convert("RGB")
input_frames = [input_image]
progress(0.4, desc="Extracting audio features...")
# Extract audio features (matches get_embedding from reference)
audio_duration = len(audio) / sr
video_length = audio_duration * 25 # Assume 25 FPS
# Extract features with wav2vec
audio_feature = np.squeeze(
feature_extractor(audio, sampling_rate=sr).input_values
)
audio_feature = torch.from_numpy(audio_feature).float().to(device="cuda")
audio_feature = audio_feature.unsqueeze(0)
# Get embeddings from audio encoder
with torch.no_grad():
embeddings = audio_encoder(audio_feature, seq_len=int(video_length), output_hidden_states=True)
if len(embeddings) == 0 or not hasattr(embeddings, 'hidden_states'):
raise gr.Error("Failed to extract audio embeddings")
# Stack hidden states (matches reference implementation)
from einops import rearrange
audio_embeddings = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
audio_embeddings = rearrange(audio_embeddings, "b s d -> s b d")
audio_embeddings = audio_embeddings.cpu().detach()
logger.info(f"Audio embeddings shape: {audio_embeddings.shape}")
gpu_manager.print_memory_usage("After audio processing - ")
progress(0.5, desc="Generating video (this may take a minute)...")
# Set random seed
if seed == -1:
seed = random.randint(0, 99999999)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# Generate video with InfiniteTalk
output_path = f"/tmp/output_{seed}.mp4"
# Prepare input for pipeline (following generate_infinitetalk.py structure)
with torch.no_grad():
logger.info(f"Generating {resolution} video with {steps} steps...")
# Save audio embeddings to temporary file (pipeline expects file path)
import tempfile
os.makedirs("/tmp/audio_embeddings", exist_ok=True)
emb_path = "/tmp/audio_embeddings/1.pt"
audio_wav_path = "/tmp/audio_embeddings/sum.wav"
torch.save(audio_embeddings, emb_path)
sf.write(audio_wav_path, audio, sr)
# Prepare input dictionary (matches generate_infinitetalk.py format)
input_clip = {
"prompt": "", # Empty prompt for talking head
"cond_video": image_or_video,
"cond_audio": {
"person1": emb_path
},
"video_audio": audio_wav_path
}
# Calculate sample_shift based on resolution
sample_shift = 7 if resolution == "480p" else 11
# Call InfiniteTalk pipeline
video_tensor = wan_pipeline.generate_infinitetalk(
input_clip,
size_buckget=size,
motion_frame=9, # Default motion frame
frame_num=81, # Default frame num (4n+1 format)
shift=sample_shift,
sampling_steps=steps,
text_guide_scale=5.0, # Default text guidance
audio_guide_scale=audio_guide_scale,
seed=seed,
offload_model=True,
max_frames_num=81, # For clip mode
color_correction_strength=1.0,
extra_args=None
)
# Save video with audio
from wan.utils.multitalk_utils import save_video_ffmpeg
save_video_ffmpeg(
video_tensor,
output_path.replace(".mp4", ""), # Function adds .mp4 extension
[audio_wav_path],
high_quality_save=False
)
progress(0.9, desc="Finalizing...")
# Cleanup
gpu_manager.cleanup()
progress(1.0, desc="Complete!")
logger.info(f"Video generated successfully: {output_path}")
return output_path
except Exception as e:
logger.error(f"Error generating video: {e}")
gpu_manager.cleanup()
raise gr.Error(f"Generation failed: {str(e)}")
def create_interface():
"""Create Gradio interface"""
with gr.Blocks(title="InfiniteTalk - Talking Video Generator", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎬 InfiniteTalk - Talking Video Generator
Generate realistic talking head videos with accurate lip-sync from images or dub existing videos with new audio!
**Note**: First generation may take 2-3 minutes while models download. Subsequent generations are much faster (~40s for 10s video).
""")
with gr.Tabs():
# Tab 1: Image-to-Video
with gr.Tab("📸 Image-to-Video"):
gr.Markdown("Transform a static portrait into a talking video")
with gr.Row():
with gr.Column():
image_input = gr.Image(
type="filepath",
label="Upload Portrait Image (clear face visibility recommended)"
)
audio_input_i2v = gr.Audio(
type="filepath",
label="Upload Audio (MP3, WAV, or FLAC)"
)
with gr.Accordion("Advanced Settings", open=False):
resolution_i2v = gr.Radio(
choices=["480p", "720p"],
value="480p",
label="Resolution (480p faster, 720p higher quality)"
)
steps_i2v = gr.Slider(
minimum=20,
maximum=50,
value=40,
step=1,
label="Diffusion Steps (more = higher quality but slower)"
)
audio_scale_i2v = gr.Slider(
minimum=1.0,
maximum=5.0,
value=3.0,
step=0.5,
label="Audio Guide Scale (2-4 recommended)"
)
seed_i2v = gr.Number(
value=-1,
label="Seed (-1 for random)"
)
generate_btn_i2v = gr.Button("🎬 Generate Video", variant="primary", size="lg")
with gr.Column():
output_video_i2v = gr.Video(label="Generated Video")
gr.Markdown("**💡 Tip**: Use high-quality portrait images with clear facial features for best results")
generate_btn_i2v.click(
fn=generate_video,
inputs=[image_input, audio_input_i2v, resolution_i2v, steps_i2v, audio_scale_i2v, seed_i2v],
outputs=output_video_i2v
)
# Tab 2: Video Dubbing
with gr.Tab("🎥 Video Dubbing"):
gr.Markdown("Dub an existing video with new audio while maintaining natural movements")
with gr.Row():
with gr.Column():
video_input = gr.Video(
label="Upload Video (with visible face)"
)
audio_input_v2v = gr.Audio(
type="filepath",
label="Upload New Audio (MP3, WAV, or FLAC)"
)
with gr.Accordion("Advanced Settings", open=False):
resolution_v2v = gr.Radio(
choices=["480p", "720p"],
value="480p",
label="Resolution"
)
steps_v2v = gr.Slider(
minimum=20,
maximum=50,
value=40,
step=1,
label="Diffusion Steps"
)
audio_scale_v2v = gr.Slider(
minimum=1.0,
maximum=5.0,
value=3.0,
step=0.5,
label="Audio Guide Scale"
)
seed_v2v = gr.Number(
value=-1,
label="Seed"
)
generate_btn_v2v = gr.Button("🎬 Generate Dubbed Video", variant="primary", size="lg")
with gr.Column():
output_video_v2v = gr.Video(label="Dubbed Video")
gr.Markdown("**💡 Tip**: For best results, use videos with consistent face visibility throughout")
generate_btn_v2v.click(
fn=generate_video,
inputs=[video_input, audio_input_v2v, resolution_v2v, steps_v2v, audio_scale_v2v, seed_v2v],
outputs=output_video_v2v
)
# Footer
gr.Markdown("""
---
### About
Powered by [InfiniteTalk](https://github.com/MeiGen-AI/InfiniteTalk) - Apache 2.0 License
⚠️ **Note**: This Space requires GPU hardware to generate videos. Apply for a Community GPU Grant in Settings.
💡 **Tips**:
- First generation downloads models (~15GB) and may take 2-3 minutes
- Use 480p for faster generation (~40s for 10s video)
- Use 720p for higher quality (slower but better results)
- Clear, well-lit images produce the best results
""")
return demo
if __name__ == "__main__":
demo = create_interface()
demo.queue(max_size=10)
demo.launch()