Spaces:
Sleeping
Sleeping
File size: 8,739 Bytes
5da0109 dae9fa5 5da0109 dae9fa5 5da0109 dae9fa5 5da0109 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 | #
# SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
# SPDX-License-Identifier: Apache-2.0
#
import time
import tempfile
import threading
import torch
import scipy.io.wavfile
from pocket_tts import TTSModel
from config import (
AVAILABLE_VOICES,
DEFAULT_VOICE,
DEFAULT_MODEL_VARIANT,
DEFAULT_TEMPERATURE,
DEFAULT_LSD_DECODE_STEPS,
DEFAULT_EOS_THRESHOLD,
VOICE_STATE_CACHE_MAXIMUM_SIZE,
VOICE_STATE_CACHE_CLEANUP_THRESHOLD
)
from ..core.state import (
temporary_files_registry,
temporary_files_lock,
set_text_to_speech_manager
)
from ..core.memory import (
force_garbage_collection,
memory_cleanup,
perform_memory_cleanup,
trigger_background_cleanup_check,
is_memory_usage_approaching_limit
)
class TextToSpeechManager:
def __init__(self):
self.loaded_model = None
self.current_configuration = {}
self.voice_state_cache = {}
self.voice_state_cache_access_timestamps = {}
self.voice_state_cache_lock = threading.Lock()
self.model_lock = threading.Lock()
def is_model_loaded(self):
with self.model_lock:
return self.loaded_model is not None
def unload_model_completely(self):
with self.model_lock:
self.clear_voice_state_cache_completely()
if self.loaded_model is not None:
del self.loaded_model
self.loaded_model = None
self.current_configuration = {}
memory_cleanup()
def load_or_get_model(
self,
model_variant,
temperature,
lsd_decode_steps,
noise_clamp,
eos_threshold
):
perform_memory_cleanup()
processed_variant = str(model_variant or DEFAULT_MODEL_VARIANT).strip()
processed_temperature = float(temperature) if temperature is not None else DEFAULT_TEMPERATURE
processed_lsd_steps = int(lsd_decode_steps) if lsd_decode_steps is not None else DEFAULT_LSD_DECODE_STEPS
processed_noise_clamp = float(noise_clamp) if noise_clamp and float(noise_clamp) > 0 else None
processed_eos_threshold = float(eos_threshold) if eos_threshold is not None else DEFAULT_EOS_THRESHOLD
requested_configuration = {
"variant": processed_variant,
"temp": processed_temperature,
"lsd_decode_steps": processed_lsd_steps,
"noise_clamp": processed_noise_clamp,
"eos_threshold": processed_eos_threshold
}
with self.model_lock:
if self.loaded_model is None or self.current_configuration != requested_configuration:
if self.loaded_model is not None:
self.clear_voice_state_cache_completely()
del self.loaded_model
self.loaded_model = None
memory_cleanup()
self.loaded_model = TTSModel.load_model(**requested_configuration)
self.current_configuration = requested_configuration
self.voice_state_cache = {}
return self.loaded_model
def clear_voice_state_cache_completely(self):
with self.voice_state_cache_lock:
for voice_name in list(self.voice_state_cache.keys()):
voice_state_tensor = self.voice_state_cache.pop(voice_name, None)
if voice_state_tensor is not None:
del voice_state_tensor
self.voice_state_cache.clear()
self.voice_state_cache_access_timestamps.clear()
force_garbage_collection()
def evict_least_recently_used_voice_states(self):
with self.voice_state_cache_lock:
if len(self.voice_state_cache) <= VOICE_STATE_CACHE_CLEANUP_THRESHOLD:
if len(self.voice_state_cache) > 0:
sorted_voice_names_by_access_time = sorted(
self.voice_state_cache_access_timestamps.keys(),
key=lambda voice_name: self.voice_state_cache_access_timestamps[voice_name]
)
number_of_entries_to_remove = max(1, len(self.voice_state_cache) // 2)
for index in range(min(number_of_entries_to_remove, len(sorted_voice_names_by_access_time))):
voice_name_to_remove = sorted_voice_names_by_access_time[index]
voice_state_tensor = self.voice_state_cache.pop(voice_name_to_remove, None)
self.voice_state_cache_access_timestamps.pop(voice_name_to_remove, None)
if voice_state_tensor is not None:
del voice_state_tensor
force_garbage_collection()
return
sorted_voice_names_by_access_time = sorted(
self.voice_state_cache_access_timestamps.keys(),
key=lambda voice_name: self.voice_state_cache_access_timestamps[voice_name]
)
number_of_entries_to_remove = len(self.voice_state_cache) - VOICE_STATE_CACHE_CLEANUP_THRESHOLD
for index in range(number_of_entries_to_remove):
voice_name_to_remove = sorted_voice_names_by_access_time[index]
voice_state_tensor = self.voice_state_cache.pop(voice_name_to_remove, None)
self.voice_state_cache_access_timestamps.pop(voice_name_to_remove, None)
if voice_state_tensor is not None:
del voice_state_tensor
force_garbage_collection()
def get_voice_state_for_preset(self, voice_name):
validated_voice = voice_name if voice_name in AVAILABLE_VOICES else DEFAULT_VOICE
with self.voice_state_cache_lock:
if validated_voice in self.voice_state_cache:
self.voice_state_cache_access_timestamps[validated_voice] = time.time()
return self.voice_state_cache[validated_voice]
if is_memory_usage_approaching_limit():
self.evict_least_recently_used_voice_states()
if len(self.voice_state_cache) >= VOICE_STATE_CACHE_MAXIMUM_SIZE:
self.evict_least_recently_used_voice_states()
with self.model_lock:
if self.loaded_model is None:
raise RuntimeError("TTS model is not loaded. Please try again.")
if validated_voice not in self.voice_state_cache:
computed_voice_state = self.loaded_model.get_state_for_audio_prompt(
audio_conditioning=validated_voice,
truncate=False
)
with self.voice_state_cache_lock:
self.voice_state_cache[validated_voice] = computed_voice_state
self.voice_state_cache_access_timestamps[validated_voice] = time.time()
return self.voice_state_cache[validated_voice]
def get_voice_state_for_clone(self, audio_file_path, prepared_audio_path=None):
with self.model_lock:
if self.loaded_model is None:
raise RuntimeError("TTS model is not loaded. Please try again.")
audio_path_to_use = prepared_audio_path if prepared_audio_path is not None else audio_file_path
return self.loaded_model.get_state_for_audio_prompt(
audio_conditioning=audio_path_to_use,
truncate=False
)
def generate_audio(self, text_content, voice_state, frames_after_eos, enable_custom_frames):
with self.model_lock:
if self.loaded_model is None:
raise RuntimeError("TTS model is not loaded. Please try again.")
processed_frames = int(frames_after_eos) if enable_custom_frames else None
generated_audio = self.loaded_model.generate_audio(
model_state=voice_state,
text_to_generate=text_content,
frames_after_eos=processed_frames,
copy_state=True
)
force_garbage_collection()
return generated_audio
def save_audio_to_file(self, audio_tensor):
with self.model_lock:
if self.loaded_model is None:
raise RuntimeError("TTS model is not loaded. Cannot save audio.")
audio_sample_rate = self.loaded_model.sample_rate
audio_numpy_data = audio_tensor.numpy()
output_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
scipy.io.wavfile.write(output_file.name, audio_sample_rate, audio_numpy_data)
with temporary_files_lock:
temporary_files_registry[output_file.name] = time.time()
trigger_background_cleanup_check()
return output_file.name
text_to_speech_manager = TextToSpeechManager()
set_text_to_speech_manager(text_to_speech_manager) |