Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Medicare AI Voice Agent - Production Backend | |
| Optimized for RunPod GPU deployment with VICIdial integration | |
| Version: 1.0.0 | |
| """ | |
| import os | |
| import asyncio | |
| import json | |
| import base64 | |
| import re | |
| from datetime import datetime | |
| from typing import Optional, List, Dict, Sequence | |
| from collections import deque | |
| from typing_extensions import Annotated, TypedDict | |
| # FastAPI | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File, Form, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel, Field | |
| import uvicorn | |
| # ML/AI | |
| import torch | |
| import numpy as np | |
| from transformers import pipeline | |
| from transformers.utils import is_flash_attn_2_available | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_core.messages import HumanMessage, AIMessage, ToolMessage | |
| from langchain_core.tools import tool | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.checkpoint.memory import MemorySaver | |
| # Database | |
| from sqlalchemy import create_engine, Column, String, Integer, DateTime, Text, JSON, Boolean | |
| from sqlalchemy.orm import sessionmaker, declarative_base, Session | |
| # Audio processing | |
| import soundfile as sf | |
| from pydub import AudioSegment | |
| import scipy.signal | |
| import io | |
| import audioop | |
| # ============================================================================ | |
| # CONFIGURATION | |
| # ============================================================================ | |
| GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "") | |
| if not GOOGLE_API_KEY: | |
| raise ValueError("GOOGLE_API_KEY environment variable not set!") | |
| os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY | |
| WHISPER_MODEL = "openai/whisper-large-v3-turbo" | |
| KOKORO_VOICE = "af_heart" | |
| KOKORO_LANG = "a" | |
| VAD_AVAILABLE = False | |
| # Database | |
| DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./medicare_agent.db") | |
| engine = create_engine(DATABASE_URL, pool_pre_ping=True) | |
| SessionLocal = sessionmaker(bind=engine) | |
| Base = declarative_base() | |
| # FastAPI app | |
| app = FastAPI(title="Medicare AI Voice Agent", version="1.0.0") | |
| # CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| print("β Configuration loaded") | |
| # ============================================================================ | |
| # DATABASE MODEL | |
| # ============================================================================ | |
| class Conversation(Base): | |
| __tablename__ = "conversations" | |
| session_id = Column(String, primary_key=True) | |
| phone_number = Column(String, index=True) | |
| started_at = Column(DateTime, default=datetime.utcnow) | |
| ended_at = Column(DateTime, nullable=True) | |
| status = Column(String) | |
| # Patient Data | |
| customer_name = Column(String, nullable=True) | |
| ethnicity = Column(String, nullable=True) | |
| height = Column(String, nullable=True) | |
| weight = Column(String, nullable=True) | |
| immune_conditions = Column(JSON, nullable=True) | |
| neuro_conditions = Column(JSON, nullable=True) | |
| cancer_history = Column(JSON, nullable=True) | |
| last_visit_date = Column(String, nullable=True) | |
| can_make_decisions = Column(Boolean, nullable=True) | |
| interested = Column(Boolean, nullable=True) | |
| # Conversation | |
| conversation_json = Column(JSON, nullable=True) | |
| total_turns = Column(Integer, default=0) | |
| greeting = Column(Text, nullable=True) | |
| first_user_response = Column(Text, nullable=True) | |
| Base.metadata.create_all(engine) | |
| print("β Database initialized") | |
| # ============================================================================ | |
| # AGENT LOGIC - LANGGRAPH WITH STATE MANAGEMENT | |
| # ============================================================================ | |
| class PatientInfoExtraction(BaseModel): | |
| """Patient information schema""" | |
| customer_name: Optional[str] = Field(None, description="Patient's full name") | |
| ethnicity: Optional[str] = Field(None, description="Patient's ethnic background") | |
| height: Optional[str] = Field(None, description="Patient's height") | |
| weight: Optional[str] = Field(None, description="Patient's weight") | |
| immune_conditions: Optional[List[str]] = Field(None, description="List of immune-related conditions") | |
| neuro_conditions: Optional[List[str]] = Field(None, description="List of neurological conditions") | |
| cancer_history: Optional[List[str]] = Field(None, description="List of any cancer diagnoses") | |
| last_visit_date: Optional[str] = Field(None, description="When patient last visited their doctor") | |
| can_make_decisions: Optional[bool] = Field(None, description="Whether patient is capable of making own medical decisions") | |
| interested: Optional[bool] = Field(None, description="Whether patient is interested in moving forward") | |
| class InterviewState(TypedDict): | |
| """Graph state definition""" | |
| messages: Annotated[Sequence[object], lambda x, y: x + y] | |
| patient_info: PatientInfoExtraction | |
| def update_patient_info( | |
| customer_name: Optional[str] = None, | |
| ethnicity: Optional[str] = None, | |
| height: Optional[str] = None, | |
| weight: Optional[str] = None, | |
| immune_conditions: Optional[List[str]] = None, | |
| neuro_conditions: Optional[List[str]] = None, | |
| cancer_history: Optional[List[str]] = None, | |
| last_visit_date: Optional[str] = None, | |
| can_make_decisions: Optional[bool] = None, | |
| interested: Optional[bool] = None, | |
| ): | |
| """Update patient information with new data""" | |
| locals_copy = locals().copy() | |
| updates = {k: v for k, v in locals_copy.items() if v is not None} | |
| return json.dumps(updates) | |
| def end_call(reason: str): | |
| """End the conversation""" | |
| return f"Call ended with reason: {reason}" | |
| def forward_call_to_human(reason: str): | |
| """Forward call to human agent""" | |
| return f"Call forwarded to human with reason: {reason}" | |
| def get_pending_questions(patient_info: PatientInfoExtraction) -> List[str]: | |
| """Get list of unanswered fields""" | |
| return [field for field, value in patient_info.model_dump().items() if value is None] | |
| def agent_node(state: InterviewState): | |
| """Main agent reasoning node""" | |
| messages = state["messages"] | |
| patient_info = state.get("patient_info") or PatientInfoExtraction() | |
| pending_questions = get_pending_questions(patient_info) | |
| relevant_history = messages[-6:] | |
| if not messages: | |
| system_prompt = """You are Jane, a friendly agent from Nationwide Screening. | |
| Your first task is to greet the patient, introduce yourself and the purpose of the call, | |
| and then ask for their name. This is a cold call. | |
| Script: "Hi, this is Jane calling from Nationwide Screening. The reason I'm reaching out is because you've been approved through your Medicare benefits to receive a no-cost genetic saliva test that checks for hidden risks related to autoimmune conditions, neurological disorders, and hereditary cancers. I'm calling today to see if you'd like to take advantage of this benefit. Before we go over the details, may I please have your name?" | |
| """ | |
| else: | |
| form_complete = len(pending_questions) == 0 | |
| customer_interested = patient_info.interested == True | |
| customer_not_interested = patient_info.interested == False | |
| if form_complete and customer_interested: | |
| system_prompt = """ | |
| ## SITUATION: | |
| You have collected ALL required information and the customer IS INTERESTED. | |
| ## IMMEDIATE ACTION REQUIRED: | |
| You MUST call the `forward_call_to_human` tool RIGHT NOW with reason: "interested_customer_ready". | |
| ## YOUR RESPONSE: | |
| Say: "Thank you so much for your time! I have all the information I need. Let me connect you with a specialist who can help you schedule your test. Please hold for just a moment." | |
| Then IMMEDIATELY call `forward_call_to_human`. | |
| """ | |
| elif form_complete and customer_not_interested: | |
| system_prompt = """ | |
| ## SITUATION: | |
| You have collected all information but the customer is NOT interested. | |
| ## IMMEDIATE ACTION: | |
| Call `end_call` with reason: "not_interested". | |
| ## YOUR RESPONSE: | |
| Say: "I understand. Thank you for your time today. Have a great day!" | |
| Then call `end_call`. | |
| """ | |
| elif customer_interested and not form_complete: | |
| system_prompt = f""" | |
| ## SITUATION: | |
| The customer IS INTERESTED but you still need some information. | |
| ## CURRENT PROGRESS: | |
| {patient_info.model_dump_json(indent=2)} | |
| ## MISSING INFORMATION: | |
| {', '.join(pending_questions)} | |
| ## YOUR TASK: | |
| 1. Acknowledge their interest warmly | |
| 2. Explain you just need a couple more details | |
| 3. Ask ONLY for the next missing item: {pending_questions[0]} | |
| Keep it brief, natural, and conversational. | |
| """ | |
| elif not form_complete and patient_info.interested is None: | |
| system_prompt = f""" | |
| You are Jane, a friendly medicare screening agent collecting patient information. | |
| ## YOUR PROGRESS: | |
| {patient_info.model_dump_json(indent=2)} | |
| ## PENDING QUESTIONS (ask in order): | |
| {', '.join(pending_questions)} | |
| ## CRITICAL RULES: | |
| 1. **Extract Information:** If patient provides ANY info, call `update_patient_info` tool IMMEDIATELY | |
| 2. **One Question at a Time:** Ask about ONLY ONE field: {pending_questions[0] if pending_questions else 'none'} | |
| 3. **Never Repeat:** NEVER ask for information you already have | |
| 4. **Be Natural:** Respond conversationally to what they said, then ask next question | |
| 5. **Handle Negativity:** If rude/frustrated, call `end_call` with reason "customer_upset" | |
| ## WHAT TO DO RIGHT NOW: | |
| - If they answered your last question β call `update_patient_info` | |
| - Then ask about: {pending_questions[0] if pending_questions else 'all info collected'} | |
| Remember: Natural speech only. No special characters. | |
| """ | |
| elif form_complete and patient_info.interested is None: | |
| system_prompt = """ | |
| ## SITUATION: | |
| You have collected ALL patient information EXCEPT their interest level. | |
| ## YOUR FINAL QUESTION: | |
| Ask clearly: "Great! I have all your information. Are you interested in moving forward with this free genetic screening test?" | |
| ## WHAT HAPPENS NEXT: | |
| - If they say YES β call `update_patient_info` with interested=True, then I will forward them | |
| - If they say NO β call `update_patient_info` with interested=False, then end call | |
| Ask the question naturally and wait for their response. | |
| """ | |
| else: | |
| system_prompt = f""" | |
| You are Jane, a friendly medicare screening agent. | |
| ## YOUR PROGRESS: | |
| {patient_info.model_dump_json(indent=2)} | |
| ## PENDING QUESTIONS: | |
| {', '.join(pending_questions)} | |
| ## INSTRUCTIONS: | |
| 1. Respond naturally to what the patient just said | |
| 2. If they provided info β call `update_patient_info` | |
| 3. Ask about the next pending item: {pending_questions[0] if pending_questions else 'none'} | |
| 4. If patient is rude or frustrated β call `end_call` | |
| 5. If explicitly asks for human β call `forward_call_to_human` | |
| Keep it conversational and natural. | |
| """ | |
| tools = [update_patient_info, end_call, forward_call_to_human] | |
| model = ChatGoogleGenerativeAI( | |
| temperature=0.7, | |
| model="gemini-2.5-flash-lite", | |
| api_key=os.getenv("GOOGLE_API_KEY") | |
| ).bind_tools(tools) | |
| response = model.invoke([system_prompt] + relevant_history) | |
| return {"messages": [response]} | |
| def tool_node(state: InterviewState): | |
| """Execute tools and update state""" | |
| last_message = state["messages"][-1] | |
| new_info_obj = state.get('patient_info') or PatientInfoExtraction() | |
| tool_messages = [] | |
| for tool_call in last_message.tool_calls: | |
| tool_name = tool_call["name"] | |
| if tool_name == "update_patient_info": | |
| tool_output_json = update_patient_info.invoke(tool_call["args"]) | |
| tool_messages.append(ToolMessage(content=tool_output_json, tool_call_id=tool_call["id"])) | |
| new_data_dict = json.loads(tool_output_json) | |
| new_info_obj = new_info_obj.model_copy(update=new_data_dict) | |
| print(f"[Tool] Updated: {new_data_dict}") | |
| elif tool_name == "end_call": | |
| tool_output = end_call.invoke(tool_call["args"]) | |
| tool_messages.append(ToolMessage(content=tool_output, tool_call_id=tool_call["id"])) | |
| print(f"[Tool] Ending call: {tool_call['args']}") | |
| elif tool_name == "forward_call_to_human": | |
| tool_output = forward_call_to_human.invoke(tool_call["args"]) | |
| tool_messages.append(ToolMessage(content=tool_output, tool_call_id=tool_call["id"])) | |
| print(f"[Tool] Forwarding: {tool_call['args']}") | |
| return {"messages": tool_messages, "patient_info": new_info_obj} | |
| def should_call_tool(state: InterviewState): | |
| """Decide if tools should be called""" | |
| last_message = state["messages"][-1] | |
| if last_message.tool_calls: | |
| return "tool_node" | |
| else: | |
| return END | |
| def after_tool(state: InterviewState): | |
| """Decide next step after tool execution""" | |
| last_message = state["messages"][-1] | |
| if isinstance(last_message, AIMessage) and last_message.tool_calls: | |
| tool_name = last_message.tool_calls[0]["name"] | |
| if tool_name == "update_patient_info": | |
| return "agent_node" | |
| else: | |
| return END | |
| else: | |
| return "agent_node" | |
| def create_agent_graph(): | |
| """Build and compile agent graph""" | |
| builder = StateGraph(InterviewState) | |
| builder.add_node("agent_node", agent_node) | |
| builder.add_node("tool_node", tool_node) | |
| builder.set_entry_point("agent_node") | |
| builder.add_conditional_edges( | |
| "agent_node", | |
| should_call_tool, | |
| {"tool_node": "tool_node", END: END} | |
| ) | |
| builder.add_conditional_edges( | |
| "tool_node", | |
| after_tool, | |
| {"agent_node": "agent_node", END: END} | |
| ) | |
| memory = MemorySaver() | |
| return builder.compile(checkpointer=memory) | |
| class MedicareAgent: | |
| """Agent with buffered database writes and state management""" | |
| def __init__(self): | |
| self.app = create_agent_graph() | |
| self._call_buffers = {} | |
| def process_message(self, session_id: str, user_message: str) -> str: | |
| """Process message and update state""" | |
| config = {"configurable": {"thread_id": session_id}} | |
| if session_id not in self._call_buffers: | |
| self._call_buffers[session_id] = { | |
| "turns": [], | |
| "started_at": datetime.utcnow(), | |
| "caller_id": None | |
| } | |
| if user_message: | |
| self._call_buffers[session_id]["turns"].append({ | |
| "timestamp": datetime.utcnow(), | |
| "role": "user", | |
| "content": user_message | |
| }) | |
| input_data = {"messages": [HumanMessage(content=user_message)] if user_message else []} | |
| result = self.app.invoke(input_data, config) | |
| agent_response = "I'm processing your information..." | |
| for msg in reversed(result.get("messages", [])): | |
| if hasattr(msg, '__class__') and msg.__class__.__name__ == "AIMessage": | |
| if hasattr(msg, 'content') and msg.content and msg.content.strip(): | |
| agent_response = msg.content | |
| break | |
| self._call_buffers[session_id]["turns"].append({ | |
| "timestamp": datetime.utcnow(), | |
| "role": "agent", | |
| "content": agent_response | |
| }) | |
| patient_info = result.get("patient_info") | |
| if patient_info: | |
| self._call_buffers[session_id]["patient_info"] = patient_info | |
| return agent_response | |
| def end_call(self, session_id: str, reason: str = "completed"): | |
| """Save conversation to database""" | |
| if session_id not in self._call_buffers: | |
| return | |
| buffer = self._call_buffers[session_id] | |
| db_session = SessionLocal() | |
| try: | |
| turns = buffer["turns"] | |
| conversation_structured = [] | |
| for i, turn in enumerate(turns): | |
| conversation_structured.append({ | |
| "turn_number": i + 1, | |
| "role": turn["role"], | |
| "content": turn["content"], | |
| "timestamp": turn["timestamp"].isoformat() | |
| }) | |
| greeting = None | |
| first_user_response = None | |
| for turn in turns: | |
| if turn["role"] == "agent" and greeting is None: | |
| greeting = turn["content"] | |
| elif turn["role"] == "user" and first_user_response is None: | |
| first_user_response = turn["content"] | |
| patient_info = buffer.get("patient_info") | |
| patient_data = {} | |
| if patient_info: | |
| updates = patient_info.model_dump() if hasattr(patient_info, 'model_dump') else patient_info | |
| patient_data = {k: v for k, v in updates.items() if v is not None} | |
| existing = db_session.query(Conversation).filter_by(session_id=session_id).first() | |
| if existing: | |
| existing.ended_at = datetime.utcnow() | |
| existing.status = reason | |
| existing.conversation_json = conversation_structured | |
| existing.total_turns = len(turns) | |
| existing.greeting = greeting | |
| existing.first_user_response = first_user_response | |
| for field, value in patient_data.items(): | |
| if hasattr(existing, field): | |
| setattr(existing, field, value) | |
| else: | |
| conversation = Conversation( | |
| session_id=session_id, | |
| phone_number=buffer.get("caller_id"), | |
| started_at=buffer["started_at"], | |
| ended_at=datetime.utcnow(), | |
| status=reason, | |
| conversation_json=conversation_structured, | |
| total_turns=len(turns), | |
| greeting=greeting, | |
| first_user_response=first_user_response, | |
| **patient_data | |
| ) | |
| db_session.add(conversation) | |
| db_session.commit() | |
| print(f"β [{session_id}] Saved: {len(turns)} turns, status: {reason}") | |
| except Exception as e: | |
| print(f"β [{session_id}] Database error: {e}") | |
| db_session.rollback() | |
| finally: | |
| db_session.close() | |
| del self._call_buffers[session_id] | |
| def get_patient_info(self, session_id: str) -> dict: | |
| """Get patient info from buffer or database""" | |
| if session_id in self._call_buffers: | |
| patient_info = self._call_buffers[session_id].get("patient_info") | |
| if patient_info: | |
| return patient_info.model_dump() if hasattr(patient_info, 'model_dump') else patient_info | |
| db_session = SessionLocal() | |
| try: | |
| conversation = db_session.query(Conversation).filter_by(session_id=session_id).first() | |
| if conversation: | |
| return { | |
| "customer_name": conversation.customer_name, | |
| "ethnicity": conversation.ethnicity, | |
| "height": conversation.height, | |
| "weight": conversation.weight, | |
| "immune_conditions": conversation.immune_conditions, | |
| "neuro_conditions": conversation.neuro_conditions, | |
| "cancer_history": conversation.cancer_history, | |
| "last_visit_date": conversation.last_visit_date, | |
| "can_make_decisions": conversation.can_make_decisions, | |
| "interested": conversation.interested | |
| } | |
| return {} | |
| except Exception as e: | |
| print(f"Database error: {e}") | |
| return {} | |
| finally: | |
| db_session.close() | |
| print("β Agent logic loaded") | |
| # ============================================================================ | |
| # AUDIO PROCESSING | |
| # ============================================================================ | |
| def transcribe_audio(audio_bytes: bytes, source_format: str = "webm") -> str: | |
| """Transcribe audio to text""" | |
| try: | |
| if source_format == "mulaw": | |
| pcm_data = audioop.ulaw2lin(audio_bytes, 2) | |
| audio = AudioSegment(data=pcm_data, sample_width=2, frame_rate=8000, channels=1) | |
| elif source_format == "pcm16k": | |
| audio = AudioSegment(data=audio_bytes, sample_width=2, frame_rate=16000, channels=1) | |
| else: | |
| audio = AudioSegment.from_file(io.BytesIO(audio_bytes)) | |
| audio = audio.set_frame_rate(16000).set_channels(1) | |
| wav_buffer = io.BytesIO() | |
| audio.export(wav_buffer, format="wav") | |
| wav_buffer.seek(0) | |
| audio_array, sample_rate = sf.read(wav_buffer) | |
| audio_array = audio_array.astype(np.float32) | |
| max_val = np.abs(audio_array).max() | |
| if max_val > 0: | |
| audio_array = audio_array / max_val | |
| if sample_rate != 16000: | |
| num_samples = int(len(audio_array) * 16000 / sample_rate) | |
| audio_array = scipy.signal.resample(audio_array, num_samples) | |
| result = whisper_pipeline( | |
| audio_array, | |
| return_timestamps=True, | |
| generate_kwargs={"language": "english", "task": "transcribe"} | |
| ) | |
| return result["text"].strip() | |
| except Exception as e: | |
| print(f"Transcription error: {e}") | |
| raise HTTPException(status_code=500, detail=f"STT Error: {str(e)}") | |
| def synthesize_speech(text: str, output_format: str = "wav") -> bytes: | |
| """Convert text to speech""" | |
| try: | |
| audio_array = tts_model.generate_speech(text, voice=KOKORO_VOICE, speed=1.0) | |
| sample_rate = 24000 | |
| if output_format == "mulaw": | |
| pcm_data = (audio_array * 32767).astype(np.int16).tobytes() | |
| audio_seg = AudioSegment(data=pcm_data, sample_width=2, frame_rate=sample_rate, channels=1) | |
| audio_seg = audio_seg.set_frame_rate(8000) | |
| mulaw_data = audioop.lin2ulaw(audio_seg.raw_data, 2) | |
| return mulaw_data | |
| elif output_format == "pcm16k": | |
| resampled = scipy.signal.resample(audio_array, len(audio_array) * 16 // 24) | |
| pcm_data = (resampled * 32767).astype(np.int16).tobytes() | |
| return pcm_data | |
| else: | |
| buffer = io.BytesIO() | |
| sf.write(buffer, audio_array, sample_rate, format='WAV') | |
| buffer.seek(0) | |
| return buffer.read() | |
| except Exception as e: | |
| print(f"TTS Error: {e}") | |
| sample_rate = 24000 | |
| duration = max(2.0, len(text.split()) * 0.5) | |
| samples = int(duration * sample_rate) | |
| audio = np.zeros(samples, dtype=np.float32) | |
| buffer = io.BytesIO() | |
| sf.write(buffer, audio, sample_rate, format='WAV') | |
| buffer.seek(0) | |
| return buffer.read() | |
| print("β Audio processing loaded") | |
| # ============================================================================ | |
| # LOAD MODELS | |
| # ============================================================================ | |
| print("π Loading models...") | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| use_flash_attn = is_flash_attn_2_available() | |
| model_kwargs = {"attn_implementation": "flash_attention_2"} if use_flash_attn else {"attn_implementation": "sdpa"} | |
| print(f" Device: {device}") | |
| print(f" Attention: {'Flash Attention 2' if use_flash_attn else 'SDPA'}") | |
| whisper_pipeline = pipeline( | |
| "automatic-speech-recognition", | |
| model=WHISPER_MODEL, | |
| torch_dtype=torch_dtype, | |
| device=device, | |
| model_kwargs=model_kwargs, | |
| ) | |
| print(f"β Whisper loaded") | |
| from kokoro import KPipeline | |
| tts_model = KPipeline(lang_code=KOKORO_LANG) | |
| print(f"β Kokoro TTS loaded") | |
| agent = MedicareAgent() | |
| print("β Medicare Agent loaded") | |
| # ============================================================================ | |
| # API ENDPOINTS | |
| # ============================================================================ | |
| async def root(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "running", | |
| "version": "1.0.0", | |
| "services": { | |
| "whisper": "loaded", | |
| "device": device, | |
| "kokoro": "loaded", | |
| "agent": "loaded", | |
| "database": "connected" | |
| } | |
| } | |
| async def text_message(request: dict): | |
| """Text-based message API""" | |
| session_id = request.get("session_id") | |
| message = request.get("message", "") | |
| response = agent.process_message(session_id, message) | |
| patient_info = agent.get_patient_info(session_id) | |
| return { | |
| "session_id": session_id, | |
| "agent_response": response, | |
| "patient_info": patient_info | |
| } | |
| async def voice_message( | |
| audio: UploadFile = File(...), | |
| session_id: str = Form(...), | |
| source: str = Form("webm") | |
| ): | |
| """Voice-based message API""" | |
| audio_bytes = await audio.read() | |
| transcript = transcribe_audio(audio_bytes, source_format=source) | |
| response_text = agent.process_message(session_id, transcript) | |
| response_audio = synthesize_speech(response_text, output_format="wav") | |
| return StreamingResponse( | |
| io.BytesIO(response_audio), | |
| media_type="audio/wav", | |
| headers={ | |
| "X-Transcript": transcript, | |
| "X-Response": response_text | |
| } | |
| ) | |
| async def end_call_endpoint(session_id: str): | |
| """End call and save to database""" | |
| agent.end_call(session_id, "completed") | |
| return {"status": "success", "session_id": session_id} | |
| async def get_patient_info_endpoint(session_id: str): | |
| """Get patient information""" | |
| patient_info = agent.get_patient_info(session_id) | |
| return {"session_id": session_id, "patient_info": patient_info} | |
| # @app.websocket("/ws/vicidial/{session_id}") | |
| # async def websocket_vicidial(websocket: WebSocket, session_id: str): | |
| # """WebSocket for VICIdial AudioSocket relay""" | |
| # await websocket.accept() | |
| # audio_buffer = bytearray() | |
| # try: | |
| # greeting = agent.process_message(session_id, "") | |
| # greeting_audio = synthesize_speech(greeting, output_format="pcm16k") | |
| # await websocket.send_json({ | |
| # "type": "audio_response", | |
| # "audio": base64.b64encode(greeting_audio).decode('utf-8'), | |
| # "text": greeting | |
| # }) | |
| # while True: | |
| # data = await websocket.receive_text() | |
| # msg = json.loads(data) | |
| # if msg['type'] == 'audio_data': | |
| # pcm_chunk = base64.b64decode(msg['audio']) | |
| # audio_buffer.extend(pcm_chunk) | |
| # if len(audio_buffer) >= 3200: | |
| # audio_to_process = bytes(audio_buffer[:3200]) | |
| # audio_buffer = audio_buffer[3200:] | |
| # transcript = transcribe_audio(audio_to_process, source_format="pcm16k") | |
| # if transcript.strip(): | |
| # await websocket.send_json({ | |
| # "type": "transcript", | |
| # "text": transcript | |
| # }) | |
| # response = agent.process_message(session_id, transcript) | |
| # response_audio = synthesize_speech(response, output_format="pcm16k") | |
| # await websocket.send_json({ | |
| # "type": "audio_response", | |
| # "audio": base64.b64encode(response_audio).decode('utf-8'), | |
| # "text": response | |
| # }) | |
| # elif msg['type'] == 'hangup': | |
| # agent.end_call(session_id, "completed") | |
| # break | |
| # except WebSocketDisconnect: | |
| # agent.end_call(session_id, "completed") | |
| async def websocket_vicidial(websocket: WebSocket, session_id: str): | |
| """WebSocket for VICIdial with sentence-level streaming TTS""" | |
| await websocket.accept() | |
| audio_buffer = bytearray() | |
| try: | |
| # Send greeting with streaming | |
| greeting = agent.process_message(session_id, "") | |
| # Split greeting into sentences for streaming | |
| sentences = [s.strip() + '.' for s in greeting.split('.') if s.strip()] | |
| for i, sentence in enumerate(sentences): | |
| sentence_audio = synthesize_speech(sentence, output_format="pcm16k") | |
| await websocket.send_json({ | |
| "type": "audio_chunk", | |
| "audio": base64.b64encode(sentence_audio).decode('utf-8'), | |
| "text": sentence, | |
| "is_final": (i == len(sentences) - 1) | |
| }) | |
| # Main conversation loop | |
| while True: | |
| data = await websocket.receive_text() | |
| msg = json.loads(data) | |
| if msg['type'] == 'audio_data': | |
| pcm_chunk = base64.b64decode(msg['audio']) | |
| audio_buffer.extend(pcm_chunk) | |
| # Process when we have enough audio (0.4 seconds @ 8kHz) | |
| if len(audio_buffer) >= 3200: | |
| audio_to_process = bytes(audio_buffer[:3200]) | |
| audio_buffer = audio_buffer[3200:] | |
| # Transcribe | |
| transcript = transcribe_audio(audio_to_process, source_format="pcm16k") | |
| if transcript.strip(): | |
| # Send transcript immediately | |
| await websocket.send_json({ | |
| "type": "transcript", | |
| "text": transcript | |
| }) | |
| # Get agent response | |
| response = agent.process_message(session_id, transcript) | |
| # Stream response sentence by sentence | |
| response_sentences = [s.strip() + '.' for s in response.split('.') if s.strip()] | |
| for i, sentence in enumerate(response_sentences): | |
| sentence_audio = synthesize_speech(sentence, output_format="pcm16k") | |
| await websocket.send_json({ | |
| "type": "audio_chunk", | |
| "audio": base64.b64encode(sentence_audio).decode('utf-8'), | |
| "text": sentence, | |
| "is_final": (i == len(response_sentences) - 1) | |
| }) | |
| elif msg['type'] == 'hangup': | |
| agent.end_call(session_id, "completed") | |
| break | |
| except WebSocketDisconnect: | |
| agent.end_call(session_id, "completed") | |
| print("β API endpoints loaded") | |
| # ============================================================================ | |
| # SERVER STARTUP | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| port = int(os.getenv("PORT", 8000)) | |
| print(f"\nπ Starting Medicare AI Voice Agent on port {port}") | |
| print(f"π‘ Health check: http://localhost:{port}/") | |
| print(f"π VICIdial WebSocket: ws://localhost:{port}/ws/vicidial/{{session_id}}\n") | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=port, | |
| workers=1, | |
| log_level="info" | |
| ) | |