maya-voice-agent / src /session.py
rudyByte
fix: implement robust state barge-in transitions, fix safety reset duration check, and upgrade conversational behavior guidelines
b2ed694
from dataclasses import dataclass, field
import time
import asyncio
from collections import Counter
from typing import Optional
class CallStateMachine:
VALID_TRANSITIONS = {
"greeting": {"listening", "processing", "closing", "ended"},
"listening": {"processing", "closing", "ended"},
"processing": {"responding", "tool_calling", "clarifying", "closing", "ended"},
"responding": {"listening", "processing", "closing", "ended"},
"clarifying": {"listening", "processing", "closing", "ended"},
"confirming": {"listening", "processing", "tool_calling", "closing", "ended"},
"tool_calling": {"responding", "closing", "ended"},
"closing": {"ended"},
"ended": set(),
}
def transition(self, session: "CallSession", new_state: str) -> bool:
old_state = session.call_state
if new_state == old_state:
return True
if new_state == "ended":
session.call_state = new_state
print(f"[{session.call_sid}] State: {old_state} -> {new_state}")
return True
allowed = self.VALID_TRANSITIONS.get(old_state, set())
if new_state not in allowed and old_state != "ended":
print(f"[{session.call_sid}] Illegal state transition ignored: {old_state} -> {new_state}")
return False
session.call_state = new_state
print(f"[{session.call_sid}] State: {old_state} -> {new_state}")
return True
class LatencyTracker:
def __init__(self, session: "CallSession"):
self.session = session
self.marks = {}
def mark(self, name: str):
self.marks[name] = time.time()
def elapsed_ms(self, start_mark: str, end_mark: Optional[str] = None) -> float:
start = self.marks.get(start_mark)
if start is None:
return 0.0
end = self.marks.get(end_mark, time.time()) if end_mark else time.time()
return round((end - start) * 1000, 1)
def report_turn(self) -> dict:
return {
"stt_ms": self.elapsed_ms("user_speech_end", "stt_complete"),
"llm_first_sentence_ms": self.elapsed_ms("stt_complete", "llm_first_sentence"),
"tts_first_audio_ms": self.elapsed_ms("llm_first_sentence", "tts_first_audio"),
"turnaround_ms": self.elapsed_ms("user_speech_end", "tts_first_audio"),
}
class SessionRegistry:
def __init__(self):
self._sessions: dict[str, CallSession] = {}
self._lock = asyncio.Lock()
async def register(self, session: "CallSession"):
async with self._lock:
self._sessions[session.call_sid] = session
print(f"Session registered: {session.call_sid} total={len(self._sessions)}")
async def unregister(self, call_sid: str):
async with self._lock:
session = self._sessions.pop(call_sid, None)
if session:
print(f"Session closed: {call_sid} remaining={len(self._sessions)}")
async def get(self, call_sid: str) -> Optional["CallSession"]:
async with self._lock:
return self._sessions.get(call_sid)
def active_count(self) -> int:
return len(self._sessions)
@dataclass
class CallSession:
call_sid: str
tenant_id: str
tenant_config: dict
start_time: float
current_language: str # 'hindi', 'gujarati', 'english'
lang_detection_buffer: list = field(default_factory=list)
conversation_history: list = field(default_factory=list)
turn_count: int = 0
latencies: list = field(default_factory=list)
lead_data: dict = field(default_factory=lambda: {
"name": "Unknown",
"time": "Not set",
"service": "General Inquiry",
"phone": "N/A",
"notes": ""
})
blocked: bool = False
caller_name: str = "Unknown"
language: str = ""
is_agent_speaking: bool = False
agent_speaking_until: float = 0.0
last_user_speech_at: float = 0.0
silence_warning_sent: bool = False
consecutive_empty_stt: int = 0
consecutive_errors: int = 0
call_state: str = "greeting"
barge_in_allowed: bool = False
filler_task: Optional[asyncio.Task] = None
stream_sid: Optional[str] = None
websocket: object = None
call_context: dict = field(default_factory=dict)
state_machine: CallStateMachine = field(default_factory=CallStateMachine)
def __post_init__(self):
if not self.language:
self.language = self.current_language
if not self.last_user_speech_at:
self.last_user_speech_at = self.start_time
def transition(self, new_state: str) -> bool:
return self.state_machine.transition(self, new_state)
def update_lead_data(self, new_data: dict):
"""Merge newly extracted data into the persistent lead state."""
for key, value in new_data.items():
if value and isinstance(value, str) and value.lower() != "unknown" and value != "N/A":
self.lead_data[key] = value
elif value and not isinstance(value, str):
# Accept non-string values (e.g. booleans, ints) directly
self.lead_data[key] = value
print(f"[{self.call_sid}] Updated Lead State: {self.lead_data}")
def update_language(self, detected: str, confidence: float):
# ... (rest of the file stays same)
"""
Continuously update language based on detected speech.
Never permanently locks — allows real-time language switching
at any point in the call (Hindi → Gujarati → English etc.)
"""
if confidence < 0.6:
return # Low confidence — ignore
lang_map = {'gu': 'gujarati', 'hi': 'hindi', 'en': 'english'}
detected_full = lang_map.get(detected, None)
if not detected_full:
return # Unknown language code — ignore
self.lang_detection_buffer.append(detected_full)
# Keep rolling window of last 3 detections
if len(self.lang_detection_buffer) > 3:
self.lang_detection_buffer = self.lang_detection_buffer[-3:]
# Switch if 2 out of last 3 detections agree on a new language
if len(self.lang_detection_buffer) >= 2:
most_common = Counter(self.lang_detection_buffer).most_common(1)[0][0]
most_common_count = Counter(self.lang_detection_buffer)[most_common]
if most_common_count >= 2 and most_common != self.current_language:
prev = self.current_language
self.current_language = most_common
print(f"[{self.call_sid}] Language switched: {prev}{self.current_language}")
def add_turn(self, role: str, content: str, latency_ms: int = 0, tool_call_id: str = None, name: str = None):
msg = {"role": role, "content": content}
if tool_call_id: msg["tool_call_id"] = tool_call_id
if name: msg["name"] = name
self.conversation_history.append(msg)
if latency_ms > 0:
self.latencies.append(latency_ms)
self.turn_count += 1
# Keep context window small to maintain speed and save tokens
if len(self.conversation_history) > 20:
self.conversation_history = self.conversation_history[-20:]