Spaces:
Sleeping
Sleeping
| """ | |
| Pydantic models for AI Voice Assistant conversation state and messaging. | |
| """ | |
| import uuid | |
| from datetime import datetime | |
| from enum import Enum | |
| from typing import Any, Dict, List, Optional | |
| from pydantic import BaseModel, Field | |
| class ConversationState(str, Enum): | |
| """Dialogue state machine states.""" | |
| GREETING = "greeting" | |
| CHIEF_COMPLAINT = "chief_complaint" | |
| SYMPTOM_DETAILS = "symptom_details" | |
| FOLLOW_UP = "follow_up" | |
| SUMMARY = "summary" | |
| EMERGENCY_ESCALATION = "emergency_escalation" | |
| ENDED = "ended" | |
| class ConversationMode(str, Enum): | |
| """Who is using the voice assistant.""" | |
| PATIENT = "patient" | |
| CLINICIAN = "clinician" | |
| AMBIENT = "ambient" # Phase 3: passive ambient listening mode | |
| class ConversationTurn(BaseModel): | |
| """A single turn in the conversation.""" | |
| role: str # "assistant" or "user" | |
| content: str | |
| timestamp: datetime = Field(default_factory=datetime.utcnow) | |
| state: Optional[ConversationState] = None | |
| entities_extracted: Optional[Dict[str, Any]] = None | |
| class AssistantResponse(BaseModel): | |
| """Response from the dialogue manager for a single turn.""" | |
| text: str | |
| state: ConversationState | |
| previous_state: Optional[ConversationState] = None | |
| entities_update: Optional[Dict[str, Any]] = None | |
| is_final: bool = False | |
| is_emergency: bool = False | |
| rag_grounded: bool = False | |
| documentation: Optional[Dict[str, Any]] = None | |
| class ConversationSessionData(BaseModel): | |
| """In-memory conversation session state.""" | |
| session_id: str = Field(default_factory=lambda: str(uuid.uuid4())) | |
| mode: ConversationMode = ConversationMode.PATIENT | |
| state: ConversationState = ConversationState.GREETING | |
| turns: List[ConversationTurn] = Field(default_factory=list) | |
| extracted_entities: Dict[str, Any] = Field(default_factory=lambda: { | |
| "conditions": [], | |
| "medications": [], | |
| }) | |
| collected_symptoms: Dict[str, Any] = Field(default_factory=dict) | |
| rag_context: Optional[Dict[str, Any]] = None | |
| accumulated_transcript: str = "" | |
| followup_round: int = 0 | |
| language: str = "en" | |
| created_at: datetime = Field(default_factory=datetime.utcnow) | |
| def add_turn(self, role: str, content: str, **kwargs): | |
| """Add a conversation turn.""" | |
| turn = ConversationTurn( | |
| role=role, | |
| content=content, | |
| state=self.state, | |
| **kwargs, | |
| ) | |
| self.turns.append(turn) | |
| if role == "user": | |
| self.accumulated_transcript += f" {content}" if self.accumulated_transcript else content | |
| def get_conversation_history(self) -> List[Dict[str, str]]: | |
| """Get conversation history formatted for LLM chat template.""" | |
| return [ | |
| {"role": t.role, "content": t.content} | |
| for t in self.turns | |
| ] | |
| # WebSocket protocol messages | |
| class WSClientAction(BaseModel): | |
| """Message from client to server.""" | |
| action: str # "start", "stop", "text_input", "interrupt" | |
| mode: Optional[str] = None | |
| language: Optional[str] = None | |
| text: Optional[str] = None | |
| class WSServerMessage(BaseModel): | |
| """Message from server to client.""" | |
| type: str # "connected", "assistant_text", "assistant_audio", "user_transcript", | |
| # "entities_update", "state_change", "summary", "error" | |
| session_id: Optional[str] = None | |
| text: Optional[str] = None | |
| audio: Optional[str] = None # base64-encoded WAV | |
| format: Optional[str] = None | |
| sample_rate: Optional[int] = None | |
| state: Optional[str] = None | |
| from_state: Optional[str] = None | |
| to_state: Optional[str] = None | |
| entities: Optional[Dict[str, Any]] = None | |
| documentation: Optional[Dict[str, Any]] = None | |
| message: Optional[str] = None | |
| is_final: bool = False | |
| is_emergency: bool = False | |
| rag_grounded: bool = False | |