VoxDoc / app /models /conversation_session.py
joelthomas77's picture
Upload app code
60d4850 verified
"""
Pydantic models for AI Voice Assistant conversation state and messaging.
"""
import uuid
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class ConversationState(str, Enum):
"""Dialogue state machine states."""
GREETING = "greeting"
CHIEF_COMPLAINT = "chief_complaint"
SYMPTOM_DETAILS = "symptom_details"
FOLLOW_UP = "follow_up"
SUMMARY = "summary"
EMERGENCY_ESCALATION = "emergency_escalation"
ENDED = "ended"
class ConversationMode(str, Enum):
"""Who is using the voice assistant."""
PATIENT = "patient"
CLINICIAN = "clinician"
AMBIENT = "ambient" # Phase 3: passive ambient listening mode
class ConversationTurn(BaseModel):
"""A single turn in the conversation."""
role: str # "assistant" or "user"
content: str
timestamp: datetime = Field(default_factory=datetime.utcnow)
state: Optional[ConversationState] = None
entities_extracted: Optional[Dict[str, Any]] = None
class AssistantResponse(BaseModel):
"""Response from the dialogue manager for a single turn."""
text: str
state: ConversationState
previous_state: Optional[ConversationState] = None
entities_update: Optional[Dict[str, Any]] = None
is_final: bool = False
is_emergency: bool = False
rag_grounded: bool = False
documentation: Optional[Dict[str, Any]] = None
class ConversationSessionData(BaseModel):
"""In-memory conversation session state."""
session_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
mode: ConversationMode = ConversationMode.PATIENT
state: ConversationState = ConversationState.GREETING
turns: List[ConversationTurn] = Field(default_factory=list)
extracted_entities: Dict[str, Any] = Field(default_factory=lambda: {
"conditions": [],
"medications": [],
})
collected_symptoms: Dict[str, Any] = Field(default_factory=dict)
rag_context: Optional[Dict[str, Any]] = None
accumulated_transcript: str = ""
followup_round: int = 0
language: str = "en"
created_at: datetime = Field(default_factory=datetime.utcnow)
def add_turn(self, role: str, content: str, **kwargs):
"""Add a conversation turn."""
turn = ConversationTurn(
role=role,
content=content,
state=self.state,
**kwargs,
)
self.turns.append(turn)
if role == "user":
self.accumulated_transcript += f" {content}" if self.accumulated_transcript else content
def get_conversation_history(self) -> List[Dict[str, str]]:
"""Get conversation history formatted for LLM chat template."""
return [
{"role": t.role, "content": t.content}
for t in self.turns
]
# WebSocket protocol messages
class WSClientAction(BaseModel):
"""Message from client to server."""
action: str # "start", "stop", "text_input", "interrupt"
mode: Optional[str] = None
language: Optional[str] = None
text: Optional[str] = None
class WSServerMessage(BaseModel):
"""Message from server to client."""
type: str # "connected", "assistant_text", "assistant_audio", "user_transcript",
# "entities_update", "state_change", "summary", "error"
session_id: Optional[str] = None
text: Optional[str] = None
audio: Optional[str] = None # base64-encoded WAV
format: Optional[str] = None
sample_rate: Optional[int] = None
state: Optional[str] = None
from_state: Optional[str] = None
to_state: Optional[str] = None
entities: Optional[Dict[str, Any]] = None
documentation: Optional[Dict[str, Any]] = None
message: Optional[str] = None
is_final: bool = False
is_emergency: bool = False
rag_grounded: bool = False