mgbam commited on
Commit
db431a4
·
verified ·
1 Parent(s): 9a4c62a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -14
app.py CHANGED
@@ -14,19 +14,32 @@ from llm_handler import get_llm_response
14
  from tts_handler import text_to_speech_stream
15
  from tool_handler import execute_tool_call
16
 
17
- # Load environment variables
18
  load_dotenv()
19
 
 
20
  app = FastAPI()
21
 
22
- # Configuration
 
 
 
 
 
 
 
 
 
 
23
  SILENCE_THRESHOLD_SECONDS = 0.7
24
  AUDIO_RATE = 8000 # Hz for Twilio media streams
25
  AUDIO_BUFFER_SIZE = int(SILENCE_THRESHOLD_SECONDS * AUDIO_RATE)
26
 
27
- # In-memory session storage (for demonstration)
28
  sessions = {}
29
 
 
 
30
  @app.websocket("/rentbot")
31
  async def websocket_endpoint(ws: WebSocket):
32
  await ws.accept()
@@ -69,7 +82,7 @@ async def websocket_endpoint(ws: WebSocket):
69
  audio_buffer = np.append(audio_buffer, chunk_pcm)
70
 
71
  if len(audio_buffer) >= AUDIO_BUFFER_SIZE:
72
- if sessions[stream_sid]["processing_task"] and not sessions[stream_sid]["processing_task"].done():
73
  continue
74
  task = asyncio.create_task(process_user_audio(ws, stream_sid, audio_buffer))
75
  sessions[stream_sid]["processing_task"] = task
@@ -77,8 +90,8 @@ async def websocket_endpoint(ws: WebSocket):
77
 
78
  elif data['event'] == 'mark':
79
  if not stream_sid: continue
80
- if len(audio_buffer) > 1000:
81
- if not (sessions[stream_sid]["processing_task"] and not sessions[stream_sid]["processing_task"].done()):
82
  task = asyncio.create_task(process_user_audio(ws, stream_sid, audio_buffer))
83
  sessions[stream_sid]["processing_task"] = task
84
  audio_buffer = np.array([], dtype=np.int16)
@@ -90,19 +103,21 @@ async def websocket_endpoint(ws: WebSocket):
90
  except WebSocketDisconnect:
91
  print(f"WebSocket disconnected for stream {stream_sid}")
92
  except Exception as e:
93
- print(f"An error occurred: {e}")
94
  finally:
95
  if stream_sid and stream_sid in sessions:
96
- if sessions[stream_sid]["processing_task"]:
97
  sessions[stream_sid]["processing_task"].cancel()
98
  del sessions[stream_sid]
99
  print(f"Session cleaned up for stream {stream_sid}")
100
 
101
 
 
102
  async def process_user_audio(ws: WebSocket, stream_sid: str, audio_chunk: np.ndarray):
103
  """The main logic loop: STT -> LLM -> (Tool/TTS)"""
104
  print(f"[{stream_sid}] Processing audio chunk of size {len(audio_chunk)}...")
105
 
 
106
  user_text = await transcribe_audio_chunk(audio_chunk)
107
  if not user_text:
108
  print(f"[{stream_sid}] No text transcribed.")
