Spaces:
Paused
Paused
rudyByte
fix: implement robust state barge-in transitions, fix safety reset duration check, and upgrade conversational behavior guidelines
b2ed694 | from dataclasses import dataclass, field | |
| import time | |
| import asyncio | |
| from collections import Counter | |
| from typing import Optional | |
| class CallStateMachine: | |
| VALID_TRANSITIONS = { | |
| "greeting": {"listening", "processing", "closing", "ended"}, | |
| "listening": {"processing", "closing", "ended"}, | |
| "processing": {"responding", "tool_calling", "clarifying", "closing", "ended"}, | |
| "responding": {"listening", "processing", "closing", "ended"}, | |
| "clarifying": {"listening", "processing", "closing", "ended"}, | |
| "confirming": {"listening", "processing", "tool_calling", "closing", "ended"}, | |
| "tool_calling": {"responding", "closing", "ended"}, | |
| "closing": {"ended"}, | |
| "ended": set(), | |
| } | |
| def transition(self, session: "CallSession", new_state: str) -> bool: | |
| old_state = session.call_state | |
| if new_state == old_state: | |
| return True | |
| if new_state == "ended": | |
| session.call_state = new_state | |
| print(f"[{session.call_sid}] State: {old_state} -> {new_state}") | |
| return True | |
| allowed = self.VALID_TRANSITIONS.get(old_state, set()) | |
| if new_state not in allowed and old_state != "ended": | |
| print(f"[{session.call_sid}] Illegal state transition ignored: {old_state} -> {new_state}") | |
| return False | |
| session.call_state = new_state | |
| print(f"[{session.call_sid}] State: {old_state} -> {new_state}") | |
| return True | |
| class LatencyTracker: | |
| def __init__(self, session: "CallSession"): | |
| self.session = session | |
| self.marks = {} | |
| def mark(self, name: str): | |
| self.marks[name] = time.time() | |
| def elapsed_ms(self, start_mark: str, end_mark: Optional[str] = None) -> float: | |
| start = self.marks.get(start_mark) | |
| if start is None: | |
| return 0.0 | |
| end = self.marks.get(end_mark, time.time()) if end_mark else time.time() | |
| return round((end - start) * 1000, 1) | |
| def report_turn(self) -> dict: | |
| return { | |
| "stt_ms": self.elapsed_ms("user_speech_end", "stt_complete"), | |
| "llm_first_sentence_ms": self.elapsed_ms("stt_complete", "llm_first_sentence"), | |
| "tts_first_audio_ms": self.elapsed_ms("llm_first_sentence", "tts_first_audio"), | |
| "turnaround_ms": self.elapsed_ms("user_speech_end", "tts_first_audio"), | |
| } | |
| class SessionRegistry: | |
| def __init__(self): | |
| self._sessions: dict[str, CallSession] = {} | |
| self._lock = asyncio.Lock() | |
| async def register(self, session: "CallSession"): | |
| async with self._lock: | |
| self._sessions[session.call_sid] = session | |
| print(f"Session registered: {session.call_sid} total={len(self._sessions)}") | |
| async def unregister(self, call_sid: str): | |
| async with self._lock: | |
| session = self._sessions.pop(call_sid, None) | |
| if session: | |
| print(f"Session closed: {call_sid} remaining={len(self._sessions)}") | |
| async def get(self, call_sid: str) -> Optional["CallSession"]: | |
| async with self._lock: | |
| return self._sessions.get(call_sid) | |
| def active_count(self) -> int: | |
| return len(self._sessions) | |
| class CallSession: | |
| call_sid: str | |
| tenant_id: str | |
| tenant_config: dict | |
| start_time: float | |
| current_language: str # 'hindi', 'gujarati', 'english' | |
| lang_detection_buffer: list = field(default_factory=list) | |
| conversation_history: list = field(default_factory=list) | |
| turn_count: int = 0 | |
| latencies: list = field(default_factory=list) | |
| lead_data: dict = field(default_factory=lambda: { | |
| "name": "Unknown", | |
| "time": "Not set", | |
| "service": "General Inquiry", | |
| "phone": "N/A", | |
| "notes": "" | |
| }) | |
| blocked: bool = False | |
| caller_name: str = "Unknown" | |
| language: str = "" | |
| is_agent_speaking: bool = False | |
| agent_speaking_until: float = 0.0 | |
| last_user_speech_at: float = 0.0 | |
| silence_warning_sent: bool = False | |
| consecutive_empty_stt: int = 0 | |
| consecutive_errors: int = 0 | |
| call_state: str = "greeting" | |
| barge_in_allowed: bool = False | |
| filler_task: Optional[asyncio.Task] = None | |
| stream_sid: Optional[str] = None | |
| websocket: object = None | |
| call_context: dict = field(default_factory=dict) | |
| state_machine: CallStateMachine = field(default_factory=CallStateMachine) | |
| def __post_init__(self): | |
| if not self.language: | |
| self.language = self.current_language | |
| if not self.last_user_speech_at: | |
| self.last_user_speech_at = self.start_time | |
| def transition(self, new_state: str) -> bool: | |
| return self.state_machine.transition(self, new_state) | |
| def update_lead_data(self, new_data: dict): | |
| """Merge newly extracted data into the persistent lead state.""" | |
| for key, value in new_data.items(): | |
| if value and isinstance(value, str) and value.lower() != "unknown" and value != "N/A": | |
| self.lead_data[key] = value | |
| elif value and not isinstance(value, str): | |
| # Accept non-string values (e.g. booleans, ints) directly | |
| self.lead_data[key] = value | |
| print(f"[{self.call_sid}] Updated Lead State: {self.lead_data}") | |
| def update_language(self, detected: str, confidence: float): | |
| # ... (rest of the file stays same) | |
| """ | |
| Continuously update language based on detected speech. | |
| Never permanently locks — allows real-time language switching | |
| at any point in the call (Hindi → Gujarati → English etc.) | |
| """ | |
| if confidence < 0.6: | |
| return # Low confidence — ignore | |
| lang_map = {'gu': 'gujarati', 'hi': 'hindi', 'en': 'english'} | |
| detected_full = lang_map.get(detected, None) | |
| if not detected_full: | |
| return # Unknown language code — ignore | |
| self.lang_detection_buffer.append(detected_full) | |
| # Keep rolling window of last 3 detections | |
| if len(self.lang_detection_buffer) > 3: | |
| self.lang_detection_buffer = self.lang_detection_buffer[-3:] | |
| # Switch if 2 out of last 3 detections agree on a new language | |
| if len(self.lang_detection_buffer) >= 2: | |
| most_common = Counter(self.lang_detection_buffer).most_common(1)[0][0] | |
| most_common_count = Counter(self.lang_detection_buffer)[most_common] | |
| if most_common_count >= 2 and most_common != self.current_language: | |
| prev = self.current_language | |
| self.current_language = most_common | |
| print(f"[{self.call_sid}] Language switched: {prev} → {self.current_language}") | |
| def add_turn(self, role: str, content: str, latency_ms: int = 0, tool_call_id: str = None, name: str = None): | |
| msg = {"role": role, "content": content} | |
| if tool_call_id: msg["tool_call_id"] = tool_call_id | |
| if name: msg["name"] = name | |
| self.conversation_history.append(msg) | |
| if latency_ms > 0: | |
| self.latencies.append(latency_ms) | |
| self.turn_count += 1 | |
| # Keep context window small to maintain speed and save tokens | |
| if len(self.conversation_history) > 20: | |
| self.conversation_history = self.conversation_history[-20:] | |