Spaces:
Sleeping
Sleeping
| # | |
| # SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org> | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| import time | |
| import tempfile | |
| import threading | |
| import torch | |
| import scipy.io.wavfile | |
| from pocket_tts import TTSModel | |
| from config import ( | |
| AVAILABLE_VOICES, | |
| DEFAULT_VOICE, | |
| DEFAULT_MODEL_VARIANT, | |
| DEFAULT_TEMPERATURE, | |
| DEFAULT_LSD_DECODE_STEPS, | |
| DEFAULT_EOS_THRESHOLD, | |
| VOICE_STATE_CACHE_MAXIMUM_SIZE, | |
| VOICE_STATE_CACHE_CLEANUP_THRESHOLD | |
| ) | |
| from ..core.state import ( | |
| temporary_files_registry, | |
| temporary_files_lock, | |
| set_text_to_speech_manager | |
| ) | |
| from ..core.memory import ( | |
| force_garbage_collection, | |
| memory_cleanup, | |
| perform_memory_cleanup, | |
| trigger_background_cleanup_check, | |
| is_memory_usage_approaching_limit | |
| ) | |
| class TextToSpeechManager: | |
| def __init__(self): | |
| self.loaded_model = None | |
| self.current_configuration = {} | |
| self.voice_state_cache = {} | |
| self.voice_state_cache_access_timestamps = {} | |
| self.voice_state_cache_lock = threading.Lock() | |
| self.model_lock = threading.Lock() | |
| def is_model_loaded(self): | |
| with self.model_lock: | |
| return self.loaded_model is not None | |
| def unload_model_completely(self): | |
| with self.model_lock: | |
| self.clear_voice_state_cache_completely() | |
| if self.loaded_model is not None: | |
| del self.loaded_model | |
| self.loaded_model = None | |
| self.current_configuration = {} | |
| memory_cleanup() | |
| def load_or_get_model( | |
| self, | |
| model_variant, | |
| temperature, | |
| lsd_decode_steps, | |
| noise_clamp, | |
| eos_threshold | |
| ): | |
| perform_memory_cleanup() | |
| processed_variant = str(model_variant or DEFAULT_MODEL_VARIANT).strip() | |
| processed_temperature = float(temperature) if temperature is not None else DEFAULT_TEMPERATURE | |
| processed_lsd_steps = int(lsd_decode_steps) if lsd_decode_steps is not None else DEFAULT_LSD_DECODE_STEPS | |
| processed_noise_clamp = float(noise_clamp) if noise_clamp and float(noise_clamp) > 0 else None | |
| processed_eos_threshold = float(eos_threshold) if eos_threshold is not None else DEFAULT_EOS_THRESHOLD | |
| requested_configuration = { | |
| "variant": processed_variant, | |
| "temp": processed_temperature, | |
| "lsd_decode_steps": processed_lsd_steps, | |
| "noise_clamp": processed_noise_clamp, | |
| "eos_threshold": processed_eos_threshold | |
| } | |
| with self.model_lock: | |
| if self.loaded_model is None or self.current_configuration != requested_configuration: | |
| if self.loaded_model is not None: | |
| self.clear_voice_state_cache_completely() | |
| del self.loaded_model | |
| self.loaded_model = None | |
| memory_cleanup() | |
| self.loaded_model = TTSModel.load_model(**requested_configuration) | |
| self.current_configuration = requested_configuration | |
| self.voice_state_cache = {} | |
| return self.loaded_model | |
| def clear_voice_state_cache_completely(self): | |
| with self.voice_state_cache_lock: | |
| for voice_name in list(self.voice_state_cache.keys()): | |
| voice_state_tensor = self.voice_state_cache.pop(voice_name, None) | |
| if voice_state_tensor is not None: | |
| del voice_state_tensor | |
| self.voice_state_cache.clear() | |
| self.voice_state_cache_access_timestamps.clear() | |
| force_garbage_collection() | |
| def evict_least_recently_used_voice_states(self): | |
| with self.voice_state_cache_lock: | |
| if len(self.voice_state_cache) <= VOICE_STATE_CACHE_CLEANUP_THRESHOLD: | |
| if len(self.voice_state_cache) > 0: | |
| sorted_voice_names_by_access_time = sorted( | |
| self.voice_state_cache_access_timestamps.keys(), | |
| key=lambda voice_name: self.voice_state_cache_access_timestamps[voice_name] | |
| ) | |
| number_of_entries_to_remove = max(1, len(self.voice_state_cache) // 2) | |
| for index in range(min(number_of_entries_to_remove, len(sorted_voice_names_by_access_time))): | |
| voice_name_to_remove = sorted_voice_names_by_access_time[index] | |
| voice_state_tensor = self.voice_state_cache.pop(voice_name_to_remove, None) | |
| self.voice_state_cache_access_timestamps.pop(voice_name_to_remove, None) | |
| if voice_state_tensor is not None: | |
| del voice_state_tensor | |
| force_garbage_collection() | |
| return | |
| sorted_voice_names_by_access_time = sorted( | |
| self.voice_state_cache_access_timestamps.keys(), | |
| key=lambda voice_name: self.voice_state_cache_access_timestamps[voice_name] | |
| ) | |
| number_of_entries_to_remove = len(self.voice_state_cache) - VOICE_STATE_CACHE_CLEANUP_THRESHOLD | |
| for index in range(number_of_entries_to_remove): | |
| voice_name_to_remove = sorted_voice_names_by_access_time[index] | |
| voice_state_tensor = self.voice_state_cache.pop(voice_name_to_remove, None) | |
| self.voice_state_cache_access_timestamps.pop(voice_name_to_remove, None) | |
| if voice_state_tensor is not None: | |
| del voice_state_tensor | |
| force_garbage_collection() | |
| def get_voice_state_for_preset(self, voice_name): | |
| validated_voice = voice_name if voice_name in AVAILABLE_VOICES else DEFAULT_VOICE | |
| with self.voice_state_cache_lock: | |
| if validated_voice in self.voice_state_cache: | |
| self.voice_state_cache_access_timestamps[validated_voice] = time.time() | |
| return self.voice_state_cache[validated_voice] | |
| if is_memory_usage_approaching_limit(): | |
| self.evict_least_recently_used_voice_states() | |
| if len(self.voice_state_cache) >= VOICE_STATE_CACHE_MAXIMUM_SIZE: | |
| self.evict_least_recently_used_voice_states() | |
| with self.model_lock: | |
| if self.loaded_model is None: | |
| raise RuntimeError("TTS model is not loaded. Please try again.") | |
| if validated_voice not in self.voice_state_cache: | |
| computed_voice_state = self.loaded_model.get_state_for_audio_prompt( | |
| audio_conditioning=validated_voice, | |
| truncate=False | |
| ) | |
| with self.voice_state_cache_lock: | |
| self.voice_state_cache[validated_voice] = computed_voice_state | |
| self.voice_state_cache_access_timestamps[validated_voice] = time.time() | |
| return self.voice_state_cache[validated_voice] | |
| def get_voice_state_for_clone(self, audio_file_path, prepared_audio_path=None): | |
| with self.model_lock: | |
| if self.loaded_model is None: | |
| raise RuntimeError("TTS model is not loaded. Please try again.") | |
| audio_path_to_use = prepared_audio_path if prepared_audio_path is not None else audio_file_path | |
| return self.loaded_model.get_state_for_audio_prompt( | |
| audio_conditioning=audio_path_to_use, | |
| truncate=False | |
| ) | |
| def generate_audio(self, text_content, voice_state, frames_after_eos, enable_custom_frames): | |
| with self.model_lock: | |
| if self.loaded_model is None: | |
| raise RuntimeError("TTS model is not loaded. Please try again.") | |
| processed_frames = int(frames_after_eos) if enable_custom_frames else None | |
| generated_audio = self.loaded_model.generate_audio( | |
| model_state=voice_state, | |
| text_to_generate=text_content, | |
| frames_after_eos=processed_frames, | |
| copy_state=True | |
| ) | |
| force_garbage_collection() | |
| return generated_audio | |
| def save_audio_to_file(self, audio_tensor): | |
| with self.model_lock: | |
| if self.loaded_model is None: | |
| raise RuntimeError("TTS model is not loaded. Cannot save audio.") | |
| audio_sample_rate = self.loaded_model.sample_rate | |
| audio_numpy_data = audio_tensor.numpy() | |
| output_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| scipy.io.wavfile.write(output_file.name, audio_sample_rate, audio_numpy_data) | |
| with temporary_files_lock: | |
| temporary_files_registry[output_file.name] = time.time() | |
| trigger_background_cleanup_check() | |
| return output_file.name | |
| text_to_speech_manager = TextToSpeechManager() | |
| set_text_to_speech_manager(text_to_speech_manager) |