# # SPDX-FileCopyrightText: Hadad # SPDX-License-Identifier: Apache-2.0 # import time import tempfile import threading import torch import scipy.io.wavfile from pocket_tts import TTSModel from config import ( AVAILABLE_VOICES, DEFAULT_VOICE, DEFAULT_MODEL_VARIANT, DEFAULT_TEMPERATURE, DEFAULT_LSD_DECODE_STEPS, DEFAULT_EOS_THRESHOLD, VOICE_STATE_CACHE_MAXIMUM_SIZE, VOICE_STATE_CACHE_CLEANUP_THRESHOLD ) from ..core.state import ( temporary_files_registry, temporary_files_lock, set_text_to_speech_manager ) from ..core.memory import ( force_garbage_collection, memory_cleanup, perform_memory_cleanup, trigger_background_cleanup_check, is_memory_usage_approaching_limit ) class TextToSpeechManager: def __init__(self): self.loaded_model = None self.current_configuration = {} self.voice_state_cache = {} self.voice_state_cache_access_timestamps = {} self.voice_state_cache_lock = threading.Lock() self.model_lock = threading.Lock() def is_model_loaded(self): with self.model_lock: return self.loaded_model is not None def unload_model_completely(self): with self.model_lock: self.clear_voice_state_cache_completely() if self.loaded_model is not None: del self.loaded_model self.loaded_model = None self.current_configuration = {} memory_cleanup() def load_or_get_model( self, model_variant, temperature, lsd_decode_steps, noise_clamp, eos_threshold ): perform_memory_cleanup() processed_variant = str(model_variant or DEFAULT_MODEL_VARIANT).strip() processed_temperature = float(temperature) if temperature is not None else DEFAULT_TEMPERATURE processed_lsd_steps = int(lsd_decode_steps) if lsd_decode_steps is not None else DEFAULT_LSD_DECODE_STEPS processed_noise_clamp = float(noise_clamp) if noise_clamp and float(noise_clamp) > 0 else None processed_eos_threshold = float(eos_threshold) if eos_threshold is not None else DEFAULT_EOS_THRESHOLD requested_configuration = { "variant": processed_variant, "temp": processed_temperature, "lsd_decode_steps": processed_lsd_steps, "noise_clamp": processed_noise_clamp, "eos_threshold": processed_eos_threshold } with self.model_lock: if self.loaded_model is None or self.current_configuration != requested_configuration: if self.loaded_model is not None: self.clear_voice_state_cache_completely() del self.loaded_model self.loaded_model = None memory_cleanup() self.loaded_model = TTSModel.load_model(**requested_configuration) self.current_configuration = requested_configuration self.voice_state_cache = {} return self.loaded_model def clear_voice_state_cache_completely(self): with self.voice_state_cache_lock: for voice_name in list(self.voice_state_cache.keys()): voice_state_tensor = self.voice_state_cache.pop(voice_name, None) if voice_state_tensor is not None: del voice_state_tensor self.voice_state_cache.clear() self.voice_state_cache_access_timestamps.clear() force_garbage_collection() def evict_least_recently_used_voice_states(self): with self.voice_state_cache_lock: if len(self.voice_state_cache) <= VOICE_STATE_CACHE_CLEANUP_THRESHOLD: if len(self.voice_state_cache) > 0: sorted_voice_names_by_access_time = sorted( self.voice_state_cache_access_timestamps.keys(), key=lambda voice_name: self.voice_state_cache_access_timestamps[voice_name] ) number_of_entries_to_remove = max(1, len(self.voice_state_cache) // 2) for index in range(min(number_of_entries_to_remove, len(sorted_voice_names_by_access_time))): voice_name_to_remove = sorted_voice_names_by_access_time[index] voice_state_tensor = self.voice_state_cache.pop(voice_name_to_remove, None) self.voice_state_cache_access_timestamps.pop(voice_name_to_remove, None) if voice_state_tensor is not None: del voice_state_tensor force_garbage_collection() return sorted_voice_names_by_access_time = sorted( self.voice_state_cache_access_timestamps.keys(), key=lambda voice_name: self.voice_state_cache_access_timestamps[voice_name] ) number_of_entries_to_remove = len(self.voice_state_cache) - VOICE_STATE_CACHE_CLEANUP_THRESHOLD for index in range(number_of_entries_to_remove): voice_name_to_remove = sorted_voice_names_by_access_time[index] voice_state_tensor = self.voice_state_cache.pop(voice_name_to_remove, None) self.voice_state_cache_access_timestamps.pop(voice_name_to_remove, None) if voice_state_tensor is not None: del voice_state_tensor force_garbage_collection() def get_voice_state_for_preset(self, voice_name): validated_voice = voice_name if voice_name in AVAILABLE_VOICES else DEFAULT_VOICE with self.voice_state_cache_lock: if validated_voice in self.voice_state_cache: self.voice_state_cache_access_timestamps[validated_voice] = time.time() return self.voice_state_cache[validated_voice] if is_memory_usage_approaching_limit(): self.evict_least_recently_used_voice_states() if len(self.voice_state_cache) >= VOICE_STATE_CACHE_MAXIMUM_SIZE: self.evict_least_recently_used_voice_states() with self.model_lock: if self.loaded_model is None: raise RuntimeError("TTS model is not loaded. Please try again.") if validated_voice not in self.voice_state_cache: computed_voice_state = self.loaded_model.get_state_for_audio_prompt( audio_conditioning=validated_voice, truncate=False ) with self.voice_state_cache_lock: self.voice_state_cache[validated_voice] = computed_voice_state self.voice_state_cache_access_timestamps[validated_voice] = time.time() return self.voice_state_cache[validated_voice] def get_voice_state_for_clone(self, audio_file_path, prepared_audio_path=None): with self.model_lock: if self.loaded_model is None: raise RuntimeError("TTS model is not loaded. Please try again.") audio_path_to_use = prepared_audio_path if prepared_audio_path is not None else audio_file_path return self.loaded_model.get_state_for_audio_prompt( audio_conditioning=audio_path_to_use, truncate=False ) def generate_audio(self, text_content, voice_state, frames_after_eos, enable_custom_frames): with self.model_lock: if self.loaded_model is None: raise RuntimeError("TTS model is not loaded. Please try again.") processed_frames = int(frames_after_eos) if enable_custom_frames else None generated_audio = self.loaded_model.generate_audio( model_state=voice_state, text_to_generate=text_content, frames_after_eos=processed_frames, copy_state=True ) force_garbage_collection() return generated_audio def save_audio_to_file(self, audio_tensor): with self.model_lock: if self.loaded_model is None: raise RuntimeError("TTS model is not loaded. Cannot save audio.") audio_sample_rate = self.loaded_model.sample_rate audio_numpy_data = audio_tensor.numpy() output_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) scipy.io.wavfile.write(output_file.name, audio_sample_rate, audio_numpy_data) with temporary_files_lock: temporary_files_registry[output_file.name] = time.time() trigger_background_cleanup_check() return output_file.name text_to_speech_manager = TextToSpeechManager() set_text_to_speech_manager(text_to_speech_manager)