tts / src /accelerator /client.py
hadadrjt's picture
Pocket TTS: Implement safe and efficient processing mechanisms.
02b5975
#
# SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
# SPDX-License-Identifier: Apache-2.0
#
import os
import socket
import struct
import subprocess
import tempfile
import threading
import time
import sys
from typing import Optional, Tuple, Dict, Any
from config import (
ACCELERATOR_SOCKET_PATH,
ACCELERATOR_BINARY_PATH,
ACCELERATOR_WORKER_THREADS,
ACCELERATOR_MEMORY_POOL_MB,
ACCELERATOR_LOG_PREFIX
)
from ..core.state import (
accelerator_log_lock,
accelerator_log_thread,
accelerator_log_stop_event
)
PROTOCOL_MAGIC_NUMBER = 0x50545453
COMMAND_PING = 0
COMMAND_PROCESS_AUDIO = 1
COMMAND_CONVERT_TO_MONO = 2
COMMAND_CONVERT_TO_PCM = 3
COMMAND_RESAMPLE_AUDIO = 4
COMMAND_GET_MEMORY_STATS = 5
COMMAND_CLEAR_MEMORY_POOL = 6
COMMAND_SHUTDOWN = 7
RESPONSE_SUCCESS = 0
RESPONSE_ERROR_INVALID_COMMAND = 1
RESPONSE_ERROR_FILE_NOT_FOUND = 2
RESPONSE_ERROR_PROCESSING_FAILED = 3
RESPONSE_ERROR_MEMORY_ALLOCATION = 4
RESPONSE_ERROR_INTERNAL = 5
REQUEST_HEADER_FORMAT = "=IIII"
RESPONSE_HEADER_FORMAT = "=IIII"
REQUEST_HEADER_SIZE = struct.calcsize(REQUEST_HEADER_FORMAT)
RESPONSE_HEADER_SIZE = struct.calcsize(RESPONSE_HEADER_FORMAT)
PROCESS_AUDIO_REQUEST_FORMAT = "=512s512sII"
PROCESS_AUDIO_REQUEST_SIZE = struct.calcsize(PROCESS_AUDIO_REQUEST_FORMAT)
MEMORY_STATS_RESPONSE_FORMAT = "=QQQ"
MEMORY_STATS_RESPONSE_SIZE = struct.calcsize(MEMORY_STATS_RESPONSE_FORMAT)
accelerator_process_handle = None
accelerator_process_lock = threading.Lock()
request_id_counter = 0
request_id_lock = threading.Lock()
def log_accelerator_message(message: str):
with accelerator_log_lock:
print(f"{ACCELERATOR_LOG_PREFIX} {message}", flush=True)
def stream_accelerator_output(process_handle: subprocess.Popen):
try:
while not accelerator_log_stop_event.is_set():
if process_handle.poll() is not None:
break
if process_handle.stdout:
line = process_handle.stdout.readline()
if line:
decoded_line = line.decode("utf-8", errors="replace").rstrip()
if decoded_line:
log_accelerator_message(decoded_line)
except Exception as stream_error:
log_accelerator_message(f"Log stream error: {str(stream_error)}")
def stream_accelerator_stderr(process_handle: subprocess.Popen):
try:
while not accelerator_log_stop_event.is_set():
if process_handle.poll() is not None:
break
if process_handle.stderr:
line = process_handle.stderr.readline()
if line:
decoded_line = line.decode("utf-8", errors="replace").rstrip()
if decoded_line:
log_accelerator_message(f"[STDERR] {decoded_line}")
except Exception as stream_error:
log_accelerator_message(f"Stderr stream error: {str(stream_error)}")
class AcceleratorClient:
def __init__(self, socket_path: str = ACCELERATOR_SOCKET_PATH):
self.socket_path = socket_path
self.connection_timeout = 5.0
self.read_timeout = 30.0
def is_connected(self) -> bool:
try:
response = self.send_ping()
return response is not None and response.startswith(b"PONG")
except Exception:
return False
def send_ping(self) -> Optional[bytes]:
return self._send_command(COMMAND_PING, b"")
def process_audio(
self,
input_file_path: str,
output_file_path: str,
target_sample_rate: int = 0,
options_flags: int = 0
) -> Tuple[bool, str]:
payload = self._pack_process_audio_request(
input_file_path,
output_file_path,
target_sample_rate,
options_flags
)
log_accelerator_message(f"Processing audio: {input_file_path} -> {output_file_path}")
response = self._send_command(COMMAND_PROCESS_AUDIO, payload)
if response is None:
log_accelerator_message("Failed to communicate with accelerator for process_audio")
return False, "Failed to communicate with accelerator"
response_string = response.decode("utf-8", errors="ignore")
if response_string.startswith("SUCCESS:"):
log_accelerator_message(f"Audio processing succeeded: {response_string[8:]}")
return True, response_string[8:]
elif response_string.startswith("ERROR:"):
log_accelerator_message(f"Audio processing failed: {response_string[6:]}")
return False, response_string[6:]
else:
log_accelerator_message(f"Audio processing unknown response: {response_string}")
return False, response_string
def convert_to_mono(
self,
input_file_path: str,
output_file_path: str
) -> Tuple[bool, str]:
payload = self._pack_process_audio_request(
input_file_path,
output_file_path,
0,
0
)
log_accelerator_message(f"Converting to mono: {input_file_path} -> {output_file_path}")
response = self._send_command(COMMAND_CONVERT_TO_MONO, payload)
if response is None:
log_accelerator_message("Failed to communicate with accelerator for convert_to_mono")
return False, "Failed to communicate with accelerator"
response_string = response.decode("utf-8", errors="ignore")
if response_string.startswith("SUCCESS:"):
log_accelerator_message(f"Mono conversion succeeded: {response_string[8:]}")
return True, response_string[8:]
elif response_string.startswith("ERROR:"):
log_accelerator_message(f"Mono conversion failed: {response_string[6:]}")
return False, response_string[6:]
else:
log_accelerator_message(f"Mono conversion unknown response: {response_string}")
return False, response_string
def convert_to_pcm(
self,
input_file_path: str,
output_file_path: str
) -> Tuple[bool, str]:
payload = self._pack_process_audio_request(
input_file_path,
output_file_path,
0,
0
)
log_accelerator_message(f"Converting to PCM: {input_file_path} -> {output_file_path}")
response = self._send_command(COMMAND_CONVERT_TO_PCM, payload)
if response is None:
log_accelerator_message("Failed to communicate with accelerator for convert_to_pcm")
return False, "Failed to communicate with accelerator"
response_string = response.decode("utf-8", errors="ignore")
if response_string.startswith("SUCCESS:"):
log_accelerator_message(f"PCM conversion succeeded: {response_string[8:]}")
return True, response_string[8:]
elif response_string.startswith("ERROR:"):
log_accelerator_message(f"PCM conversion failed: {response_string[6:]}")
return False, response_string[6:]
else:
log_accelerator_message(f"PCM conversion unknown response: {response_string}")
return False, response_string
def resample_audio(
self,
input_file_path: str,
output_file_path: str,
target_sample_rate: int
) -> Tuple[bool, str]:
payload = self._pack_process_audio_request(
input_file_path,
output_file_path,
target_sample_rate,
0
)
log_accelerator_message(f"Resampling audio to {target_sample_rate}Hz: {input_file_path} -> {output_file_path}")
response = self._send_command(COMMAND_RESAMPLE_AUDIO, payload)
if response is None:
log_accelerator_message("Failed to communicate with accelerator for resample_audio")
return False, "Failed to communicate with accelerator"
response_string = response.decode("utf-8", errors="ignore")
if response_string.startswith("SUCCESS:"):
log_accelerator_message(f"Resampling succeeded: {response_string[8:]}")
return True, response_string[8:]
elif response_string.startswith("ERROR:"):
log_accelerator_message(f"Resampling failed: {response_string[6:]}")
return False, response_string[6:]
else:
log_accelerator_message(f"Resampling unknown response: {response_string}")
return False, response_string
def get_memory_stats(self) -> Optional[Dict[str, int]]:
response = self._send_command(COMMAND_GET_MEMORY_STATS, b"")
if response is None or len(response) < MEMORY_STATS_RESPONSE_SIZE:
log_accelerator_message("Failed to get memory stats from accelerator")
return None
total_allocated, total_used, block_count = struct.unpack(
MEMORY_STATS_RESPONSE_FORMAT,
response[:MEMORY_STATS_RESPONSE_SIZE]
)
stats = {
"total_allocated_bytes": total_allocated,
"total_used_bytes": total_used,
"block_count": block_count
}
log_accelerator_message(f"Memory stats: allocated={total_allocated}, used={total_used}, blocks={block_count}")
return stats
def clear_memory_pool(self) -> bool:
log_accelerator_message("Clearing accelerator memory pool")
response = self._send_command(COMMAND_CLEAR_MEMORY_POOL, b"")
success = response is not None
if success:
log_accelerator_message("Memory pool cleared successfully")
else:
log_accelerator_message("Failed to clear memory pool")
return success
def shutdown_accelerator(self) -> bool:
log_accelerator_message("Sending shutdown command to accelerator")
response = self._send_command(COMMAND_SHUTDOWN, b"")
return response is not None
def _get_next_request_id(self) -> int:
global request_id_counter
with request_id_lock:
request_id_counter += 1
return request_id_counter
def _pack_process_audio_request(
self,
input_path: str,
output_path: str,
target_sample_rate: int,
options_flags: int
) -> bytes:
input_path_bytes = input_path.encode("utf-8")[:511] + b"\x00"
output_path_bytes = output_path.encode("utf-8")[:511] + b"\x00"
input_path_padded = input_path_bytes.ljust(512, b"\x00")
output_path_padded = output_path_bytes.ljust(512, b"\x00")
return struct.pack(
PROCESS_AUDIO_REQUEST_FORMAT,
input_path_padded,
output_path_padded,
target_sample_rate,
options_flags
)
def _send_command(
self,
command_type: int,
payload: bytes
) -> Optional[bytes]:
try:
client_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
client_socket.settimeout(self.connection_timeout)
client_socket.connect(self.socket_path)
request_id = self._get_next_request_id()
request_header = struct.pack(
REQUEST_HEADER_FORMAT,
PROTOCOL_MAGIC_NUMBER,
command_type,
len(payload),
request_id
)
client_socket.sendall(request_header)
if payload:
client_socket.sendall(payload)
client_socket.settimeout(self.read_timeout)
response_header_data = self._receive_exactly(client_socket, RESPONSE_HEADER_SIZE)
if response_header_data is None:
client_socket.close()
return None
magic_number, status_code, payload_size, response_request_id = struct.unpack(
RESPONSE_HEADER_FORMAT,
response_header_data
)
if magic_number != PROTOCOL_MAGIC_NUMBER:
log_accelerator_message(f"Invalid magic number in response: {magic_number}")
client_socket.close()
return None
if response_request_id != request_id:
log_accelerator_message(f"Request ID mismatch: expected {request_id}, got {response_request_id}")
client_socket.close()
return None
response_payload = b""
if payload_size > 0:
response_payload = self._receive_exactly(client_socket, payload_size)
if response_payload is None:
client_socket.close()
return None
client_socket.close()
if status_code != RESPONSE_SUCCESS:
return response_payload if response_payload else None
return response_payload
except socket.timeout:
log_accelerator_message("Socket timeout while communicating with accelerator")
return None
except socket.error as socket_err:
log_accelerator_message(f"Socket error: {str(socket_err)}")
return None
except Exception as general_error:
log_accelerator_message(f"Unexpected error: {str(general_error)}")
return None
def _receive_exactly(
self,
client_socket: socket.socket,
num_bytes: int
) -> Optional[bytes]:
received_data = b""
remaining_bytes = num_bytes
while remaining_bytes > 0:
try:
chunk = client_socket.recv(remaining_bytes)
if not chunk:
return None
received_data += chunk
remaining_bytes -= len(chunk)
except socket.timeout:
return None
except socket.error:
return None
return received_data
def is_accelerator_available() -> bool:
if not os.path.exists(ACCELERATOR_SOCKET_PATH):
return False
client = AcceleratorClient()
return client.is_connected()
def start_accelerator_daemon() -> bool:
global accelerator_process_handle
from ..core import state as global_state
with accelerator_process_lock:
if accelerator_process_handle is not None:
if accelerator_process_handle.poll() is None:
return True
if not os.path.exists(ACCELERATOR_BINARY_PATH):
log_accelerator_message(f"Accelerator binary not found: {ACCELERATOR_BINARY_PATH}")
return False
try:
log_accelerator_message("Starting accelerator daemon...")
global_state.accelerator_log_stop_event.clear()
accelerator_process_handle = subprocess.Popen(
[
ACCELERATOR_BINARY_PATH,
"--socket", ACCELERATOR_SOCKET_PATH,
"--threads", str(ACCELERATOR_WORKER_THREADS),
"--memory", str(ACCELERATOR_MEMORY_POOL_MB)
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
start_new_session=True
)
stdout_thread = threading.Thread(
target=stream_accelerator_output,
args=(accelerator_process_handle,),
daemon=True,
name="AcceleratorStdoutThread"
)
stdout_thread.start()
stderr_thread = threading.Thread(
target=stream_accelerator_stderr,
args=(accelerator_process_handle,),
daemon=True,
name="AcceleratorStderrThread"
)
stderr_thread.start()
for attempt_index in range(50):
time.sleep(0.1)
if is_accelerator_available():
log_accelerator_message("Accelerator daemon started and responding")
return True
available = is_accelerator_available()
if available:
log_accelerator_message("Accelerator daemon started successfully")
else:
log_accelerator_message("Accelerator daemon started but not responding")
return available
except Exception as start_error:
log_accelerator_message(f"Failed to start accelerator daemon: {str(start_error)}")
return False
def stop_accelerator_daemon() -> bool:
global accelerator_process_handle
from ..core import state as global_state
with accelerator_process_lock:
global_state.accelerator_log_stop_event.set()
if is_accelerator_available():
try:
log_accelerator_message("Sending shutdown command to accelerator...")
client = AcceleratorClient()
client.shutdown_accelerator()
time.sleep(0.5)
except Exception as shutdown_error:
log_accelerator_message(f"Error during shutdown command: {str(shutdown_error)}")
if accelerator_process_handle is not None:
if accelerator_process_handle.poll() is None:
try:
log_accelerator_message("Terminating accelerator process...")
accelerator_process_handle.terminate()
accelerator_process_handle.wait(timeout=5)
log_accelerator_message("Accelerator process terminated")
except subprocess.TimeoutExpired:
log_accelerator_message("Accelerator process did not terminate, killing...")
accelerator_process_handle.kill()
accelerator_process_handle.wait()
log_accelerator_message("Accelerator process killed")
accelerator_process_handle = None
return True
def process_audio_with_accelerator(
input_file_path: str,
output_file_path: str
) -> Tuple[bool, str]:
if not is_accelerator_available():
return False, "Accelerator not available"
client = AcceleratorClient()
return client.process_audio(input_file_path, output_file_path)
def convert_to_mono_with_accelerator(
input_file_path: str,
output_file_path: str
) -> Tuple[bool, str]:
if not is_accelerator_available():
return False, "Accelerator not available"
client = AcceleratorClient()
return client.convert_to_mono(input_file_path, output_file_path)
def convert_to_pcm_with_accelerator(
input_file_path: str,
output_file_path: str
) -> Tuple[bool, str]:
if not is_accelerator_available():
return False, "Accelerator not available"
client = AcceleratorClient()
return client.convert_to_pcm(input_file_path, output_file_path)
def resample_audio_with_accelerator(
input_file_path: str,
output_file_path: str,
target_sample_rate: int
) -> Tuple[bool, str]:
if not is_accelerator_available():
return False, "Accelerator not available"
client = AcceleratorClient()
return client.resample_audio(input_file_path, output_file_path, target_sample_rate)
def get_accelerator_memory_stats() -> Optional[Dict[str, int]]:
if not is_accelerator_available():
return None
client = AcceleratorClient()
return client.get_memory_stats()
def clear_accelerator_memory_pool() -> bool:
if not is_accelerator_available():
return False
client = AcceleratorClient()
return client.clear_memory_pool()