pocket-tts / src /core /memory.py
hadadrjt's picture
Pocket TTS: Let's take this seriously.
5da0109
#
# SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
# SPDX-License-Identifier: Apache-2.0
#
import os
import gc
import time
import atexit
import threading
import torch
from config import (
TEMPORARY_FILE_LIFETIME_SECONDS,
BACKGROUND_CLEANUP_INTERVAL,
MEMORY_WARNING_THRESHOLD,
MEMORY_CRITICAL_THRESHOLD,
MEMORY_CHECK_INTERVAL,
MEMORY_IDLE_TARGET,
MAXIMUM_MEMORY_USAGE
)
from ..core.state import (
temporary_files_registry,
temporary_files_lock,
memory_enforcement_lock,
background_cleanup_thread,
background_cleanup_stop_event,
background_cleanup_trigger_event,
check_if_generation_is_currently_active,
get_text_to_speech_manager
)
def get_current_memory_usage():
try:
with open('/proc/self/status', 'r') as status_file:
for line in status_file:
if line.startswith('VmRSS:'):
memory_value_kb = int(line.split()[1])
return memory_value_kb * 1024
except Exception:
pass
try:
with open('/proc/self/statm', 'r') as statm_file:
statm_values = statm_file.read().split()
resident_pages = int(statm_values[1])
page_size = os.sysconf('SC_PAGE_SIZE')
return resident_pages * page_size
except Exception:
pass
try:
import resource
import platform
memory_usage_kilobytes = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
if platform.system() == "Darwin":
return memory_usage_kilobytes
else:
return memory_usage_kilobytes * 1024
except Exception:
pass
return 0
def is_memory_usage_within_limit():
current_memory_usage = get_current_memory_usage()
return current_memory_usage < MAXIMUM_MEMORY_USAGE
def is_memory_usage_approaching_limit():
current_memory_usage = get_current_memory_usage()
return current_memory_usage >= MEMORY_WARNING_THRESHOLD
def is_memory_usage_critical():
current_memory_usage = get_current_memory_usage()
return current_memory_usage >= MEMORY_CRITICAL_THRESHOLD
def is_memory_above_idle_target():
current_memory_usage = get_current_memory_usage()
return current_memory_usage > MEMORY_IDLE_TARGET
def force_garbage_collection():
gc.collect(0)
gc.collect(1)
gc.collect(2)
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
def memory_cleanup():
force_garbage_collection()
try:
import ctypes
libc = ctypes.CDLL("libc.so.6")
libc.malloc_trim(0)
except Exception:
pass
force_garbage_collection()
def perform_memory_cleanup():
force_garbage_collection()
tts_manager = get_text_to_speech_manager()
if tts_manager is not None:
tts_manager.evict_least_recently_used_voice_states()
memory_cleanup()
def cleanup_expired_temporary_files():
current_timestamp = time.time()
expired_files = []
with temporary_files_lock:
for file_path, creation_timestamp in list(temporary_files_registry.items()):
if current_timestamp - creation_timestamp > TEMPORARY_FILE_LIFETIME_SECONDS:
expired_files.append(file_path)
for file_path in expired_files:
try:
if os.path.exists(file_path):
os.remove(file_path)
del temporary_files_registry[file_path]
except Exception:
pass
def cleanup_all_temporary_files_immediately():
with temporary_files_lock:
for file_path in list(temporary_files_registry.keys()):
try:
if os.path.exists(file_path):
os.remove(file_path)
del temporary_files_registry[file_path]
except Exception:
pass
def has_temporary_files_pending_cleanup():
with temporary_files_lock:
if len(temporary_files_registry) == 0:
return False
current_timestamp = time.time()
for file_path, creation_timestamp in temporary_files_registry.items():
if current_timestamp - creation_timestamp > TEMPORARY_FILE_LIFETIME_SECONDS:
return True
return False
def has_any_temporary_files_registered():
with temporary_files_lock:
return len(temporary_files_registry) > 0
def calculate_time_until_next_file_expiration():
with temporary_files_lock:
if len(temporary_files_registry) == 0:
return None
current_timestamp = time.time()
minimum_time_until_expiration = None
for file_path, creation_timestamp in temporary_files_registry.items():
time_since_creation = current_timestamp - creation_timestamp
time_until_expiration = TEMPORARY_FILE_LIFETIME_SECONDS - time_since_creation
if time_until_expiration <= 0:
return 0
if minimum_time_until_expiration is None or time_until_expiration < minimum_time_until_expiration:
minimum_time_until_expiration = time_until_expiration
return minimum_time_until_expiration
def enforce_memory_limit_if_exceeded():
with memory_enforcement_lock:
generation_is_active = check_if_generation_is_currently_active()
current_memory_usage = get_current_memory_usage()
if current_memory_usage < MEMORY_WARNING_THRESHOLD:
return True
force_garbage_collection()
current_memory_usage = get_current_memory_usage()
if current_memory_usage < MEMORY_WARNING_THRESHOLD:
return True
tts_manager = get_text_to_speech_manager()
if tts_manager is not None:
tts_manager.evict_least_recently_used_voice_states()
memory_cleanup()
current_memory_usage = get_current_memory_usage()
if current_memory_usage < MEMORY_CRITICAL_THRESHOLD:
return True
if tts_manager is not None:
tts_manager.clear_voice_state_cache_completely()
cleanup_all_temporary_files_immediately()
memory_cleanup()
current_memory_usage = get_current_memory_usage()
if current_memory_usage < MAXIMUM_MEMORY_USAGE:
return True
if generation_is_active:
return current_memory_usage < MAXIMUM_MEMORY_USAGE
if tts_manager is not None:
tts_manager.unload_model_completely()
memory_cleanup()
current_memory_usage = get_current_memory_usage()
return current_memory_usage < MAXIMUM_MEMORY_USAGE
def perform_idle_memory_reduction():
if check_if_generation_is_currently_active():
return
with memory_enforcement_lock:
current_memory_usage = get_current_memory_usage()
if current_memory_usage <= MEMORY_IDLE_TARGET:
return
force_garbage_collection()
current_memory_usage = get_current_memory_usage()
if current_memory_usage <= MEMORY_IDLE_TARGET:
return
if check_if_generation_is_currently_active():
return
tts_manager = get_text_to_speech_manager()
if tts_manager is not None:
tts_manager.evict_least_recently_used_voice_states()
memory_cleanup()
current_memory_usage = get_current_memory_usage()
if current_memory_usage <= MEMORY_IDLE_TARGET:
return
if check_if_generation_is_currently_active():
return
if tts_manager is not None:
tts_manager.clear_voice_state_cache_completely()
memory_cleanup()
current_memory_usage = get_current_memory_usage()
if current_memory_usage <= MEMORY_IDLE_TARGET:
return
if check_if_generation_is_currently_active():
return
if tts_manager is not None:
tts_manager.unload_model_completely()
memory_cleanup()
def perform_background_cleanup_cycle():
last_memory_check_timestamp = 0
while not background_cleanup_stop_event.is_set():
time_until_next_expiration = calculate_time_until_next_file_expiration()
current_timestamp = time.time()
time_since_last_memory_check = current_timestamp - last_memory_check_timestamp
if time_until_next_expiration is not None:
if time_until_next_expiration <= 0:
wait_duration = 1
else:
wait_duration = min(
time_until_next_expiration + 1,
MEMORY_CHECK_INTERVAL,
BACKGROUND_CLEANUP_INTERVAL
)
else:
if is_memory_above_idle_target() and not check_if_generation_is_currently_active():
wait_duration = MEMORY_CHECK_INTERVAL
else:
background_cleanup_trigger_event.clear()
triggered = background_cleanup_trigger_event.wait(timeout=BACKGROUND_CLEANUP_INTERVAL)
if background_cleanup_stop_event.is_set():
break
if triggered:
continue
else:
if not check_if_generation_is_currently_active():
perform_idle_memory_reduction()
continue
background_cleanup_stop_event.wait(timeout=wait_duration)
if background_cleanup_stop_event.is_set():
break
if has_temporary_files_pending_cleanup():
cleanup_expired_temporary_files()
current_timestamp = time.time()
time_since_last_memory_check = current_timestamp - last_memory_check_timestamp
if time_since_last_memory_check >= MEMORY_CHECK_INTERVAL:
if not check_if_generation_is_currently_active():
if is_memory_usage_critical():
enforce_memory_limit_if_exceeded()
elif is_memory_above_idle_target():
perform_idle_memory_reduction()
last_memory_check_timestamp = current_timestamp
def trigger_background_cleanup_check():
background_cleanup_trigger_event.set()
def start_background_cleanup_thread():
global background_cleanup_thread
from ..core import state as global_state
if global_state.background_cleanup_thread is None or not global_state.background_cleanup_thread.is_alive():
background_cleanup_stop_event.clear()
background_cleanup_trigger_event.clear()
global_state.background_cleanup_thread = threading.Thread(
target=perform_background_cleanup_cycle,
daemon=True,
name="BackgroundCleanupThread"
)
global_state.background_cleanup_thread.start()
def stop_background_cleanup_thread():
from ..core import state as global_state
background_cleanup_stop_event.set()
background_cleanup_trigger_event.set()
if global_state.background_cleanup_thread is not None and global_state.background_cleanup_thread.is_alive():
global_state.background_cleanup_thread.join(timeout=5)
atexit.register(stop_background_cleanup_thread)