Chenghao Mou commited on
Commit
03f211f
·
unverified ·
1 Parent(s): f50f534

fix: floor (#293)

Browse files
Files changed (1) hide show
  1. musetalk/utils/audio_processor.py +22 -22
musetalk/utils/audio_processor.py CHANGED
@@ -1,16 +1,17 @@
1
- import os
2
  import math
 
 
3
  import librosa
4
  import numpy as np
5
  import torch
6
-
7
  from einops import rearrange
8
  from transformers import AutoFeatureExtractor
9
 
 
10
  class AudioProcessor:
11
  def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
12
  self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
13
-
14
  def get_audio_feature(self, wav_path, start_index=0, weight_dtype=None):
15
  if not os.path.exists(wav_path):
16
  return None
@@ -19,11 +20,11 @@ class AudioProcessor:
19
  # Split audio into 30s segments
20
  segment_length = 30 * sampling_rate
21
  segments = [librosa_output[i:i + segment_length] for i in range(0, len(librosa_output), segment_length)]
22
-
23
  features = []
24
  for segment in segments:
25
  audio_feature = self.feature_extractor(
26
- segment,
27
  return_tensors="pt",
28
  sampling_rate=sampling_rate
29
  ).input_features
@@ -32,13 +33,13 @@ class AudioProcessor:
32
  features.append(audio_feature)
33
 
34
  return features, len(librosa_output)
35
-
36
  def get_whisper_chunk(
37
- self,
38
- whisper_input_features,
39
- device,
40
- weight_dtype,
41
- whisper,
42
  librosa_length,
43
  fps=25,
44
  audio_padding_length_left=2,
@@ -48,30 +49,30 @@ class AudioProcessor:
48
  whisper_feature = []
49
  # Process multiple 30s mel input features
50
  for input_feature in whisper_input_features:
51
- audio_feats = whisper.encoder(input_feature.to(device), output_hidden_states=True).hidden_states
52
  audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype)
53
  whisper_feature.append(audio_feats)
54
-
55
  whisper_feature = torch.cat(whisper_feature, dim=1)
56
  # Trim the last segment to remove padding
57
  sr = 16000
58
  audio_fps = 50
59
  fps = int(fps)
60
  whisper_idx_multiplier = audio_fps / fps
61
- num_frames = math.floor((librosa_length / sr)) * fps
62
- actual_length = math.floor((librosa_length / sr)) * audio_fps
63
  whisper_feature = whisper_feature[:,:actual_length,...]
64
-
65
  # Calculate padding amount
66
  padding_nums = math.floor(whisper_idx_multiplier)
67
  # Add padding at start and end
68
  whisper_feature = torch.cat([
69
- torch.zeros_like(whisper_feature[:, :padding_nums * audio_padding_length_left]),
70
- whisper_feature,
71
  # Add extra padding to prevent out of bounds
72
  torch.zeros_like(whisper_feature[:, :padding_nums * 3 * audio_padding_length_right])
73
  ], 1)
74
-
75
  audio_prompts = []
76
  for frame_index in range(num_frames):
77
  try:
@@ -86,7 +87,7 @@ class AudioProcessor:
86
  print(f"num frames: {num_frames}, fps: {fps}, whisper_idx_multiplier: {whisper_idx_multiplier}")
87
  print(f"frame_index: {frame_index}, audio_index: {audio_index}-{audio_index + audio_feature_length_per_frame}")
88
  exit()
89
-
90
  audio_prompts = torch.cat(audio_prompts, dim=0) # T, 10, 5, 384
91
  audio_prompts = rearrange(audio_prompts, 'b c h w -> b (c h) w')
92
  return audio_prompts
@@ -97,5 +98,4 @@ if __name__ == "__main__":
97
  audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path)
98
  print("Audio Feature shape:", audio_feature.shape)
99
  print("librosa_feature_length:", librosa_feature_length)
100
-
101
-
 
 
1
  import math
2
+ import os
3
+
4
  import librosa
5
  import numpy as np
6
  import torch
 
7
  from einops import rearrange
8
  from transformers import AutoFeatureExtractor
9
 
10
+
11
  class AudioProcessor:
12
  def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
13
  self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
14
+
15
  def get_audio_feature(self, wav_path, start_index=0, weight_dtype=None):
16
  if not os.path.exists(wav_path):
17
  return None
 
20
  # Split audio into 30s segments
21
  segment_length = 30 * sampling_rate
22
  segments = [librosa_output[i:i + segment_length] for i in range(0, len(librosa_output), segment_length)]
23
+
24
  features = []
25
  for segment in segments:
26
  audio_feature = self.feature_extractor(
27
+ segment,
28
  return_tensors="pt",
29
  sampling_rate=sampling_rate
30
  ).input_features
 
33
  features.append(audio_feature)
34
 
35
  return features, len(librosa_output)
36
+
37
  def get_whisper_chunk(
38
+ self,
39
+ whisper_input_features,
40
+ device,
41
+ weight_dtype,
42
+ whisper,
43
  librosa_length,
44
  fps=25,
45
  audio_padding_length_left=2,
 
49
  whisper_feature = []
50
  # Process multiple 30s mel input features
51
  for input_feature in whisper_input_features:
52
+ audio_feats = whisper.encoder(input_feature.to(device), output_hidden_states=True).hidden_states
53
  audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype)
54
  whisper_feature.append(audio_feats)
55
+
56
  whisper_feature = torch.cat(whisper_feature, dim=1)
57
  # Trim the last segment to remove padding
58
  sr = 16000
59
  audio_fps = 50
60
  fps = int(fps)
61
  whisper_idx_multiplier = audio_fps / fps
62
+ num_frames = math.floor((librosa_length / sr) * fps)
63
+ actual_length = math.floor((librosa_length / sr) * audio_fps)
64
  whisper_feature = whisper_feature[:,:actual_length,...]
65
+
66
  # Calculate padding amount
67
  padding_nums = math.floor(whisper_idx_multiplier)
68
  # Add padding at start and end
69
  whisper_feature = torch.cat([
70
+ torch.zeros_like(whisper_feature[:, :padding_nums * audio_padding_length_left]),
71
+ whisper_feature,
72
  # Add extra padding to prevent out of bounds
73
  torch.zeros_like(whisper_feature[:, :padding_nums * 3 * audio_padding_length_right])
74
  ], 1)
75
+
76
  audio_prompts = []
77
  for frame_index in range(num_frames):
78
  try:
 
87
  print(f"num frames: {num_frames}, fps: {fps}, whisper_idx_multiplier: {whisper_idx_multiplier}")
88
  print(f"frame_index: {frame_index}, audio_index: {audio_index}-{audio_index + audio_feature_length_per_frame}")
89
  exit()
90
+
91
  audio_prompts = torch.cat(audio_prompts, dim=0) # T, 10, 5, 384
92
  audio_prompts = rearrange(audio_prompts, 'b c h w -> b (c h) w')
93
  return audio_prompts
 
98
  audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path)
99
  print("Audio Feature shape:", audio_feature.shape)
100
  print("librosa_feature_length:", librosa_feature_length)
101
+