| import asyncio |
| import base64 |
| import json |
| import os |
| import openai |
| from collections import deque |
| from typing import Dict |
|
|
| import dotenv |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
| from fastapi.responses import HTMLResponse |
| from twilio.rest import Client |
| from twilio.twiml.voice_response import Connect, VoiceResponse |
|
|
| from logger_config import get_logger |
| from services.call_context import CallContext |
| from services.llm_service import LLMFactory |
| from services.stream_service import StreamService |
| from services.transcription_service import TranscriptionService |
| from services.tts_service import TTSFactory |
| from openai import Client |
|
|
| dotenv.load_dotenv() |
| app = FastAPI() |
| logger = get_logger("App") |
|
|
| |
| global call_contexts |
| call_contexts = {} |
|
|
| def get_server_url(): |
| """Get the appropriate server URL based on the environment""" |
| return "fleet-loved-tahr.ngrok-free.app" |
|
|
| def get_twilio_client(): |
| """Get authenticated Twilio client""" |
| return Client(os.getenv("TWILIO_ACCOUNT_SID"), os.getenv("TWILIO_AUTH_TOKEN")) |
|
|
| @app.post("/incoming") |
| async def incoming_call() -> HTMLResponse: |
| """Handle incoming Twilio calls""" |
| server = get_server_url() |
| response = VoiceResponse() |
| connect = Connect() |
| connect.stream(url=f"wss://{server}/connection") |
| response.append(connect) |
| return HTMLResponse(content=str(response), status_code=200) |
|
|
| @app.get("/call_recording/{call_sid}") |
| async def get_call_recording(call_sid: str): |
| """Get the recording URL for a specific call.""" |
| recording = get_twilio_client().calls(call_sid).recordings.list() |
| if recording: |
| return {"recording_url": f"https://api.twilio.com/{recording[0].uri}"} |
| return {"error": "Recording not found"} |
|
|
| @app.websocket("/connection") |
| async def websocket_endpoint(websocket: WebSocket): |
| """Handle WebSocket connection for media streaming""" |
| await websocket.accept() |
|
|
| llm_service_name = os.getenv("LLM_SERVICE", "openai") |
| tts_service_name = os.getenv("TTS_SERVICE", "deepgram") |
|
|
| logger.info(f"Using LLM service: {llm_service_name}") |
| logger.info(f"Using TTS service: {tts_service_name}") |
|
|
| llm_service = LLMFactory.get_llm_service(llm_service_name, CallContext()) |
| stream_service = StreamService(websocket) |
| transcription_service = TranscriptionService() |
| tts_service = TTSFactory.get_tts_service(tts_service_name) |
| |
| marks = deque() |
| interaction_count = 0 |
|
|
| await transcription_service.connect() |
|
|
| async def process_media(msg): |
| await transcription_service.send(base64.b64decode(msg['media']['payload'])) |
|
|
| async def handle_transcription(text): |
| nonlocal interaction_count |
| if not text: |
| return |
| logger.info(f"Interaction {interaction_count} – STT -> LLM: {text}") |
| await llm_service.completion(text, interaction_count) |
| interaction_count += 1 |
|
|
| async def handle_llm_reply(llm_reply, icount): |
| logger.info(f"Interaction {icount}: LLM -> TTS: {llm_reply['partialResponse']}") |
| await tts_service.generate(llm_reply, icount) |
|
|
| async def handle_speech(response_index, audio, label, icount): |
| logger.info(f"Interaction {icount}: TTS -> TWILIO: {label}") |
| await stream_service.buffer(response_index, audio) |
|
|
| async def handle_audio_sent(mark_label): |
| marks.append(mark_label) |
|
|
| async def handle_utterance(text, stream_sid): |
| try: |
| if len(marks) > 0 and text.strip(): |
| logger.info("Interruption detected, clearing system.") |
| await websocket.send_json({ |
| "streamSid": stream_sid, |
| "event": "clear" |
| }) |
| |
| |
| stream_service.reset() |
| llm_service.reset() |
| except Exception as e: |
| logger.error(f"Error while handling utterance: {e}") |
| raise |
|
|
| transcription_service.on('utterance', handle_utterance) |
| transcription_service.on('transcription', handle_transcription) |
| llm_service.on('llmreply', handle_llm_reply) |
| tts_service.on('speech', handle_speech) |
| stream_service.on('audiosent', handle_audio_sent) |
|
|
| |
| message_queue = asyncio.Queue() |
|
|
| async def websocket_listener(): |
| try: |
| while True: |
| data = await websocket.receive_text() |
| await message_queue.put(json.loads(data)) |
| except WebSocketDisconnect: |
| logger.info("WebSocket disconnected") |
|
|
| async def message_processor(): |
| while True: |
| msg = await message_queue.get() |
| if msg['event'] == 'start': |
| stream_sid = msg['start']['streamSid'] |
| call_sid = msg['start']['callSid'] |
|
|
| call_context = CallContext() |
|
|
| if os.getenv("RECORD_CALLS") == "true": |
| get_twilio_client().calls(call_sid).recordings.create({"recordingChannels": "dual"}) |
|
|
| |
| if call_sid not in call_contexts: |
| |
| call_context.system_message = os.environ.get("SYSTEM_MESSAGE") |
| call_context.initial_message = os.environ.get("INITIAL_MESSAGE") |
| call_context.call_sid = call_sid |
| call_contexts[call_sid] = call_context |
| else: |
| |
| call_context = call_contexts[call_sid] |
| |
| llm_service.set_call_context(call_context) |
| stream_service.set_stream_sid(stream_sid) |
| transcription_service.set_stream_sid(stream_sid) |
|
|
| logger.info(f"Twilio -> Starting Media Stream for {stream_sid}") |
| await tts_service.generate({ |
| "partialResponseIndex": None, |
| "partialResponse": call_context.initial_message |
| }, 1) |
| elif msg['event'] == 'media': |
| asyncio.create_task(process_media(msg)) |
| elif msg['event'] == 'mark': |
| label = msg['mark']['name'] |
| if label in marks: |
| marks.remove(label) |
| elif msg['event'] == 'stop': |
| logger.info("Media stream ended.") |
| break |
| message_queue.task_done() |
|
|
| try: |
| listener_task = asyncio.create_task(websocket_listener()) |
| processor_task = asyncio.create_task(message_processor()) |
| await asyncio.gather(listener_task, processor_task) |
| except asyncio.CancelledError: |
| logger.info("Tasks cancelled") |
| finally: |
| await transcription_service.disconnect() |
|
|
| @app.post("/start_call") |
| async def start_call(request: Dict[str, str]): |
| """Initiate a call using Twilio with optional system and initial messages.""" |
| to_number = request.get("to_number") |
| system_message = request.get("system_message") |
| initial_message = request.get("initial_message") |
| logger.info(f"Initiating call to {to_number}") |
|
|
| server = get_server_url() |
| service_url = f"https://{server}/incoming" |
|
|
| if not to_number: |
| return {"error": "Missing 'to_number' in request"} |
|
|
| try: |
| client = get_twilio_client() |
| logger.info(f"Initiating call to {to_number} via {service_url}") |
| call = client.calls.create( |
| to=to_number, |
| from_=os.getenv("APP_NUMBER"), |
| url=service_url |
| ) |
| call_sid = call.sid |
|
|
| call_context = CallContext() |
| call_contexts[call_sid] = call_context |
| |
| |
| call_context.system_message = system_message or os.getenv("SYSTEM_MESSAGE") |
| call_context.initial_message = initial_message or os.getenv("INITIAL_MESSAGE") |
| call_context.call_sid = call_sid |
|
|
| return {"call_sid": call_sid} |
| except Exception as e: |
| logger.error(f"Error initiating call: {str(e)}") |
| return {"error": f"Failed to initiate call: {str(e)}"} |
|
|
| @app.get("/call_status/{call_sid}") |
| async def get_call_status(call_sid: str): |
| """Get the status of a call.""" |
| try: |
| client = get_twilio_client() |
| call = client.calls(call_sid).fetch() |
| return {"status": call.status} |
| except Exception as e: |
| logger.error(f"Error fetching call status: {str(e)}") |
| return {"error": f"Failed to fetch call status: {str(e)}"} |
|
|
| @app.post("/end_call") |
| async def end_call(request: Dict[str, str]): |
| """End a specific call.""" |
| try: |
| call_sid = request.get("call_sid") |
| client = get_twilio_client() |
| client.calls(call_sid).update(status='completed') |
| return {"status": "success"} |
| except Exception as e: |
| logger.error(f"Error ending call {str(e)}") |
| return {"error": f"Failed to end requested call: {str(e)}"} |
|
|
| @app.get("/transcript/{call_sid}") |
| async def get_transcript(call_sid: str): |
| """Get the entire transcript for a specific call.""" |
| call_context = call_contexts.get(call_sid) |
| if not call_context: |
| logger.info(f"[GET] Call not found for call SID: {call_sid}") |
| return {"error": "Call not found"} |
| return {"transcript": call_context.user_context} |
|
|
| @app.get("/all_transcripts") |
| async def get_all_transcripts(): |
| """Get a list of all current call transcripts.""" |
| try: |
| transcript_list = [] |
| for call_sid, context in call_contexts.items(): |
| transcript_list.append({ |
| "call_sid": call_sid, |
| "transcript": context.user_context, |
| }) |
| return {"transcripts": transcript_list} |
| except Exception as e: |
| logger.error(f"Error fetching all transcripts: {str(e)}") |
| return {"error": f"Failed to fetch all transcripts: {str(e)}"} |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| logger.info("Starting server...") |
| logger.info(f"Backend server address set to: {os.getenv('SERVER')}") |
| port = int(os.getenv("PORT", 8000)) |
| uvicorn.run(app, host="0.0.0.0", port=port) |