ShalomKing commited on
Commit
bc5110c
·
verified ·
1 Parent(s): 38572a2

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +91 -54
app.py CHANGED
@@ -78,9 +78,23 @@ def initialize_models(progress=gr.Progress()):
78
  raise gr.Error(f"Failed to initialize models: {str(e)}")
79
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def process_audio(audio_path, target_sr=16000):
82
  """
83
- Process audio file for InfiniteTalk
84
 
85
  Args:
86
  audio_path: Path to audio file
@@ -90,18 +104,11 @@ def process_audio(audio_path, target_sr=16000):
90
  Processed audio array and sample rate
91
  """
92
  try:
93
- # Load audio
94
- audio, sr = librosa.load(audio_path, sr=None)
95
-
96
- # Resample if needed
97
- if sr != target_sr:
98
- audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
99
- sr = target_sr
100
 
101
  # Normalize loudness
102
- meter = pyln.Meter(sr)
103
- loudness = meter.integrated_loudness(audio)
104
- audio = pyln.normalize.loudness(audio, loudness, -20.0)
105
 
106
  # Ensure mono
107
  if len(audio.shape) > 1:
@@ -187,8 +194,8 @@ def generate_video(
187
  # Load models
188
  size = f"infinitetalk-{resolution.replace('p', '')}"
189
 
190
- # Load Wan model
191
- wan_model = model_manager.load_wan_model(size=size, device="cuda")
192
 
193
  # Load audio encoder
194
  audio_encoder, feature_extractor = model_manager.load_audio_encoder(device="cuda")
@@ -210,18 +217,31 @@ def generate_video(
210
 
211
  progress(0.4, desc="Extracting audio features...")
212
 
213
- # Extract audio features
214
- audio_features = feature_extractor(
215
- audio,
216
- sampling_rate=sr,
217
- return_tensors="pt"
218
- ).input_values
219
 
220
- audio_features = audio_features.to("cuda")
 
 
 
 
 
221
 
 
222
  with torch.no_grad():
223
- audio_embeddings = audio_encoder(audio_features).last_hidden_state
224
 
 
 
 
 
 
 
 
 
 
 
225
  gpu_manager.print_memory_usage("After audio processing - ")
226
 
227
  progress(0.5, desc="Generating video (this may take a minute)...")
@@ -234,44 +254,61 @@ def generate_video(
234
  if torch.cuda.is_available():
235
  torch.cuda.manual_seed(seed)
236
 
237
- # Generate video
238
- # This is a placeholder for the actual inference logic
239
- # The actual implementation would call wan_model.generate() with proper parameters
240
-
241
  output_path = f"/tmp/output_{seed}.mp4"
242
 
243
- # Simplified inference call (replace with actual InfiniteTalk logic)
244
  with torch.no_grad():
245
- # Parameters
246
- generation_args = {
247
- "input_frames": input_frames,
248
- "audio_embeddings": audio_embeddings,
249
- "num_steps": steps,
250
- "audio_guide_scale": audio_guide_scale,
251
- "size": size,
252
- "seed": seed,
253
- }
254
-
255
- # Call model inference (placeholder)
256
- # output_frames = wan_model.generate(**generation_args)
257
-
258
- # For now, just create a dummy output to test the pipeline
259
- # In production, this would be replaced with actual video generation
260
  logger.info(f"Generating {resolution} video with {steps} steps...")
261
 
262
- # Placeholder: copy input as output for testing
263
- import shutil
264
- if is_input_video:
265
- shutil.copy(image_or_video, output_path)
266
- else:
267
- # Create a short video from the image
268
- # This is just for testing - replace with actual generation
269
- logger.warning("Placeholder: actual video generation not implemented yet")
270
- raise gr.Error(
271
- "Video generation logic needs to be integrated. "
272
- "This is a template - please integrate the actual InfiniteTalk "
273
- "inference code from generate_infinitetalk.py"
274
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
  progress(0.9, desc="Finalizing...")
277
 
 
78
  raise gr.Error(f"Failed to initialize models: {str(e)}")
79
 
80
 
81
+ def loudness_norm(audio_array, sr=16000, lufs=-20.0):
82
+ """Normalize audio loudness using pyloudnorm"""
83
+ try:
84
+ meter = pyln.Meter(sr)
85
+ loudness = meter.integrated_loudness(audio_array)
86
+ if abs(loudness) > 100: # Skip if loudness measurement failed
87
+ return audio_array
88
+ normalized_audio = pyln.normalize.loudness(audio_array, loudness, lufs)
89
+ return normalized_audio
90
+ except Exception as e:
91
+ logger.warning(f"Loudness normalization failed: {e}, returning original audio")
92
+ return audio_array
93
+
94
+
95
  def process_audio(audio_path, target_sr=16000):
96
  """
97
+ Process audio file for InfiniteTalk (matches audio_prepare_single from reference)
98
 
99
  Args:
100
  audio_path: Path to audio file
 
104
  Processed audio array and sample rate
105
  """
106
  try:
107
+ # Load audio with librosa
108
+ audio, sr = librosa.load(audio_path, sr=target_sr)
 
 
 
 
 
109
 
110
  # Normalize loudness
111
+ audio = loudness_norm(audio, sr)
 
 
112
 
113
  # Ensure mono
114
  if len(audio.shape) > 1:
 
194
  # Load models
195
  size = f"infinitetalk-{resolution.replace('p', '')}"
196
 
197
+ # Load InfiniteTalk pipeline
198
+ wan_pipeline = model_manager.load_wan_model(size=size, device="cuda")
199
 
200
  # Load audio encoder
201
  audio_encoder, feature_extractor = model_manager.load_audio_encoder(device="cuda")
 
217
 
218
  progress(0.4, desc="Extracting audio features...")
219
 
220
+ # Extract audio features (matches get_embedding from reference)
221
+ audio_duration = len(audio) / sr
222
+ video_length = audio_duration * 25 # Assume 25 FPS
 
 
 
223
 
224
+ # Extract features with wav2vec
225
+ audio_feature = np.squeeze(
226
+ feature_extractor(audio, sampling_rate=sr).input_values
227
+ )
228
+ audio_feature = torch.from_numpy(audio_feature).float().to(device="cuda")
229
+ audio_feature = audio_feature.unsqueeze(0)
230
 
231
+ # Get embeddings from audio encoder
232
  with torch.no_grad():
233
+ embeddings = audio_encoder(audio_feature, seq_len=int(video_length), output_hidden_states=True)
234
 
235
+ if len(embeddings) == 0 or not hasattr(embeddings, 'hidden_states'):
236
+ raise gr.Error("Failed to extract audio embeddings")
237
+
238
+ # Stack hidden states (matches reference implementation)
239
+ from einops import rearrange
240
+ audio_embeddings = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
241
+ audio_embeddings = rearrange(audio_embeddings, "b s d -> s b d")
242
+ audio_embeddings = audio_embeddings.cpu().detach()
243
+
244
+ logger.info(f"Audio embeddings shape: {audio_embeddings.shape}")
245
  gpu_manager.print_memory_usage("After audio processing - ")
246
 
247
  progress(0.5, desc="Generating video (this may take a minute)...")
 
254
  if torch.cuda.is_available():
255
  torch.cuda.manual_seed(seed)
256
 
257
+ # Generate video with InfiniteTalk
 
 
 
258
  output_path = f"/tmp/output_{seed}.mp4"
259
 
260
+ # Prepare input for pipeline (following generate_infinitetalk.py structure)
261
  with torch.no_grad():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  logger.info(f"Generating {resolution} video with {steps} steps...")
263
 
264
+ # Save audio embeddings to temporary file (pipeline expects file path)
265
+ import tempfile
266
+ os.makedirs("/tmp/audio_embeddings", exist_ok=True)
267
+ emb_path = "/tmp/audio_embeddings/1.pt"
268
+ audio_wav_path = "/tmp/audio_embeddings/sum.wav"
269
+
270
+ torch.save(audio_embeddings, emb_path)
271
+ sf.write(audio_wav_path, audio, sr)
272
+
273
+ # Prepare input dictionary (matches generate_infinitetalk.py format)
274
+ input_clip = {
275
+ "prompt": "", # Empty prompt for talking head
276
+ "cond_video": image_or_video,
277
+ "cond_audio": {
278
+ "person1": emb_path
279
+ },
280
+ "video_audio": audio_wav_path
281
+ }
282
+
283
+ # Calculate sample_shift based on resolution
284
+ sample_shift = 7 if resolution == "480p" else 11
285
+
286
+ # Call InfiniteTalk pipeline
287
+ video_tensor = wan_pipeline.generate_infinitetalk(
288
+ input_clip,
289
+ size_buckget=size,
290
+ motion_frame=9, # Default motion frame
291
+ frame_num=81, # Default frame num (4n+1 format)
292
+ shift=sample_shift,
293
+ sampling_steps=steps,
294
+ text_guide_scale=5.0, # Default text guidance
295
+ audio_guide_scale=audio_guide_scale,
296
+ seed=seed,
297
+ offload_model=True,
298
+ max_frames_num=81, # For clip mode
299
+ color_correction_strength=1.0,
300
+ extra_args=None
301
+ )
302
+
303
+ # Save video with audio
304
+ from wan.utils.multitalk_utils import save_video_ffmpeg
305
+
306
+ save_video_ffmpeg(
307
+ video_tensor,
308
+ output_path.replace(".mp4", ""), # Function adds .mp4 extension
309
+ [audio_wav_path],
310
+ high_quality_save=False
311
+ )
312
 
313
  progress(0.9, desc="Finalizing...")
314