pocket-tts / src /tts /manager.py
hadadrjt's picture
[3/?] Pocket TTS: Handle multiple format extensions for voice cloning.
dae9fa5
#
# SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
# SPDX-License-Identifier: Apache-2.0
#
import time
import tempfile
import threading
import torch
import scipy.io.wavfile
from pocket_tts import TTSModel
from config import (
AVAILABLE_VOICES,
DEFAULT_VOICE,
DEFAULT_MODEL_VARIANT,
DEFAULT_TEMPERATURE,
DEFAULT_LSD_DECODE_STEPS,
DEFAULT_EOS_THRESHOLD,
VOICE_STATE_CACHE_MAXIMUM_SIZE,
VOICE_STATE_CACHE_CLEANUP_THRESHOLD
)
from ..core.state import (
temporary_files_registry,
temporary_files_lock,
set_text_to_speech_manager
)
from ..core.memory import (
force_garbage_collection,
memory_cleanup,
perform_memory_cleanup,
trigger_background_cleanup_check,
is_memory_usage_approaching_limit
)
class TextToSpeechManager:
def __init__(self):
self.loaded_model = None
self.current_configuration = {}
self.voice_state_cache = {}
self.voice_state_cache_access_timestamps = {}
self.voice_state_cache_lock = threading.Lock()
self.model_lock = threading.Lock()
def is_model_loaded(self):
with self.model_lock:
return self.loaded_model is not None
def unload_model_completely(self):
with self.model_lock:
self.clear_voice_state_cache_completely()
if self.loaded_model is not None:
del self.loaded_model
self.loaded_model = None
self.current_configuration = {}
memory_cleanup()
def load_or_get_model(
self,
model_variant,
temperature,
lsd_decode_steps,
noise_clamp,
eos_threshold
):
perform_memory_cleanup()
processed_variant = str(model_variant or DEFAULT_MODEL_VARIANT).strip()
processed_temperature = float(temperature) if temperature is not None else DEFAULT_TEMPERATURE
processed_lsd_steps = int(lsd_decode_steps) if lsd_decode_steps is not None else DEFAULT_LSD_DECODE_STEPS
processed_noise_clamp = float(noise_clamp) if noise_clamp and float(noise_clamp) > 0 else None
processed_eos_threshold = float(eos_threshold) if eos_threshold is not None else DEFAULT_EOS_THRESHOLD
requested_configuration = {
"variant": processed_variant,
"temp": processed_temperature,
"lsd_decode_steps": processed_lsd_steps,
"noise_clamp": processed_noise_clamp,
"eos_threshold": processed_eos_threshold
}
with self.model_lock:
if self.loaded_model is None or self.current_configuration != requested_configuration:
if self.loaded_model is not None:
self.clear_voice_state_cache_completely()
del self.loaded_model
self.loaded_model = None
memory_cleanup()
self.loaded_model = TTSModel.load_model(**requested_configuration)
self.current_configuration = requested_configuration
self.voice_state_cache = {}
return self.loaded_model
def clear_voice_state_cache_completely(self):
with self.voice_state_cache_lock:
for voice_name in list(self.voice_state_cache.keys()):
voice_state_tensor = self.voice_state_cache.pop(voice_name, None)
if voice_state_tensor is not None:
del voice_state_tensor
self.voice_state_cache.clear()
self.voice_state_cache_access_timestamps.clear()
force_garbage_collection()
def evict_least_recently_used_voice_states(self):
with self.voice_state_cache_lock:
if len(self.voice_state_cache) <= VOICE_STATE_CACHE_CLEANUP_THRESHOLD:
if len(self.voice_state_cache) > 0:
sorted_voice_names_by_access_time = sorted(
self.voice_state_cache_access_timestamps.keys(),
key=lambda voice_name: self.voice_state_cache_access_timestamps[voice_name]
)
number_of_entries_to_remove = max(1, len(self.voice_state_cache) // 2)
for index in range(min(number_of_entries_to_remove, len(sorted_voice_names_by_access_time))):
voice_name_to_remove = sorted_voice_names_by_access_time[index]
voice_state_tensor = self.voice_state_cache.pop(voice_name_to_remove, None)
self.voice_state_cache_access_timestamps.pop(voice_name_to_remove, None)
if voice_state_tensor is not None:
del voice_state_tensor
force_garbage_collection()
return
sorted_voice_names_by_access_time = sorted(
self.voice_state_cache_access_timestamps.keys(),
key=lambda voice_name: self.voice_state_cache_access_timestamps[voice_name]
)
number_of_entries_to_remove = len(self.voice_state_cache) - VOICE_STATE_CACHE_CLEANUP_THRESHOLD
for index in range(number_of_entries_to_remove):
voice_name_to_remove = sorted_voice_names_by_access_time[index]
voice_state_tensor = self.voice_state_cache.pop(voice_name_to_remove, None)
self.voice_state_cache_access_timestamps.pop(voice_name_to_remove, None)
if voice_state_tensor is not None:
del voice_state_tensor
force_garbage_collection()
def get_voice_state_for_preset(self, voice_name):
validated_voice = voice_name if voice_name in AVAILABLE_VOICES else DEFAULT_VOICE
with self.voice_state_cache_lock:
if validated_voice in self.voice_state_cache:
self.voice_state_cache_access_timestamps[validated_voice] = time.time()
return self.voice_state_cache[validated_voice]
if is_memory_usage_approaching_limit():
self.evict_least_recently_used_voice_states()
if len(self.voice_state_cache) >= VOICE_STATE_CACHE_MAXIMUM_SIZE:
self.evict_least_recently_used_voice_states()
with self.model_lock:
if self.loaded_model is None:
raise RuntimeError("TTS model is not loaded. Please try again.")
if validated_voice not in self.voice_state_cache:
computed_voice_state = self.loaded_model.get_state_for_audio_prompt(
audio_conditioning=validated_voice,
truncate=False
)
with self.voice_state_cache_lock:
self.voice_state_cache[validated_voice] = computed_voice_state
self.voice_state_cache_access_timestamps[validated_voice] = time.time()
return self.voice_state_cache[validated_voice]
def get_voice_state_for_clone(self, audio_file_path, prepared_audio_path=None):
with self.model_lock:
if self.loaded_model is None:
raise RuntimeError("TTS model is not loaded. Please try again.")
audio_path_to_use = prepared_audio_path if prepared_audio_path is not None else audio_file_path
return self.loaded_model.get_state_for_audio_prompt(
audio_conditioning=audio_path_to_use,
truncate=False
)
def generate_audio(self, text_content, voice_state, frames_after_eos, enable_custom_frames):
with self.model_lock:
if self.loaded_model is None:
raise RuntimeError("TTS model is not loaded. Please try again.")
processed_frames = int(frames_after_eos) if enable_custom_frames else None
generated_audio = self.loaded_model.generate_audio(
model_state=voice_state,
text_to_generate=text_content,
frames_after_eos=processed_frames,
copy_state=True
)
force_garbage_collection()
return generated_audio
def save_audio_to_file(self, audio_tensor):
with self.model_lock:
if self.loaded_model is None:
raise RuntimeError("TTS model is not loaded. Cannot save audio.")
audio_sample_rate = self.loaded_model.sample_rate
audio_numpy_data = audio_tensor.numpy()
output_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
scipy.io.wavfile.write(output_file.name, audio_sample_rate, audio_numpy_data)
with temporary_files_lock:
temporary_files_registry[output_file.name] = time.time()
trigger_background_cleanup_check()
return output_file.name
text_to_speech_manager = TextToSpeechManager()
set_text_to_speech_manager(text_to_speech_manager)