File size: 8,562 Bytes
0790420
cd82c0d
0790420
 
 
 
 
 
 
 
 
 
 
 
 
 
db431a4
0790420
 
db431a4
0790420
 
db431a4
 
 
 
 
 
 
 
 
 
 
0790420
 
 
 
db431a4
0790420
 
db431a4
 
0790420
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db431a4
0790420
 
 
 
 
 
 
db431a4
 
0790420
 
 
 
 
 
 
 
 
 
 
db431a4
0790420
 
db431a4
0790420
 
 
 
9a4c62a
db431a4
cd82c0d
 
 
 
db431a4
cd82c0d
 
 
 
 
 
 
 
db431a4
34600da
db431a4
34600da
 
 
0790420
34600da
cd82c0d
db431a4
0790420
 
cd82c0d
db431a4
0790420
db431a4
 
cd82c0d
 
 
 
db431a4
cd82c0d
0790420
 
34600da
 
0790420
 
 
34600da
cd82c0d
 
 
 
db431a4
34600da
0790420
34600da
 
 
0790420
34600da
 
0790420
 
cd82c0d
34600da
 
 
cd82c0d
 
0790420
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db431a4
0790420
 
9a4c62a
 
db431a4
9a4c62a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# rentbot/app.py

import os
import base64
import json
import asyncio
import numpy as np
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dotenv import load_dotenv

from audio_utils import ulaw_to_pcm16
from stt_handler import transcribe_audio_chunk
from llm_handler import get_llm_response
from tts_handler import text_to_speech_stream
from tool_handler import execute_tool_call

# Load environment variables from .env file
load_dotenv()

# Initialize FastAPI application
app = FastAPI()

# --- Add a root endpoint for health checks and basic info ---
@app.get("/")
async def root():
    """
    A simple GET endpoint to confirm the server is running and provide info.
    This is what you see when you visit the Hugging Face Space URL in a browser.
    """
    return {"status": "running", "message": "RentBot is active. Connect via WebSocket at the /rentbot endpoint."}


# --- Global Configuration ---
SILENCE_THRESHOLD_SECONDS = 0.7
AUDIO_RATE = 8000  # Hz for Twilio media streams
AUDIO_BUFFER_SIZE = int(SILENCE_THRESHOLD_SECONDS * AUDIO_RATE)

# In-memory session storage (for demonstration). In production, use Redis or a database.
sessions = {}


# --- Main WebSocket Endpoint for Twilio ---
@app.websocket("/rentbot")
async def websocket_endpoint(ws: WebSocket):
    await ws.accept()
    stream_sid = None
    audio_buffer = np.array([], dtype=np.int16)
    
    try:
        async for message in ws.iter_text():
            data = json.loads(message)

            if data['event'] == 'start':
                stream_sid = data['start']['streamSid']
                sessions[stream_sid] = {
                    "messages": [{"role": "system", "content": os.getenv("SYSTEM_PROMPT")}],
                    "processing_task": None
                }
                print(f"New stream started: {stream_sid}")

                # Send an initial greeting
                initial_greeting = "Hi! I'm RentBot, your leasing assistant. How can I help you today?"
                sessions[stream_sid]["messages"].append({"role": "assistant", "content": initial_greeting})
                
                async def send_initial_greeting():
                    tts_iterator = text_to_speech_stream(iter([initial_greeting]))
                    async for audio_chunk in tts_iterator:
                        payload = base64.b64encode(audio_chunk).decode('utf-8')
                        await ws.send_json({
                            "event": "media",
                            "streamSid": stream_sid,
                            "media": {"payload": payload}
                        })
                    await ws.send_json({"event": "mark", "streamSid": stream_sid, "mark": {"name": "bot_turn_end"}})

                asyncio.create_task(send_initial_greeting())

            elif data['event'] == 'media':
                if not stream_sid: continue
                chunk_ulaw = base64.b64decode(data['media']['payload'])
                chunk_pcm = ulaw_to_pcm16(chunk_ulaw)
                audio_buffer = np.append(audio_buffer, chunk_pcm)

                if len(audio_buffer) >= AUDIO_BUFFER_SIZE:
                    if sessions[stream_sid].get("processing_task") and not sessions[stream_sid]["processing_task"].done():
                        continue
                    task = asyncio.create_task(process_user_audio(ws, stream_sid, audio_buffer))
                    sessions[stream_sid]["processing_task"] = task
                    audio_buffer = np.array([], dtype=np.int16)

            elif data['event'] == 'mark':
                if not stream_sid: continue
                if len(audio_buffer) > 1000: # Heuristic to process leftover audio on pause
                    if not (sessions[stream_sid].get("processing_task") and not sessions[stream_sid]["processing_task"].done()):
                        task = asyncio.create_task(process_user_audio(ws, stream_sid, audio_buffer))
                        sessions[stream_sid]["processing_task"] = task
                        audio_buffer = np.array([], dtype=np.int16)

            elif data['event'] == 'stop':
                print(f"Stream stopped: {stream_sid}")
                break

    except WebSocketDisconnect:
        print(f"WebSocket disconnected for stream {stream_sid}")
    except Exception as e:
        print(f"An error occurred in websocket_endpoint: {e}")
    finally:
        if stream_sid and stream_sid in sessions:
            if sessions[stream_sid].get("processing_task"):
                sessions[stream_sid]["processing_task"].cancel()
            del sessions[stream_sid]
        print(f"Session cleaned up for stream {stream_sid}")


