Update model.py
Browse files
model.py
CHANGED
|
@@ -32,6 +32,37 @@ class INF5Config(PretrainedConfig):
|
|
| 32 |
self.speed = speed
|
| 33 |
self.remove_sil = remove_sil
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
class INF5Model(PreTrainedModel):
|
| 36 |
config_class = INF5Config
|
| 37 |
|
|
|
|
| 32 |
self.speed = speed
|
| 33 |
self.remove_sil = remove_sil
|
| 34 |
|
| 35 |
+
def extract_speaker_embedding(self, ref_audio_path: str, ref_text: str):
|
| 36 |
+
"""
|
| 37 |
+
Extract speaker embedding or reference features from audio and text.
|
| 38 |
+
Converts audio to WAV if needed. Returns NumPy array for saving/reuse.
|
| 39 |
+
"""
|
| 40 |
+
if not os.path.exists(ref_audio_path):
|
| 41 |
+
raise FileNotFoundError(f"Reference audio file '{ref_audio_path}' not found.")
|
| 42 |
+
|
| 43 |
+
ext = os.path.splitext(ref_audio_path)[-1].lower()
|
| 44 |
+
|
| 45 |
+
# Convert to WAV if input is MP3 or MP4
|
| 46 |
+
if ext not in [".wav"]:
|
| 47 |
+
audio = AudioSegment.from_file(ref_audio_path)
|
| 48 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav_file:
|
| 49 |
+
temp_path = temp_wav_file.name
|
| 50 |
+
audio.export(temp_path, format="wav")
|
| 51 |
+
ref_audio_path = temp_path # Use converted path
|
| 52 |
+
|
| 53 |
+
# Extract embedding
|
| 54 |
+
speaker_embedding, _ = preprocess_ref_audio_text(ref_audio_path, ref_text)
|
| 55 |
+
|
| 56 |
+
# Clean up if we created a temp file
|
| 57 |
+
if ext not in [".wav"] and os.path.exists(ref_audio_path):
|
| 58 |
+
os.remove(ref_audio_path)
|
| 59 |
+
|
| 60 |
+
# Convert to NumPy for easy saving
|
| 61 |
+
if isinstance(speaker_embedding, torch.Tensor):
|
| 62 |
+
speaker_embedding = speaker_embedding.detach().cpu().numpy()
|
| 63 |
+
|
| 64 |
+
return speaker_embedding
|
| 65 |
+
|
| 66 |
class INF5Model(PreTrainedModel):
|
| 67 |
config_class = INF5Config
|
| 68 |
|