medgemma-triage-demo / src /workflows /agent_coordinator.py
SebAustin
V1
3265b47
"""
Agent Coordinator - Manages agent communication and state.
"""
from typing import Dict, Any, List, Optional
from enum import Enum
from dataclasses import dataclass, field
from datetime import datetime
from src.utils.logger import logger
class TriageState(Enum):
"""States in the triage workflow."""
INTAKE = "intake"
SYMPTOM_ASSESSMENT = "symptom_assessment"
URGENCY_CLASSIFICATION = "urgency_classification"
CARE_RECOMMENDATION = "care_recommendation"
COMMUNICATION = "communication"
COMPLETED = "completed"
ERROR = "error"
@dataclass
class TriageSession:
"""Represents a triage session with state and data."""
session_id: str
state: TriageState = TriageState.INTAKE
started_at: datetime = field(default_factory=datetime.now)
# Patient data
patient_input_history: List[str] = field(default_factory=list)
conversation_history: List[Dict[str, str]] = field(default_factory=list)
# Agent outputs
case_summary: str = ""
symptom_analysis: str = ""
primary_symptoms: List[str] = field(default_factory=list)
differential_diagnosis: List[str] = field(default_factory=list)
red_flags: List[str] = field(default_factory=list)
urgency_level: str = ""
urgency_reasoning: str = ""
urgency_confidence: str = ""
time_sensitive: bool = False
care_setting: str = ""
timeline: str = ""
next_steps: List[str] = field(default_factory=list)
self_care: List[str] = field(default_factory=list)
preparation: List[str] = field(default_factory=list)
final_report: str = ""
report_summary: str = ""
formatted_report: str = ""
# Metadata
completed_at: Optional[datetime] = None
error_message: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert session to dictionary."""
return {
"session_id": self.session_id,
"state": self.state.value,
"started_at": self.started_at.isoformat(),
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"case_summary": self.case_summary,
"urgency_level": self.urgency_level,
"care_setting": self.care_setting,
"timeline": self.timeline,
"red_flags": self.red_flags,
"final_report": self.final_report,
"report_summary": self.report_summary
}
class AgentCoordinator:
"""Coordinates agent execution and manages triage state."""
def __init__(self):
self.sessions: Dict[str, TriageSession] = {}
logger.info("AgentCoordinator initialized")
def create_session(self, session_id: str) -> TriageSession:
"""Create a new triage session."""
session = TriageSession(session_id=session_id)
self.sessions[session_id] = session
logger.info(f"Created session: {session_id}")
return session
def get_session(self, session_id: str) -> Optional[TriageSession]:
"""Get an existing session."""
return self.sessions.get(session_id)
def update_session_state(
self,
session_id: str,
new_state: TriageState
) -> None:
"""Update session state."""
if session_id in self.sessions:
old_state = self.sessions[session_id].state
self.sessions[session_id].state = new_state
logger.info(f"Session {session_id} state: {old_state.value} -> {new_state.value}")
def store_intake_data(
self,
session_id: str,
conversation_history: List[Dict[str, str]],
case_summary: str
) -> None:
"""Store intake agent results."""
if session_id in self.sessions:
session = self.sessions[session_id]
session.conversation_history = conversation_history
session.case_summary = case_summary
logger.debug(f"Stored intake data for session {session_id}")
def store_symptom_data(
self,
session_id: str,
symptom_analysis: str,
primary_symptoms: List[str],
differential_diagnosis: List[str],
red_flags: List[str]
) -> None:
"""Store symptom assessment results."""
if session_id in self.sessions:
session = self.sessions[session_id]
session.symptom_analysis = symptom_analysis
session.primary_symptoms = primary_symptoms
session.differential_diagnosis = differential_diagnosis
session.red_flags = red_flags
logger.debug(f"Stored symptom data for session {session_id}")
def store_urgency_data(
self,
session_id: str,
urgency_level: str,
reasoning: str,
confidence: str,
time_sensitive: bool
) -> None:
"""Store urgency classification results."""
if session_id in self.sessions:
session = self.sessions[session_id]
session.urgency_level = urgency_level
session.urgency_reasoning = reasoning
session.urgency_confidence = confidence
session.time_sensitive = time_sensitive
logger.debug(f"Stored urgency data for session {session_id}: {urgency_level}")
def store_care_data(
self,
session_id: str,
care_setting: str,
timeline: str,
next_steps: List[str],
self_care: List[str],
preparation: List[str]
) -> None:
"""Store care recommendation results."""
if session_id in self.sessions:
session = self.sessions[session_id]
session.care_setting = care_setting
session.timeline = timeline
session.next_steps = next_steps
session.self_care = self_care
session.preparation = preparation
logger.debug(f"Stored care data for session {session_id}")
def store_communication_data(
self,
session_id: str,
report: str,
summary: str,
formatted_report: str
) -> None:
"""Store communication agent results."""
if session_id in self.sessions:
session = self.sessions[session_id]
session.final_report = report
session.report_summary = summary
session.formatted_report = formatted_report
logger.debug(f"Stored communication data for session {session_id}")
def mark_completed(self, session_id: str) -> None:
"""Mark session as completed."""
if session_id in self.sessions:
session = self.sessions[session_id]
session.state = TriageState.COMPLETED
session.completed_at = datetime.now()
logger.info(f"Session {session_id} completed")
def mark_error(self, session_id: str, error_message: str) -> None:
"""Mark session as having an error."""
if session_id in self.sessions:
session = self.sessions[session_id]
session.state = TriageState.ERROR
session.error_message = error_message
logger.error(f"Session {session_id} error: {error_message}")
def get_session_data(self, session_id: str) -> Dict[str, Any]:
"""Get all data for a session."""
session = self.get_session(session_id)
if session:
return session.to_dict()
return {}
def cleanup_old_sessions(self, max_age_hours: int = 24) -> int:
"""Remove old sessions. Returns number of sessions removed."""
from datetime import timedelta
now = datetime.now()
cutoff = now - timedelta(hours=max_age_hours)
to_remove = []
for session_id, session in self.sessions.items():
session_time = session.completed_at or session.started_at
if session_time < cutoff:
to_remove.append(session_id)
for session_id in to_remove:
del self.sessions[session_id]
if to_remove:
logger.info(f"Cleaned up {len(to_remove)} old sessions")
return len(to_remove)
__all__ = ["AgentCoordinator", "TriageSession", "TriageState"]