Anjan9320 commited on
Commit
9fc138a
·
verified ·
1 Parent(s): 0476a91

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +34 -32
model.py CHANGED
@@ -33,36 +33,6 @@ class INF5Config(PretrainedConfig):
33
  self.speed = speed
34
  self.remove_sil = remove_sil
35
 
36
- def extract_speaker_embedding(self, ref_audio_path: str, ref_text: str):
37
- """
38
- Extract speaker embedding or reference features from audio and text.
39
- Converts audio to WAV if needed. Returns NumPy array for saving/reuse.
40
- """
41
- if not os.path.exists(ref_audio_path):
42
- raise FileNotFoundError(f"Reference audio file '{ref_audio_path}' not found.")
43
-
44
- ext = os.path.splitext(ref_audio_path)[-1].lower()
45
-
46
- # Convert to WAV if input is MP3 or MP4
47
- if ext not in [".wav"]:
48
- audio = AudioSegment.from_file(ref_audio_path)
49
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav_file:
50
- temp_path = temp_wav_file.name
51
- audio.export(temp_path, format="wav")
52
- ref_audio_path = temp_path # Use converted path
53
-
54
- # Extract embedding
55
- speaker_embedding, _ = preprocess_ref_audio_text(ref_audio_path, ref_text)
56
-
57
- # Clean up if we created a temp file
58
- if ext not in [".wav"] and os.path.exists(ref_audio_path):
59
- os.remove(ref_audio_path)
60
-
61
- # Convert to NumPy for easy saving
62
- if isinstance(speaker_embedding, torch.Tensor):
63
- speaker_embedding = speaker_embedding.detach().cpu().numpy()
64
-
65
- return speaker_embedding
66
 
67
  class INF5Model(PreTrainedModel):
68
  config_class = INF5Config
@@ -94,6 +64,37 @@ class INF5Model(PreTrainedModel):
94
  # # Load state dict into model
95
  self.ema_model.load_state_dict(state_dict, strict=False)
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  def forward(self, text: str, speaker_embedding=None, ref_audio_path=None, ref_text=None):
98
  # Validate input
99
  if speaker_embedding is None:
@@ -154,8 +155,9 @@ if __name__ == '__main__':
154
  import soundfile as sf
155
  from transformers import AutoConfig, AutoModel
156
  from f5_tts.infer.utils_infer import (
157
- preprocess_ref_audio_text,
158
- )
 
159
  AutoConfig.register("inf5", INF5Config)
160
  AutoModel.register(INF5Config, INF5Model)
161
 
 
33
  self.speed = speed
34
  self.remove_sil = remove_sil
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  class INF5Model(PreTrainedModel):
38
  config_class = INF5Config
 
64
  # # Load state dict into model
65
  self.ema_model.load_state_dict(state_dict, strict=False)
66
 
67
+ def extract_speaker_embedding(self, ref_audio_path: str, ref_text: str):
68
+ """
69
+ Extract speaker embedding or reference features from audio and text.
70
+ Converts audio to WAV if needed. Returns NumPy array for saving/reuse.
71
+ """
72
+ if not os.path.exists(ref_audio_path):
73
+ raise FileNotFoundError(f"Reference audio file '{ref_audio_path}' not found.")
74
+
75
+ ext = os.path.splitext(ref_audio_path)[-1].lower()
76
+
77
+ # Convert to WAV if input is MP3 or MP4
78
+ if ext not in [".wav"]:
79
+ audio = AudioSegment.from_file(ref_audio_path)
80
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav_file:
81
+ temp_path = temp_wav_file.name
82
+ audio.export(temp_path, format="wav")
83
+ ref_audio_path = temp_path # Use converted path
84
+
85
+ # Extract embedding
86
+ speaker_embedding, _ = preprocess_ref_audio_text(ref_audio_path, ref_text)
87
+
88
+ # Clean up if we created a temp file
89
+ if ext not in [".wav"] and os.path.exists(ref_audio_path):
90
+ os.remove(ref_audio_path)
91
+
92
+ # Convert to NumPy for easy saving
93
+ if isinstance(speaker_embedding, torch.Tensor):
94
+ speaker_embedding = speaker_embedding.detach().cpu().numpy()
95
+
96
+ return speaker_embedding
97
+
98
  def forward(self, text: str, speaker_embedding=None, ref_audio_path=None, ref_text=None):
99
  # Validate input
100
  if speaker_embedding is None:
 
155
  import soundfile as sf
156
  from transformers import AutoConfig, AutoModel
157
  from f5_tts.infer.utils_infer import (
158
+ preprocess_ref_audio_text,
159
+ )
160
+
161
  AutoConfig.register("inf5", INF5Config)
162
  AutoModel.register(INF5Config, INF5Model)
163