File size: 7,303 Bytes
4cf98c9
 
12ac7b8
4cf98c9
12ac7b8
 
 
 
 
b2ed694
12ac7b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cf98c9
 
 
 
 
 
 
5138295
4cf98c9
5138295
4cf98c9
 
ec821d3
 
 
 
 
 
 
4cf98c9
12ac7b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5138295
ec821d3
 
 
0db2958
 
 
 
ec821d3
 
 
4cf98c9
ec821d3
5138295
 
 
 
 
 
 
b061af2
4cf98c9
b061af2
 
 
 
4cf98c9
b061af2
5138295
 
 
 
 
 
4cf98c9
b061af2
 
5138295
4cf98c9
5138295
 
2e4c13e
 
 
 
 
 
4cf98c9
 
 
5138295
4cf98c9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from dataclasses import dataclass, field
import time
import asyncio
from collections import Counter
from typing import Optional


class CallStateMachine:
    VALID_TRANSITIONS = {
        "greeting": {"listening", "processing", "closing", "ended"},
        "listening": {"processing", "closing", "ended"},
        "processing": {"responding", "tool_calling", "clarifying", "closing", "ended"},
        "responding": {"listening", "processing", "closing", "ended"},
        "clarifying": {"listening", "processing", "closing", "ended"},
        "confirming": {"listening", "processing", "tool_calling", "closing", "ended"},
        "tool_calling": {"responding", "closing", "ended"},
        "closing": {"ended"},
        "ended": set(),
    }

    def transition(self, session: "CallSession", new_state: str) -> bool:
        old_state = session.call_state
        if new_state == old_state:
            return True
        if new_state == "ended":
            session.call_state = new_state
            print(f"[{session.call_sid}] State: {old_state} -> {new_state}")
            return True
        allowed = self.VALID_TRANSITIONS.get(old_state, set())
        if new_state not in allowed and old_state != "ended":
            print(f"[{session.call_sid}] Illegal state transition ignored: {old_state} -> {new_state}")
            return False
        session.call_state = new_state
        print(f"[{session.call_sid}] State: {old_state} -> {new_state}")
        return True


class LatencyTracker:
    def __init__(self, session: "CallSession"):
        self.session = session
        self.marks = {}

    def mark(self, name: str):
        self.marks[name] = time.time()

    def elapsed_ms(self, start_mark: str, end_mark: Optional[str] = None) -> float:
        start = self.marks.get(start_mark)
        if start is None:
            return 0.0
        end = self.marks.get(end_mark, time.time()) if end_mark else time.time()
        return round((end - start) * 1000, 1)

    def report_turn(self) -> dict:
        return {
            "stt_ms": self.elapsed_ms("user_speech_end", "stt_complete"),
            "llm_first_sentence_ms": self.elapsed_ms("stt_complete", "llm_first_sentence"),
            "tts_first_audio_ms": self.elapsed_ms("llm_first_sentence", "tts_first_audio"),
            "turnaround_ms": self.elapsed_ms("user_speech_end", "tts_first_audio"),
        }


class SessionRegistry:
    def __init__(self):
        self._sessions: dict[str, CallSession] = {}
        self._lock = asyncio.Lock()

    async def register(self, session: "CallSession"):
        async with self._lock:
            self._sessions[session.call_sid] = session
            print(f"Session registered: {session.call_sid} total={len(self._sessions)}")

    async def unregister(self, call_sid: str):
        async with self._lock:
            session = self._sessions.pop(call_sid, None)
            if session:
                print(f"Session closed: {call_sid} remaining={len(self._sessions)}")

    async def get(self, call_sid: str) -> Optional["CallSession"]:
        async with self._lock:
            return self._sessions.get(call_sid)

    def active_count(self) -> int:
        return len(self._sessions)

@dataclass
class CallSession:
    call_sid: str
    tenant_id: str
    tenant_config: dict
    start_time: float
    current_language: str           # 'hindi', 'gujarati', 'english'
    lang_detection_buffer: list = field(default_factory=list)
    conversation_history: list = field(default_factory=list)
    turn_count: int = 0
    latencies: list = field(default_factory=list)
    lead_data: dict = field(default_factory=lambda: {
        "name": "Unknown",
        "time": "Not set",
        "service": "General Inquiry",
        "phone": "N/A",
        "notes": ""
    })
    blocked: bool = False
    caller_name: str = "Unknown"
    language: str = ""
    is_agent_speaking: bool = False
    agent_speaking_until: float = 0.0
    last_user_speech_at: float = 0.0
    silence_warning_sent: bool = False
    consecutive_empty_stt: int = 0
    consecutive_errors: int = 0
    call_state: str = "greeting"
    barge_in_allowed: bool = False
    filler_task: Optional[asyncio.Task] = None
    stream_sid: Optional[str] = None
    websocket: object = None
    call_context: dict = field(default_factory=dict)
    state_machine: CallStateMachine = field(default_factory=CallStateMachine)

    def __post_init__(self):
        if not self.language:
            self.language = self.current_language
        if not self.last_user_speech_at:
            self.last_user_speech_at = self.start_time

    def transition(self, new_state: str) -> bool:
        return self.state_machine.transition(self, new_state)

    def update_lead_data(self, new_data: dict):
        """Merge newly extracted data into the persistent lead state."""
        for key, value in new_data.items():
            if value and isinstance(value, str) and value.lower() != "unknown" and value != "N/A":
                self.lead_data[key] = value
            elif value and not isinstance(value, str):
                # Accept non-string values (e.g. booleans, ints) directly
                self.lead_data[key] = value
        print(f"[{self.call_sid}] Updated Lead State: {self.lead_data}")

    def update_language(self, detected: str, confidence: float):
        # ... (rest of the file stays same)
        """
        Continuously update language based on detected speech.
        Never permanently locks — allows real-time language switching
        at any point in the call (Hindi → Gujarati → English etc.)
        """
        if confidence < 0.6:
            return  # Low confidence — ignore

        lang_map = {'gu': 'gujarati', 'hi': 'hindi', 'en': 'english'}
        detected_full = lang_map.get(detected, None)
        if not detected_full:
            return  # Unknown language code — ignore

        self.lang_detection_buffer.append(detected_full)

        # Keep rolling window of last 3 detections
        if len(self.lang_detection_buffer) > 3:
            self.lang_detection_buffer = self.lang_detection_buffer[-3:]

        # Switch if 2 out of last 3 detections agree on a new language
        if len(self.lang_detection_buffer) >= 2:
            most_common = Counter(self.lang_detection_buffer).most_common(1)[0][0]
            most_common_count = Counter(self.lang_detection_buffer)[most_common]
            if most_common_count >= 2 and most_common != self.current_language:
                prev = self.current_language
                self.current_language = most_common
                print(f"[{self.call_sid}] Language switched: {prev}{self.current_language}")

    def add_turn(self, role: str, content: str, latency_ms: int = 0, tool_call_id: str = None, name: str = None):
        msg = {"role": role, "content": content}
        if tool_call_id: msg["tool_call_id"] = tool_call_id
        if name: msg["name"] = name
        
        self.conversation_history.append(msg)
        if latency_ms > 0:
            self.latencies.append(latency_ms)
        self.turn_count += 1

        # Keep context window small to maintain speed and save tokens
        if len(self.conversation_history) > 20:
            self.conversation_history = self.conversation_history[-20:]