import socket import struct import json import msgpack import zlib import re from util import calculate_duration_from_bytes, update_motion_generator_duration,load_yaml from typing import Dict, Any, List, Tuple from concurrent.futures import ThreadPoolExecutor, as_completed from aligner import align_words, setup_aligner config = load_yaml() HOST = config["HOST"] PORT = config["PORT"] print(f"Connecting to {HOST}:{PORT}") MAGIC = 0x2333 def patch_socket_keepalive(sock: socket.socket) -> None: """Set keepalive + long timeout to prevent halts on idle.""" sock.settimeout(None) # Never timeout on recv sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) # Platform-specific tuning if hasattr(socket, 'TCP_KEEPIDLE'): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 10) if hasattr(socket, 'TCP_KEEPINTVL'): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 5) if hasattr(socket, 'TCP_KEEPCNT'): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3) def recv_exact(sock: socket.socket, n: int) -> bytes: buf = bytearray() while len(buf) < n: chunk = sock.recv(n - len(buf)) if not chunk: raise EOFError("Connection closed prematurely") buf.extend(chunk) return bytes(buf) def send_frame(sock: socket.socket, event: str, payload: Any) -> None: # Use msgpack instead of JSON raw = msgpack.packb({"event": event, "payload": payload}, use_bin_type=True) comp = zlib.compress(raw) # header = struct.pack(" Dict[str, Any]: header = recv_exact(sock, 12) magic, raw_len, comp_len = struct.unpack(" str: no_tags = re.sub(r"<[^>]+>", "", text) words = re.findall(r"\b[a-zA-Z']+\b", no_tags) return " ".join(words).strip() def align_audio(audio_bytes: bytes, scene_text: str) -> Tuple: """ Helper function that runs both TTS and alignment for a single scene. This entire function will be executed in a parallel thread. """ """ dummy_path = "output_0.wav" if not os.path.exists(dummy_path): raise FileNotFoundError("Dummy file 'output_0.wav' not found.") # Read dummy WAV file as bytes with open(dummy_path, "rb") as f: audio_bytes = f.read() # Strip tags from text (optional) spoken_text = strip_tags(scene_text) """ # Align alignment = align_words(audio_bytes, scene_text) return alignment def generate_audio(scene: Dict[str, Any]) -> Tuple[bytes, str]: """ audio_bytes, audio_base64 = synthesize_for_scene( prompt=scene["txt"], voice=scene.get("voice", "miko"), temperature=scene.get("temperature", 0.6), top_p=scene.get("top_p", 0.8), repetition_penalty=scene.get("repetition_penalty", 1.3), max_tokens=scene.get("max_tokens", 1200), )""" # In a real scenario, this would call your TTS engine. """ dummy_path = "output_0.wav" if not os.path.exists(dummy_path): raise FileNotFoundError("Dummy file 'output_0.wav' not found.") with open(dummy_path, "rb") as f: audio_bytes = f.read() audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")""" return audio_bytes, audio_base64 def handle_connection(sock: socket.socket) -> None: send_frame(sock, "hello", {"role": "tts"}) print("→ hello (role=tts) sent") while True: try: frame = recv_frame(sock) except EOFError: print('[ "Connection closed by the other side" ]') break event = frame.get("event") payload = frame.get("payload") if event != "generate-voice": print(f"⚠️ unknown event {event}, ignored") continue scenes: List[dict] = payload.get("scenes", []) # --- STAGE 1: FAST Audio Generation & Duration Notification --- # The goal here is to get durations to the motion generator ASAP. generated_audio_data = [] print("") print("--- Generating Audios Thread ---") with ThreadPoolExecutor(max_workers=10) as executor: # Submit all the FAST audio generation tasks future_to_scene = { executor.submit(generate_audio, scene): scene for scene in scenes if scene.get("txt") } # As each FAST audio generation task completes... for future in as_completed(future_to_scene): scene = future_to_scene[future] try: scene_id = scene["sceneId"] motion_index = scene.get("motionIndex", 0) # 1. Get the generated audio audio_bytes, audio_base64 = future.result() print("") print(f'[ "Generated Audio {scene_id}, Motion: {motion_index}" ]') # 2. Calculate duration instantly duration = calculate_duration_from_bytes(audio_bytes) # 3. Notify motion generator IMMEDIATELY if duration > 0: update_motion_generator_duration(scene["sceneId"], scene.get("motionIndex", 0), duration) # 4. Store the results to be used in the next (slow) stage generated_audio_data.append({ "scene": scene, "audio_bytes": audio_bytes, "audio_base64": audio_base64 }) except Exception as e: print(f"Error during audio generation for {scene['sceneId']}: {e}") # --- STAGE 2: SLOW Word Alignment in Parallel --- # Now that all notifications are sent, we can perform the slow alignment work. response_by_scene: Dict[str, Any] = {} print("") print("--- Word Alignments Thread ---") with ThreadPoolExecutor(max_workers=10) as executor: # Use the data from Stage 1 to submit SLOW alignment tasks. # We call `align_words` directly (your `align_audio` function is not needed). future_to_data = { executor.submit(align_words, data["audio_bytes"], strip_tags(data["scene"]["txt"])): data for data in generated_audio_data } # As each SLOW alignment task completes... for future in as_completed(future_to_data): data = future_to_data[future] scene = data["scene"] scene_id = scene["sceneId"] motion_index = scene.get("motionIndex", 0) try: # 1. Get the alignment result alignment = future.result() print("") print(f'[ "Aligned {scene_id}, Motion: {motion_index}" ]') # 2. Now, build the final response object with all the data voice_audio = { "motion": motion_index, "audio_base64": data["audio_base64"], # From Stage 1 "alignment": alignment, # From Stage 2 } if scene_id not in response_by_scene: response_by_scene[scene_id] = {"sceneId": scene_id, "audioEvents": []} response_by_scene[scene_id]["audioEvents"].append(voice_audio) except Exception as e: print(f"Error during alignment for scene {scene_id}: {e}") if response_by_scene: send_frame(sock, "voice-generated", list(response_by_scene.values())) print("") print(f"[ ← Audios ({len(response_by_scene)}) sent ]") def main() -> None: # Setup the Orpheus TTS model on startup. #setup_model() # Setup the aligner (does nothing for aeneas, but keeps pattern consistent) setup_aligner() while True: try: with socket.create_connection((HOST, PORT), timeout=60) as sock: patch_socket_keepalive(sock) print(f'["Connected to server at {HOST}:{PORT}"]') handle_connection(sock) except (ConnectionRefusedError, OSError) as e: print(f"Connection error: {e}, retrying in 5s") except Exception as e: print(f"Unhandled error: {e}, reconnecting in 5s") finally: import time time.sleep(5) if __name__ == "__main__": main()