import logging from typing import Any from models.event import Event, EventType logger = logging.getLogger("s2s_server") class TranscriptStabilizer: def __init__(self, context: Any): self.context = context self.reset() # Bind events self.context.event_bus.subscribe(EventType.PARTIAL_TRANSCRIPT, self.on_partial_transcript) self.context.event_bus.subscribe(EventType.USER_SPEECH_START, self.on_user_speech_start) self.context.event_bus.subscribe(EventType.TURN_COMPLETED, self.on_turn_completed) def reset(self): self.history = [] self.last_stable_text = "" async def on_user_speech_start(self, event: Event): self.reset() async def on_turn_completed(self, event: Event): self.reset() async def on_partial_transcript(self, event: Event): text = event.payload if not text: return self.history.append(text) if len(self.history) > 10: self.history.pop(0) # Compute stable/unstable parts stable, unstable = self.stabilize(text) # Send WS message to client try: await self.context.websocket.send_json({ "type": "partial_transcription", "text": text, "stable": stable, "unstable": unstable, }) except Exception as e: logger.error(f"Failed to send partial transcription to client: {e}") def stabilize(self, current_text: str) -> tuple[str, str]: if len(self.history) < 2: return "", current_text current_words = current_text.split() prev_words = self.history[-2].split() # Find matching word prefix common_len = 0 for w1, w2 in zip(current_words, prev_words): if w1.lower() == w2.lower(): common_len += 1 else: break # Safety margin: keep the last 2 matching words as unstable stable_len = max(0, common_len - 2) # If we have a very short transcript, be less aggressive with safety margin if common_len == len(current_words) and len(current_words) <= 2: stable_len = common_len stable_words = current_words[:stable_len] # Maintain monotone growth: stable prefix cannot shrink # This prevents flickering on the frontend stable_text = " ".join(stable_words) if len(stable_text) < len(self.last_stable_text) and self.last_stable_text.lower().startswith(stable_text.lower()): stable_text = self.last_stable_text # Re-calculate unstable words based on the final stable_text if stable_text: self.last_stable_text = stable_text stable_len_chars = len(stable_text) unstable_text = current_text[stable_len_chars:].strip() else: unstable_text = current_text return stable_text, unstable_text