@@ -111,26 +126,28 @@ async def process_user_audio(ws: WebSocket, stream_sid: str, audio_chunk: np.nda
111
  print(f"[{stream_sid}] User said: {user_text}")
112
  sessions[stream_sid]["messages"].append({"role": "user", "content": user_text})
113
 
 
114
  tts_queue = asyncio.Queue()
115
- async def llm_chunk_handler(chunk):
116
- await tts_queue.put(chunk)
117
-
118
  async def tts_text_iterator():
119
  while True:
120
  chunk = await tts_queue.get()
121
  if chunk is None: break
122
  yield chunk
123
 
 
124
  llm_task = asyncio.create_task(get_llm_response(sessions[stream_sid]["messages"], llm_chunk_handler))
125
  tts_task = asyncio.create_task(stream_and_send_audio(ws, stream_sid, tts_text_iterator()))
126
 
 
127
  assistant_message, tool_calls = await llm_task
128
- await tts_queue.put(None)
129
- await tts_task
130
 
131
  if assistant_message and assistant_message.get("content"):
132
  sessions[stream_sid]["messages"].append(assistant_message)
133
 
 
134
  if tool_calls:
135
  sessions[stream_sid]["messages"].append(assistant_message)
136
 
@@ -144,6 +161,7 @@ async def process_user_audio(ws: WebSocket, stream_sid: str, audio_chunk: np.nda
144
  tool_result_message = execute_tool_call(tool_call)
145
  sessions[stream_sid]["messages"].append(tool_result_message)
146
 
 
147
  final_tts_queue = asyncio.Queue()
148
  async def final_llm_chunk_handler(chunk): await final_tts_queue.put(chunk)
149
  async def final_tts_iterator():
@@ -178,9 +196,10 @@ async def stream_and_send_audio(ws: WebSocket, stream_sid: str, text_iterator):
178
  print(f"[{stream_sid}] Finished sending bot's audio turn.")
179
 
180
 
 
181
  if __name__ == "__main__":
182
  import uvicorn
183
  # Hugging Face Spaces expects the app to run on port 7860
184
  port = int(os.environ.get("PORT", 7860))
185
- print(f"Starting RentBot server on port {port}...")
186
  uvicorn.run(app, host="0.0.0.0", port=port)
 
14
  from tts_handler import text_to_speech_stream
15
  from tool_handler import execute_tool_call
16
 
17
+ # Load environment variables from .env file
18
  load_dotenv()
19
 
20
+ # Initialize FastAPI application
21
  app = FastAPI()
22
 
23
+ # --- Add a root endpoint for health checks and basic info ---
24
+ @app.get("/")
25
+ async def root():
26
+ """
27
+ A simple GET endpoint to confirm the server is running and provide info.
28
+ This is what you see when you visit the Hugging Face Space URL in a browser.
29
+ """
30
+ return {"status": "running", "message": "RentBot is active. Connect via WebSocket at the /rentbot endpoint."}
31
+
32
+
33
+ # --- Global Configuration ---
34
  SILENCE_THRESHOLD_SECONDS = 0.7
35
  AUDIO_RATE = 8000 # Hz for Twilio media streams
36
  AUDIO_BUFFER_SIZE = int(SILENCE_THRESHOLD_SECONDS * AUDIO_RATE)
37
 
38
+ # In-memory session storage (for demonstration). In production, use Redis or a database.
39
  sessions = {}
40
 
41
+
42
+ # --- Main WebSocket Endpoint for Twilio ---
43
  @app.websocket("/rentbot")
44
  async def websocket_endpoint(ws: WebSocket):
45
  await ws.accept()
 
82
  audio_buffer = np.append(audio_buffer, chunk_pcm)
83
 
84
  if len(audio_buffer) >= AUDIO_BUFFER_SIZE:
85
+ if sessions[stream_sid].get("processing_task") and not sessions[stream_sid]["processing_task"].done():
86
  continue
87
  task = asyncio.create_task(process_user_audio(ws, stream_sid, audio_buffer))
88
  sessions[stream_sid]["processing_task"] = task
 
90
 
91
  elif data['event'] == 'mark':
92
  if not stream_sid: continue
93
+ if len(audio_buffer) > 1000: # Heuristic to process leftover audio on pause
94
+ if not (sessions[stream_sid].get("processing_task") and not sessions[stream_sid]["processing_task"].done()):
95
  task = asyncio.create_task(process_user_audio(ws, stream_sid, audio_buffer))
96
  sessions[stream_sid]["processing_task"] = task
97
  audio_buffer = np.array([], dtype=np.int16)
 
103
  except WebSocketDisconnect:
104
  print(f"WebSocket disconnected for stream {stream_sid}")
105
  except Exception as e:
106
+ print(f"An error occurred in websocket_endpoint: {e}")
107
  finally:
108
  if stream_sid and stream_sid in sessions:
109
+ if sessions[stream_sid].get("processing_task"):
110
  sessions[stream_sid]["processing_task"].cancel()
111
  del sessions[stream_sid]
112
  print(f"Session cleaned up for stream {stream_sid}")
113
 
114
 
115
+ # --- Core Logic Functions ---
116
  async def process_user_audio(ws: WebSocket, stream_sid: str, audio_chunk: np.ndarray):
117
  """The main logic loop: STT -> LLM -> (Tool/TTS)"""
118
  print(f"[{stream_sid}] Processing audio chunk of size {len(audio_chunk)}...")
119
 
120
+ # 1. Speech-to-Text
121
  user_text = await transcribe_audio_chunk(audio_chunk)
122
  if not user_text:
123
  print(f"[{stream_sid}] No text transcribed.")
 
126
  print(f"[{stream_sid}] User said: {user_text}")
127
  sessions[stream_sid]["messages"].append({"role": "user", "content": user_text})
128
 
129
+ # Queue to pass text from LLM to TTS
130
  tts_queue = asyncio.Queue()
131
+ async def llm_chunk_handler(chunk): await tts_queue.put(chunk)
 
 
132
  async def tts_text_iterator():
133
  while True:
134
  chunk = await tts_queue.get()
135
  if chunk is None: break
136
  yield chunk
137
 
138
+ # 2. Start LLM and TTS tasks concurrently for low latency
139
  llm_task = asyncio.create_task(get_llm_response(sessions[stream_sid]["messages"], llm_chunk_handler))
140
  tts_task = asyncio.create_task(stream_and_send_audio(ws, stream_sid, tts_text_iterator()))
141
 
142
+ # Wait for LLM to finish and get final message object
143
  assistant_message, tool_calls = await llm_task
144
+ await tts_queue.put(None) # Signal TTS to end
145
+ await tts_task # Wait for TTS to finish sending audio
146
 
147
  if assistant_message and assistant_message.get("content"):
148
  sessions[stream_sid]["messages"].append(assistant_message)
149
 
150
+ # 3. Handle Tool Calls if any
151
  if tool_calls:
152
  sessions[stream_sid]["messages"].append(assistant_message)
153
 
 
161
  tool_result_message = execute_tool_call(tool_call)
162
  sessions[stream_sid]["messages"].append(tool_result_message)
163
 
164
+ # 4. Get a final response from the LLM after executing the tool
165
  final_tts_queue = asyncio.Queue()
166
  async def final_llm_chunk_handler(chunk): await final_tts_queue.put(chunk)
167
  async def final_tts_iterator():
 
196
  print(f"[{stream_sid}] Finished sending bot's audio turn.")
197
 
198
 
199
+ # --- Application Entry Point ---
200
  if __name__ == "__main__":
201
  import uvicorn
202
  # Hugging Face Spaces expects the app to run on port 7860
203
  port = int(os.environ.get("PORT", 7860))
204
+ print(f"Starting RentBot server on host 0.0.0.0 and port {port}...")
205
  uvicorn.run(app, host="0.0.0.0", port=port)