Update model.py
Browse files
model.py
CHANGED
|
@@ -33,36 +33,6 @@ class INF5Config(PretrainedConfig):
|
|
| 33 |
self.speed = speed
|
| 34 |
self.remove_sil = remove_sil
|
| 35 |
|
| 36 |
-
def extract_speaker_embedding(self, ref_audio_path: str, ref_text: str):
|
| 37 |
-
"""
|
| 38 |
-
Extract speaker embedding or reference features from audio and text.
|
| 39 |
-
Converts audio to WAV if needed. Returns NumPy array for saving/reuse.
|
| 40 |
-
"""
|
| 41 |
-
if not os.path.exists(ref_audio_path):
|
| 42 |
-
raise FileNotFoundError(f"Reference audio file '{ref_audio_path}' not found.")
|
| 43 |
-
|
| 44 |
-
ext = os.path.splitext(ref_audio_path)[-1].lower()
|
| 45 |
-
|
| 46 |
-
# Convert to WAV if input is MP3 or MP4
|
| 47 |
-
if ext not in [".wav"]:
|
| 48 |
-
audio = AudioSegment.from_file(ref_audio_path)
|
| 49 |
-
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav_file:
|
| 50 |
-
temp_path = temp_wav_file.name
|
| 51 |
-
audio.export(temp_path, format="wav")
|
| 52 |
-
ref_audio_path = temp_path # Use converted path
|
| 53 |
-
|
| 54 |
-
# Extract embedding
|
| 55 |
-
speaker_embedding, _ = preprocess_ref_audio_text(ref_audio_path, ref_text)
|
| 56 |
-
|
| 57 |
-
# Clean up if we created a temp file
|
| 58 |
-
if ext not in [".wav"] and os.path.exists(ref_audio_path):
|
| 59 |
-
os.remove(ref_audio_path)
|
| 60 |
-
|
| 61 |
-
# Convert to NumPy for easy saving
|
| 62 |
-
if isinstance(speaker_embedding, torch.Tensor):
|
| 63 |
-
speaker_embedding = speaker_embedding.detach().cpu().numpy()
|
| 64 |
-
|
| 65 |
-
return speaker_embedding
|
| 66 |
|
| 67 |
class INF5Model(PreTrainedModel):
|
| 68 |
config_class = INF5Config
|
|
@@ -94,6 +64,37 @@ class INF5Model(PreTrainedModel):
|
|
| 94 |
# # Load state dict into model
|
| 95 |
self.ema_model.load_state_dict(state_dict, strict=False)
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
def forward(self, text: str, speaker_embedding=None, ref_audio_path=None, ref_text=None):
|
| 98 |
# Validate input
|
| 99 |
if speaker_embedding is None:
|
|
@@ -154,8 +155,9 @@ if __name__ == '__main__':
|
|
| 154 |
import soundfile as sf
|
| 155 |
from transformers import AutoConfig, AutoModel
|
| 156 |
from f5_tts.infer.utils_infer import (
|
| 157 |
-
|
| 158 |
-
|
|
|
|
| 159 |
AutoConfig.register("inf5", INF5Config)
|
| 160 |
AutoModel.register(INF5Config, INF5Model)
|
| 161 |
|
|
|
|
| 33 |
self.speed = speed
|
| 34 |
self.remove_sil = remove_sil
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
class INF5Model(PreTrainedModel):
|
| 38 |
config_class = INF5Config
|
|
|
|
| 64 |
# # Load state dict into model
|
| 65 |
self.ema_model.load_state_dict(state_dict, strict=False)
|
| 66 |
|
| 67 |
+
def extract_speaker_embedding(self, ref_audio_path: str, ref_text: str):
|
| 68 |
+
"""
|
| 69 |
+
Extract speaker embedding or reference features from audio and text.
|
| 70 |
+
Converts audio to WAV if needed. Returns NumPy array for saving/reuse.
|
| 71 |
+
"""
|
| 72 |
+
if not os.path.exists(ref_audio_path):
|
| 73 |
+
raise FileNotFoundError(f"Reference audio file '{ref_audio_path}' not found.")
|
| 74 |
+
|
| 75 |
+
ext = os.path.splitext(ref_audio_path)[-1].lower()
|
| 76 |
+
|
| 77 |
+
# Convert to WAV if input is MP3 or MP4
|
| 78 |
+
if ext not in [".wav"]:
|
| 79 |
+
audio = AudioSegment.from_file(ref_audio_path)
|
| 80 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav_file:
|
| 81 |
+
temp_path = temp_wav_file.name
|
| 82 |
+
audio.export(temp_path, format="wav")
|
| 83 |
+
ref_audio_path = temp_path # Use converted path
|
| 84 |
+
|
| 85 |
+
# Extract embedding
|
| 86 |
+
speaker_embedding, _ = preprocess_ref_audio_text(ref_audio_path, ref_text)
|
| 87 |
+
|
| 88 |
+
# Clean up if we created a temp file
|
| 89 |
+
if ext not in [".wav"] and os.path.exists(ref_audio_path):
|
| 90 |
+
os.remove(ref_audio_path)
|
| 91 |
+
|
| 92 |
+
# Convert to NumPy for easy saving
|
| 93 |
+
if isinstance(speaker_embedding, torch.Tensor):
|
| 94 |
+
speaker_embedding = speaker_embedding.detach().cpu().numpy()
|
| 95 |
+
|
| 96 |
+
return speaker_embedding
|
| 97 |
+
|
| 98 |
def forward(self, text: str, speaker_embedding=None, ref_audio_path=None, ref_text=None):
|
| 99 |
# Validate input
|
| 100 |
if speaker_embedding is None:
|
|
|
|
| 155 |
import soundfile as sf
|
| 156 |
from transformers import AutoConfig, AutoModel
|
| 157 |
from f5_tts.infer.utils_infer import (
|
| 158 |
+
preprocess_ref_audio_text,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
AutoConfig.register("inf5", INF5Config)
|
| 162 |
AutoModel.register(INF5Config, INF5Model)
|
| 163 |
|