Spaces:
Running
Running
Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -78,9 +78,23 @@ def initialize_models(progress=gr.Progress()):
|
|
| 78 |
raise gr.Error(f"Failed to initialize models: {str(e)}")
|
| 79 |
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
def process_audio(audio_path, target_sr=16000):
|
| 82 |
"""
|
| 83 |
-
Process audio file for InfiniteTalk
|
| 84 |
|
| 85 |
Args:
|
| 86 |
audio_path: Path to audio file
|
|
@@ -90,18 +104,11 @@ def process_audio(audio_path, target_sr=16000):
|
|
| 90 |
Processed audio array and sample rate
|
| 91 |
"""
|
| 92 |
try:
|
| 93 |
-
# Load audio
|
| 94 |
-
audio, sr = librosa.load(audio_path, sr=
|
| 95 |
-
|
| 96 |
-
# Resample if needed
|
| 97 |
-
if sr != target_sr:
|
| 98 |
-
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
|
| 99 |
-
sr = target_sr
|
| 100 |
|
| 101 |
# Normalize loudness
|
| 102 |
-
|
| 103 |
-
loudness = meter.integrated_loudness(audio)
|
| 104 |
-
audio = pyln.normalize.loudness(audio, loudness, -20.0)
|
| 105 |
|
| 106 |
# Ensure mono
|
| 107 |
if len(audio.shape) > 1:
|
|
@@ -187,8 +194,8 @@ def generate_video(
|
|
| 187 |
# Load models
|
| 188 |
size = f"infinitetalk-{resolution.replace('p', '')}"
|
| 189 |
|
| 190 |
-
# Load
|
| 191 |
-
|
| 192 |
|
| 193 |
# Load audio encoder
|
| 194 |
audio_encoder, feature_extractor = model_manager.load_audio_encoder(device="cuda")
|
|
@@ -210,18 +217,31 @@ def generate_video(
|
|
| 210 |
|
| 211 |
progress(0.4, desc="Extracting audio features...")
|
| 212 |
|
| 213 |
-
# Extract audio features
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
sampling_rate=sr,
|
| 217 |
-
return_tensors="pt"
|
| 218 |
-
).input_values
|
| 219 |
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
|
|
|
| 222 |
with torch.no_grad():
|
| 223 |
-
|
| 224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
gpu_manager.print_memory_usage("After audio processing - ")
|
| 226 |
|
| 227 |
progress(0.5, desc="Generating video (this may take a minute)...")
|
|
@@ -234,44 +254,61 @@ def generate_video(
|
|
| 234 |
if torch.cuda.is_available():
|
| 235 |
torch.cuda.manual_seed(seed)
|
| 236 |
|
| 237 |
-
# Generate video
|
| 238 |
-
# This is a placeholder for the actual inference logic
|
| 239 |
-
# The actual implementation would call wan_model.generate() with proper parameters
|
| 240 |
-
|
| 241 |
output_path = f"/tmp/output_{seed}.mp4"
|
| 242 |
|
| 243 |
-
#
|
| 244 |
with torch.no_grad():
|
| 245 |
-
# Parameters
|
| 246 |
-
generation_args = {
|
| 247 |
-
"input_frames": input_frames,
|
| 248 |
-
"audio_embeddings": audio_embeddings,
|
| 249 |
-
"num_steps": steps,
|
| 250 |
-
"audio_guide_scale": audio_guide_scale,
|
| 251 |
-
"size": size,
|
| 252 |
-
"seed": seed,
|
| 253 |
-
}
|
| 254 |
-
|
| 255 |
-
# Call model inference (placeholder)
|
| 256 |
-
# output_frames = wan_model.generate(**generation_args)
|
| 257 |
-
|
| 258 |
-
# For now, just create a dummy output to test the pipeline
|
| 259 |
-
# In production, this would be replaced with actual video generation
|
| 260 |
logger.info(f"Generating {resolution} video with {steps} steps...")
|
| 261 |
|
| 262 |
-
#
|
| 263 |
-
import
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
progress(0.9, desc="Finalizing...")
|
| 277 |
|
|
|
|
| 78 |
raise gr.Error(f"Failed to initialize models: {str(e)}")
|
| 79 |
|
| 80 |
|
| 81 |
+
def loudness_norm(audio_array, sr=16000, lufs=-20.0):
|
| 82 |
+
"""Normalize audio loudness using pyloudnorm"""
|
| 83 |
+
try:
|
| 84 |
+
meter = pyln.Meter(sr)
|
| 85 |
+
loudness = meter.integrated_loudness(audio_array)
|
| 86 |
+
if abs(loudness) > 100: # Skip if loudness measurement failed
|
| 87 |
+
return audio_array
|
| 88 |
+
normalized_audio = pyln.normalize.loudness(audio_array, loudness, lufs)
|
| 89 |
+
return normalized_audio
|
| 90 |
+
except Exception as e:
|
| 91 |
+
logger.warning(f"Loudness normalization failed: {e}, returning original audio")
|
| 92 |
+
return audio_array
|
| 93 |
+
|
| 94 |
+
|
| 95 |
def process_audio(audio_path, target_sr=16000):
|
| 96 |
"""
|
| 97 |
+
Process audio file for InfiniteTalk (matches audio_prepare_single from reference)
|
| 98 |
|
| 99 |
Args:
|
| 100 |
audio_path: Path to audio file
|
|
|
|
| 104 |
Processed audio array and sample rate
|
| 105 |
"""
|
| 106 |
try:
|
| 107 |
+
# Load audio with librosa
|
| 108 |
+
audio, sr = librosa.load(audio_path, sr=target_sr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
# Normalize loudness
|
| 111 |
+
audio = loudness_norm(audio, sr)
|
|
|
|
|
|
|
| 112 |
|
| 113 |
# Ensure mono
|
| 114 |
if len(audio.shape) > 1:
|
|
|
|
| 194 |
# Load models
|
| 195 |
size = f"infinitetalk-{resolution.replace('p', '')}"
|
| 196 |
|
| 197 |
+
# Load InfiniteTalk pipeline
|
| 198 |
+
wan_pipeline = model_manager.load_wan_model(size=size, device="cuda")
|
| 199 |
|
| 200 |
# Load audio encoder
|
| 201 |
audio_encoder, feature_extractor = model_manager.load_audio_encoder(device="cuda")
|
|
|
|
| 217 |
|
| 218 |
progress(0.4, desc="Extracting audio features...")
|
| 219 |
|
| 220 |
+
# Extract audio features (matches get_embedding from reference)
|
| 221 |
+
audio_duration = len(audio) / sr
|
| 222 |
+
video_length = audio_duration * 25 # Assume 25 FPS
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
+
# Extract features with wav2vec
|
| 225 |
+
audio_feature = np.squeeze(
|
| 226 |
+
feature_extractor(audio, sampling_rate=sr).input_values
|
| 227 |
+
)
|
| 228 |
+
audio_feature = torch.from_numpy(audio_feature).float().to(device="cuda")
|
| 229 |
+
audio_feature = audio_feature.unsqueeze(0)
|
| 230 |
|
| 231 |
+
# Get embeddings from audio encoder
|
| 232 |
with torch.no_grad():
|
| 233 |
+
embeddings = audio_encoder(audio_feature, seq_len=int(video_length), output_hidden_states=True)
|
| 234 |
|
| 235 |
+
if len(embeddings) == 0 or not hasattr(embeddings, 'hidden_states'):
|
| 236 |
+
raise gr.Error("Failed to extract audio embeddings")
|
| 237 |
+
|
| 238 |
+
# Stack hidden states (matches reference implementation)
|
| 239 |
+
from einops import rearrange
|
| 240 |
+
audio_embeddings = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
|
| 241 |
+
audio_embeddings = rearrange(audio_embeddings, "b s d -> s b d")
|
| 242 |
+
audio_embeddings = audio_embeddings.cpu().detach()
|
| 243 |
+
|
| 244 |
+
logger.info(f"Audio embeddings shape: {audio_embeddings.shape}")
|
| 245 |
gpu_manager.print_memory_usage("After audio processing - ")
|
| 246 |
|
| 247 |
progress(0.5, desc="Generating video (this may take a minute)...")
|
|
|
|
| 254 |
if torch.cuda.is_available():
|
| 255 |
torch.cuda.manual_seed(seed)
|
| 256 |
|
| 257 |
+
# Generate video with InfiniteTalk
|
|
|
|
|
|
|
|
|
|
| 258 |
output_path = f"/tmp/output_{seed}.mp4"
|
| 259 |
|
| 260 |
+
# Prepare input for pipeline (following generate_infinitetalk.py structure)
|
| 261 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
logger.info(f"Generating {resolution} video with {steps} steps...")
|
| 263 |
|
| 264 |
+
# Save audio embeddings to temporary file (pipeline expects file path)
|
| 265 |
+
import tempfile
|
| 266 |
+
os.makedirs("/tmp/audio_embeddings", exist_ok=True)
|
| 267 |
+
emb_path = "/tmp/audio_embeddings/1.pt"
|
| 268 |
+
audio_wav_path = "/tmp/audio_embeddings/sum.wav"
|
| 269 |
+
|
| 270 |
+
torch.save(audio_embeddings, emb_path)
|
| 271 |
+
sf.write(audio_wav_path, audio, sr)
|
| 272 |
+
|
| 273 |
+
# Prepare input dictionary (matches generate_infinitetalk.py format)
|
| 274 |
+
input_clip = {
|
| 275 |
+
"prompt": "", # Empty prompt for talking head
|
| 276 |
+
"cond_video": image_or_video,
|
| 277 |
+
"cond_audio": {
|
| 278 |
+
"person1": emb_path
|
| 279 |
+
},
|
| 280 |
+
"video_audio": audio_wav_path
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
# Calculate sample_shift based on resolution
|
| 284 |
+
sample_shift = 7 if resolution == "480p" else 11
|
| 285 |
+
|
| 286 |
+
# Call InfiniteTalk pipeline
|
| 287 |
+
video_tensor = wan_pipeline.generate_infinitetalk(
|
| 288 |
+
input_clip,
|
| 289 |
+
size_buckget=size,
|
| 290 |
+
motion_frame=9, # Default motion frame
|
| 291 |
+
frame_num=81, # Default frame num (4n+1 format)
|
| 292 |
+
shift=sample_shift,
|
| 293 |
+
sampling_steps=steps,
|
| 294 |
+
text_guide_scale=5.0, # Default text guidance
|
| 295 |
+
audio_guide_scale=audio_guide_scale,
|
| 296 |
+
seed=seed,
|
| 297 |
+
offload_model=True,
|
| 298 |
+
max_frames_num=81, # For clip mode
|
| 299 |
+
color_correction_strength=1.0,
|
| 300 |
+
extra_args=None
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# Save video with audio
|
| 304 |
+
from wan.utils.multitalk_utils import save_video_ffmpeg
|
| 305 |
+
|
| 306 |
+
save_video_ffmpeg(
|
| 307 |
+
video_tensor,
|
| 308 |
+
output_path.replace(".mp4", ""), # Function adds .mp4 extension
|
| 309 |
+
[audio_wav_path],
|
| 310 |
+
high_quality_save=False
|
| 311 |
+
)
|
| 312 |
|
| 313 |
progress(0.9, desc="Finalizing...")
|
| 314 |
|