# --- Core Logic Functions ---
async def process_user_audio(ws: WebSocket, stream_sid: str, audio_chunk: np.ndarray):
    """The main logic loop: STT -> LLM -> (Tool/TTS)"""
    print(f"[{stream_sid}] Processing audio chunk of size {len(audio_chunk)}...")
    
    # 1. Speech-to-Text
    user_text = await transcribe_audio_chunk(audio_chunk)
    if not user_text:
        print(f"[{stream_sid}] No text transcribed.")
        return

    print(f"[{stream_sid}] User said: {user_text}")
    sessions[stream_sid]["messages"].append({"role": "user", "content": user_text})

    # Queue to pass text from LLM to TTS
    tts_queue = asyncio.Queue()
    async def llm_chunk_handler(chunk): await tts_queue.put(chunk)
    async def tts_text_iterator():
        while True:
            chunk = await tts_queue.get()
            if chunk is None: break
            yield chunk

    # 2. Start LLM and TTS tasks concurrently for low latency
    llm_task = asyncio.create_task(get_llm_response(sessions[stream_sid]["messages"], llm_chunk_handler))
    tts_task = asyncio.create_task(stream_and_send_audio(ws, stream_sid, tts_text_iterator()))

    # Wait for LLM to finish and get final message object
    assistant_message, tool_calls = await llm_task
    await tts_queue.put(None) # Signal TTS to end
    await tts_task # Wait for TTS to finish sending audio

    if assistant_message and assistant_message.get("content"):
       sessions[stream_sid]["messages"].append(assistant_message)

    # 3. Handle Tool Calls if any
    if tool_calls:
        sessions[stream_sid]["messages"].append(assistant_message)
        
        for tool_call_data in tool_calls:
            tool_call = type('ToolCall', (), {
                'id': tool_call_data.get('id'), 
                'function': type('Function', (), tool_call_data.get('function'))
            })()
            
            print(f"[{stream_sid}] Executing tool: {tool_call.function.name}")
            tool_result_message = execute_tool_call(tool_call)
            sessions[stream_sid]["messages"].append(tool_result_message)

        # 4. Get a final response from the LLM after executing the tool
        final_tts_queue = asyncio.Queue()
        async def final_llm_chunk_handler(chunk): await final_tts_queue.put(chunk)
        async def final_tts_iterator():
            while True:
                chunk = await final_tts_queue.get()
                if chunk is None: break
                yield chunk

        final_llm_task = asyncio.create_task(get_llm_response(sessions[stream_sid]["messages"], final_llm_chunk_handler))
        final_tts_task = asyncio.create_task(stream_and_send_audio(ws, stream_sid, final_tts_iterator()))
        
        final_assistant_message, _ = await final_llm_task
        await final_tts_queue.put(None)
        await final_tts_task

        if final_assistant_message:
            sessions[stream_sid]["messages"].append(final_assistant_message)


async def stream_and_send_audio(ws: WebSocket, stream_sid: str, text_iterator):
    """Stream text to TTS and send the resulting audio back over the WebSocket."""
    async for audio_chunk in text_to_speech_stream(text_iterator):
        if audio_chunk:
            payload = base64.b64encode(audio_chunk).decode('utf-8')
            await ws.send_json({
                "event": "media",
                "streamSid": stream_sid,
                "media": {"payload": payload}
            })
    
    await ws.send_json({"event": "mark", "streamSid": stream_sid, "mark": {"name": "bot_turn_end"}})
    print(f"[{stream_sid}] Finished sending bot's audio turn.")


# --- Application Entry Point ---
if __name__ == "__main__":
    import uvicorn
    # Hugging Face Spaces expects the app to run on port 7860
    port = int(os.environ.get("PORT", 7860))
    print(f"Starting RentBot server on host 0.0.0.0 and port {port}...")
    uvicorn.run(app, host="0.0.0.0", port=port)