|
|
"""
|
|
|
Memory Agent (State Manager Agent) for CareFlow Nexus
|
|
|
Agent 1: Memorizes all hospital data and provides state queries
|
|
|
|
|
|
This agent is 50% rule-based (data queries, metrics) and 50% AI (analysis, bottleneck detection)
|
|
|
"""
|
|
|
|
|
|
import logging
|
|
|
from datetime import datetime, timedelta
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
from base_agent import BaseAgent
|
|
|
from prompts.prompt_templates import StateManagerPrompts
|
|
|
from services.firebase_service import FirebaseService
|
|
|
from services.gemini_service import GeminiService
|
|
|
from utils.response_parser import ResponseParser
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class MemoryAgent(BaseAgent):
|
|
|
"""
|
|
|
State Manager Agent - Memorizes and manages hospital state
|
|
|
|
|
|
Responsibilities:
|
|
|
- Load and cache all hospital data (beds, staff, patients, tasks)
|
|
|
- Provide fast queries to other agents
|
|
|
- Monitor system state in real-time
|
|
|
- Detect bottlenecks and anomalies (AI-powered)
|
|
|
- Generate state analysis reports (AI-powered)
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
firebase_service: FirebaseService,
|
|
|
gemini_service: GeminiService,
|
|
|
refresh_interval: int = 300,
|
|
|
):
|
|
|
"""
|
|
|
Initialize Memory Agent
|
|
|
|
|
|
Args:
|
|
|
firebase_service: Firebase service instance
|
|
|
gemini_service: Gemini AI service instance
|
|
|
refresh_interval: How often to refresh state cache (seconds)
|
|
|
"""
|
|
|
super().__init__(
|
|
|
agent_id="memory_agent_001",
|
|
|
agent_type="state_manager",
|
|
|
firebase_service=firebase_service,
|
|
|
gemini_service=gemini_service,
|
|
|
)
|
|
|
|
|
|
self.refresh_interval = refresh_interval
|
|
|
self.state_cache = {
|
|
|
"beds": [],
|
|
|
"patients": [],
|
|
|
"staff": [],
|
|
|
"tasks": [],
|
|
|
"last_refresh": None,
|
|
|
}
|
|
|
|
|
|
self.logger.info("Memory Agent initialized")
|
|
|
|
|
|
async def initialize(self) -> bool:
|
|
|
"""
|
|
|
Initialize agent by loading all hospital data
|
|
|
|
|
|
Returns:
|
|
|
True if successful
|
|
|
"""
|
|
|
try:
|
|
|
self.logger.info("Initializing Memory Agent - loading hospital data...")
|
|
|
await self.refresh_state()
|
|
|
self.logger.info("Memory Agent initialization complete")
|
|
|
return True
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Failed to initialize Memory Agent: {e}")
|
|
|
return False
|
|
|
|
|
|
async def process(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Process incoming requests
|
|
|
|
|
|
Args:
|
|
|
request_data: Request with 'type' and optional parameters
|
|
|
|
|
|
Returns:
|
|
|
Response dictionary
|
|
|
"""
|
|
|
try:
|
|
|
request_type = request_data.get("type", "")
|
|
|
|
|
|
|
|
|
await self._check_and_refresh()
|
|
|
|
|
|
|
|
|
if request_type == "get_available_beds":
|
|
|
result = await self.get_available_beds(request_data.get("filters"))
|
|
|
return self.format_response(True, result, "Available beds retrieved")
|
|
|
|
|
|
elif request_type == "get_patient_requirements":
|
|
|
patient_id = request_data.get("patient_id")
|
|
|
result = await self.get_patient_requirements(patient_id)
|
|
|
return self.format_response(
|
|
|
True, result, "Patient requirements retrieved"
|
|
|
)
|
|
|
|
|
|
elif request_type == "get_staff_availability":
|
|
|
role = request_data.get("role")
|
|
|
ward = request_data.get("ward")
|
|
|
result = await self.get_staff_availability(role, ward)
|
|
|
return self.format_response(
|
|
|
True, result, "Staff availability retrieved"
|
|
|
)
|
|
|
|
|
|
elif request_type == "get_system_state":
|
|
|
result = await self.get_system_state()
|
|
|
return self.format_response(True, result, "System state retrieved")
|
|
|
|
|
|
elif request_type == "analyze_state":
|
|
|
result = await self.analyze_state_with_ai()
|
|
|
return self.format_response(True, result, "State analysis complete")
|
|
|
|
|
|
elif request_type == "detect_bottlenecks":
|
|
|
result = await self.detect_bottlenecks()
|
|
|
return self.format_response(
|
|
|
True, result, "Bottleneck detection complete"
|
|
|
)
|
|
|
|
|
|
elif request_type == "get_metrics":
|
|
|
result = await self.get_metrics()
|
|
|
return self.format_response(True, result, "Metrics retrieved")
|
|
|
|
|
|
else:
|
|
|
return self.format_response(
|
|
|
False,
|
|
|
None,
|
|
|
f"Unknown request type: {request_type}",
|
|
|
"invalid_request",
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Error processing request: {e}")
|
|
|
await self.log_error(str(e), request_data, "process_error")
|
|
|
return self.format_response(False, None, str(e), "processing_error")
|
|
|
|
|
|
|
|
|
|
|
|
async def refresh_state(self) -> None:
|
|
|
"""Refresh entire hospital state from Firebase"""
|
|
|
try:
|
|
|
self.logger.info("Refreshing hospital state cache...")
|
|
|
|
|
|
|
|
|
self.state_cache["beds"] = await self.firebase.get_all_beds()
|
|
|
self.state_cache["patients"] = await self.firebase.get_all_patients()
|
|
|
self.state_cache["staff"] = await self.firebase.get_all_staff()
|
|
|
self.state_cache["tasks"] = await self.firebase.get_tasks(
|
|
|
{"status": ["pending", "in_progress"]}
|
|
|
)
|
|
|
self.state_cache["last_refresh"] = datetime.now()
|
|
|
|
|
|
self.logger.info(
|
|
|
f"State refreshed - Beds: {len(self.state_cache['beds'])}, "
|
|
|
f"Patients: {len(self.state_cache['patients'])}, "
|
|
|
f"Staff: {len(self.state_cache['staff'])}, "
|
|
|
f"Tasks: {len(self.state_cache['tasks'])}"
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Error refreshing state: {e}")
|
|
|
raise
|
|
|
|
|
|
async def _check_and_refresh(self) -> None:
|
|
|
"""Check if cache is stale and refresh if needed"""
|
|
|
last_refresh = self.state_cache.get("last_refresh")
|
|
|
|
|
|
if last_refresh is None:
|
|
|
await self.refresh_state()
|
|
|
return
|
|
|
|
|
|
time_since_refresh = (datetime.now() - last_refresh).total_seconds()
|
|
|
|
|
|
if time_since_refresh > self.refresh_interval:
|
|
|
self.logger.info("Cache is stale, refreshing...")
|
|
|
await self.refresh_state()
|
|
|
|
|
|
async def get_available_beds(self, filters: Optional[Dict] = None) -> List[Dict]:
|
|
|
"""
|
|
|
Get available beds with optional filters (RULE-BASED)
|
|
|
|
|
|
Args:
|
|
|
filters: Optional filters like ward, has_oxygen, etc.
|
|
|
|
|
|
Returns:
|
|
|
List of available bed dictionaries
|
|
|
"""
|
|
|
beds = self.state_cache.get("beds", [])
|
|
|
available = [b for b in beds if b.get("status") == "ready"]
|
|
|
|
|
|
if not filters:
|
|
|
return available
|
|
|
|
|
|
|
|
|
filtered = available
|
|
|
|
|
|
if "ward" in filters:
|
|
|
filtered = [b for b in filtered if b.get("ward") == filters["ward"]]
|
|
|
|
|
|
if "has_oxygen" in filters:
|
|
|
filtered = [
|
|
|
b
|
|
|
for b in filtered
|
|
|
if b.get("equipment", {}).get("has_oxygen") == filters["has_oxygen"]
|
|
|
]
|
|
|
|
|
|
if "has_ventilator" in filters:
|
|
|
filtered = [
|
|
|
b
|
|
|
for b in filtered
|
|
|
if b.get("equipment", {}).get("has_ventilator")
|
|
|
== filters["has_ventilator"]
|
|
|
]
|
|
|
|
|
|
if "is_isolation" in filters:
|
|
|
filtered = [
|
|
|
b
|
|
|
for b in filtered
|
|
|
if b.get("equipment", {}).get("is_isolation") == filters["is_isolation"]
|
|
|
]
|
|
|
|
|
|
if "floor" in filters:
|
|
|
filtered = [b for b in filtered if b.get("floor") == filters["floor"]]
|
|
|
|
|
|
self.logger.info(f"Found {len(filtered)} beds matching filters")
|
|
|
return filtered
|
|
|
|
|
|
async def get_patient_requirements(self, patient_id: str) -> Optional[Dict]:
|
|
|
"""
|
|
|
Get patient requirements (RULE-BASED)
|
|
|
|
|
|
Args:
|
|
|
patient_id: Patient ID
|
|
|
|
|
|
Returns:
|
|
|
Patient requirements dictionary or None
|
|
|
"""
|
|
|
patient = await self.firebase.get_patient(patient_id)
|
|
|
|
|
|
if not patient:
|
|
|
self.logger.warning(f"Patient {patient_id} not found")
|
|
|
return None
|
|
|
|
|
|
|
|
|
requirements = patient.get("requirements", {})
|
|
|
|
|
|
|
|
|
requirements["diagnosis"] = patient.get("diagnosis", "")
|
|
|
requirements["severity"] = patient.get("severity", "moderate")
|
|
|
requirements["mobility_status"] = patient.get("mobility_status", "ambulatory")
|
|
|
|
|
|
return requirements
|
|
|
|
|
|
async def get_staff_availability(
|
|
|
self, role: str, ward: Optional[str] = None
|
|
|
) -> List[Dict]:
|
|
|
"""
|
|
|
Get available staff by role and optional ward (RULE-BASED)
|
|
|
|
|
|
Args:
|
|
|
role: Staff role (nurse, cleaner, doctor)
|
|
|
ward: Optional ward filter
|
|
|
|
|
|
Returns:
|
|
|
List of available staff
|
|
|
"""
|
|
|
staff = self.state_cache.get("staff", [])
|
|
|
|
|
|
|
|
|
available = [
|
|
|
s
|
|
|
for s in staff
|
|
|
if s.get("role") == role
|
|
|
and s.get("is_on_shift", False)
|
|
|
and s.get("current_load", 0) < 5
|
|
|
]
|
|
|
|
|
|
|
|
|
if ward:
|
|
|
available = [s for s in available if s.get("assigned_ward") == ward]
|
|
|
|
|
|
|
|
|
available.sort(key=lambda x: x.get("current_load", 0))
|
|
|
|
|
|
self.logger.info(f"Found {len(available)} available {role}s")
|
|
|
return available
|
|
|
|
|
|
async def get_system_state(self) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Get complete system state snapshot (RULE-BASED)
|
|
|
|
|
|
Returns:
|
|
|
System state dictionary
|
|
|
"""
|
|
|
beds = self.state_cache.get("beds", [])
|
|
|
patients = self.state_cache.get("patients", [])
|
|
|
staff = self.state_cache.get("staff", [])
|
|
|
tasks = self.state_cache.get("tasks", [])
|
|
|
|
|
|
return {
|
|
|
"beds": {
|
|
|
"total": len(beds),
|
|
|
"available": len([b for b in beds if b["status"] == "ready"]),
|
|
|
"occupied": len([b for b in beds if b["status"] == "occupied"]),
|
|
|
"cleaning": len([b for b in beds if b["status"] == "cleaning"]),
|
|
|
"maintenance": len([b for b in beds if b["status"] == "maintenance"]),
|
|
|
},
|
|
|
"patients": {
|
|
|
"total": len(patients),
|
|
|
"waiting": len([p for p in patients if p.get("status") == "waiting"]),
|
|
|
"admitted": len([p for p in patients if p.get("status") == "admitted"]),
|
|
|
},
|
|
|
"staff": {
|
|
|
"total": len(staff),
|
|
|
"on_shift": len([s for s in staff if s.get("is_on_shift")]),
|
|
|
"nurses": len(
|
|
|
[s for s in staff if s["role"] == "nurse" and s.get("is_on_shift")]
|
|
|
),
|
|
|
"cleaners": len(
|
|
|
[
|
|
|
s
|
|
|
for s in staff
|
|
|
if s["role"] == "cleaner" and s.get("is_on_shift")
|
|
|
]
|
|
|
),
|
|
|
},
|
|
|
"tasks": {
|
|
|
"total": len(tasks),
|
|
|
"pending": len([t for t in tasks if t["status"] == "pending"]),
|
|
|
"in_progress": len([t for t in tasks if t["status"] == "in_progress"]),
|
|
|
},
|
|
|
"timestamp": datetime.now().isoformat(),
|
|
|
}
|
|
|
|
|
|
async def get_metrics(self) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Get operational metrics (RULE-BASED)
|
|
|
|
|
|
Returns:
|
|
|
Metrics dictionary
|
|
|
"""
|
|
|
return await self.firebase.get_metrics()
|
|
|
|
|
|
|
|
|
|
|
|
async def analyze_state_with_ai(self) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Use Gemini AI to analyze current hospital state (AI-POWERED)
|
|
|
|
|
|
Returns:
|
|
|
Analysis with alerts, bottlenecks, forecast, recommendations
|
|
|
"""
|
|
|
try:
|
|
|
self.logger.info("Running AI state analysis...")
|
|
|
|
|
|
|
|
|
state = await self.get_system_state()
|
|
|
ward_summary = self._prepare_ward_summary()
|
|
|
|
|
|
|
|
|
prompt = StateManagerPrompts.STATE_ANALYSIS.format(
|
|
|
total_beds=state["beds"]["total"],
|
|
|
available_beds=state["beds"]["available"],
|
|
|
occupied_beds=state["beds"]["occupied"],
|
|
|
cleaning_beds=state["beds"]["cleaning"],
|
|
|
maintenance_beds=state["beds"]["maintenance"],
|
|
|
utilization_rate=round(
|
|
|
(state["beds"]["occupied"] / state["beds"]["total"] * 100)
|
|
|
if state["beds"]["total"] > 0
|
|
|
else 0,
|
|
|
1,
|
|
|
),
|
|
|
total_patients=state["patients"]["total"],
|
|
|
waiting_patients=state["patients"]["waiting"],
|
|
|
admitted_patients=state["patients"]["admitted"],
|
|
|
nurses_count=state["staff"]["nurses"],
|
|
|
cleaners_count=state["staff"]["cleaners"],
|
|
|
total_staff=state["staff"]["on_shift"],
|
|
|
active_tasks=state["tasks"]["total"],
|
|
|
pending_tasks=state["tasks"]["pending"],
|
|
|
in_progress_tasks=state["tasks"]["in_progress"],
|
|
|
overdue_tasks=0,
|
|
|
ward_summary=ward_summary,
|
|
|
)
|
|
|
|
|
|
|
|
|
response = await self.gemini.generate_json_response(prompt, temperature=0.3)
|
|
|
|
|
|
|
|
|
if response:
|
|
|
parsed = ResponseParser.parse_state_analysis_response(response)
|
|
|
|
|
|
|
|
|
await self.log_decision(
|
|
|
action="state_analysis",
|
|
|
input_data={"state": state},
|
|
|
output_data=parsed,
|
|
|
reasoning="AI-powered state analysis completed",
|
|
|
)
|
|
|
|
|
|
return parsed
|
|
|
else:
|
|
|
self.logger.warning("Empty response from AI, returning default")
|
|
|
return self._default_analysis_response()
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Error in AI state analysis: {e}")
|
|
|
return self._default_analysis_response()
|
|
|
|
|
|
async def detect_bottlenecks(self) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
Detect operational bottlenecks (HYBRID: Rule-based + AI)
|
|
|
|
|
|
Returns:
|
|
|
List of bottleneck dictionaries
|
|
|
"""
|
|
|
bottlenecks = []
|
|
|
|
|
|
|
|
|
rule_bottlenecks = await self._detect_bottlenecks_rule_based()
|
|
|
bottlenecks.extend(rule_bottlenecks)
|
|
|
|
|
|
|
|
|
if rule_bottlenecks:
|
|
|
ai_analysis = await self._detect_bottlenecks_ai()
|
|
|
if ai_analysis:
|
|
|
bottlenecks.extend(ai_analysis)
|
|
|
|
|
|
return bottlenecks
|
|
|
|
|
|
async def _detect_bottlenecks_rule_based(self) -> List[Dict[str, Any]]:
|
|
|
"""Rule-based bottleneck detection"""
|
|
|
bottlenecks = []
|
|
|
|
|
|
beds = self.state_cache.get("beds", [])
|
|
|
staff = self.state_cache.get("staff", [])
|
|
|
tasks = self.state_cache.get("tasks", [])
|
|
|
|
|
|
|
|
|
cleaning_tasks = [
|
|
|
t
|
|
|
for t in tasks
|
|
|
if t.get("task_type") == "cleaning" and t["status"] == "pending"
|
|
|
]
|
|
|
if len(cleaning_tasks) > 5:
|
|
|
severity = (
|
|
|
"critical"
|
|
|
if len(cleaning_tasks) > 10
|
|
|
else "high"
|
|
|
if len(cleaning_tasks) > 7
|
|
|
else "medium"
|
|
|
)
|
|
|
bottlenecks.append(
|
|
|
{
|
|
|
"type": "cleaning_backlog",
|
|
|
"severity": severity,
|
|
|
"count": len(cleaning_tasks),
|
|
|
"description": f"{len(cleaning_tasks)} cleaning tasks pending",
|
|
|
"recommendation": "Assign more cleaners or prioritize critical cleaning tasks",
|
|
|
}
|
|
|
)
|
|
|
|
|
|
|
|
|
overloaded_staff = [s for s in staff if s.get("current_load", 0) >= 5]
|
|
|
if len(overloaded_staff) > 0:
|
|
|
bottlenecks.append(
|
|
|
{
|
|
|
"type": "staff_overload",
|
|
|
"severity": "high" if len(overloaded_staff) > 3 else "medium",
|
|
|
"count": len(overloaded_staff),
|
|
|
"description": f"{len(overloaded_staff)} staff members at maximum workload",
|
|
|
"recommendation": "Redistribute tasks or call additional staff",
|
|
|
}
|
|
|
)
|
|
|
|
|
|
|
|
|
available_beds = [b for b in beds if b["status"] == "ready"]
|
|
|
total_beds = len(beds)
|
|
|
availability_rate = (
|
|
|
(len(available_beds) / total_beds * 100) if total_beds > 0 else 0
|
|
|
)
|
|
|
|
|
|
if availability_rate < 10:
|
|
|
bottlenecks.append(
|
|
|
{
|
|
|
"type": "critical_capacity",
|
|
|
"severity": "critical",
|
|
|
"count": len(available_beds),
|
|
|
"description": f"Only {len(available_beds)} beds available ({availability_rate:.1f}%)",
|
|
|
"recommendation": "Expedite discharges and cleaning tasks urgently",
|
|
|
}
|
|
|
)
|
|
|
elif availability_rate < 20:
|
|
|
bottlenecks.append(
|
|
|
{
|
|
|
"type": "low_capacity",
|
|
|
"severity": "high",
|
|
|
"count": len(available_beds),
|
|
|
"description": f"Low bed availability: {len(available_beds)} beds ({availability_rate:.1f}%)",
|
|
|
"recommendation": "Monitor closely and prepare for capacity issues",
|
|
|
}
|
|
|
)
|
|
|
|
|
|
return bottlenecks
|
|
|
|
|
|
async def _detect_bottlenecks_ai(self) -> List[Dict[str, Any]]:
|
|
|
"""AI-powered bottleneck detection for complex patterns"""
|
|
|
try:
|
|
|
state = await self.get_system_state()
|
|
|
response = await self.gemini.detect_bottlenecks(state)
|
|
|
|
|
|
if response and "bottlenecks" in response:
|
|
|
return response["bottlenecks"]
|
|
|
return []
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Error in AI bottleneck detection: {e}")
|
|
|
return []
|
|
|
|
|
|
def _prepare_ward_summary(self) -> str:
|
|
|
"""Prepare ward-level summary for prompts"""
|
|
|
beds = self.state_cache.get("beds", [])
|
|
|
|
|
|
wards = {}
|
|
|
for bed in beds:
|
|
|
ward = bed.get("ward", "Unknown")
|
|
|
if ward not in wards:
|
|
|
wards[ward] = {"total": 0, "available": 0, "occupied": 0}
|
|
|
|
|
|
wards[ward]["total"] += 1
|
|
|
if bed["status"] == "ready":
|
|
|
wards[ward]["available"] += 1
|
|
|
elif bed["status"] == "occupied":
|
|
|
wards[ward]["occupied"] += 1
|
|
|
|
|
|
summary_lines = []
|
|
|
for ward, stats in wards.items():
|
|
|
occupancy = (
|
|
|
(stats["occupied"] / stats["total"] * 100) if stats["total"] > 0 else 0
|
|
|
)
|
|
|
summary_lines.append(
|
|
|
f" - {ward}: {stats['available']}/{stats['total']} available ({occupancy:.0f}% occupied)"
|
|
|
)
|
|
|
|
|
|
return "\n".join(summary_lines) if summary_lines else " No ward data available"
|
|
|
|
|
|
def _default_analysis_response(self) -> Dict[str, Any]:
|
|
|
"""Default response when AI fails"""
|
|
|
return {
|
|
|
"critical_alerts": [],
|
|
|
"bottlenecks": [],
|
|
|
"capacity_forecast": {
|
|
|
"next_4_hours": "Unable to generate forecast",
|
|
|
"bed_availability_trend": "stable",
|
|
|
"staffing_adequacy": "unknown",
|
|
|
},
|
|
|
"recommendations": ["Refresh data and try again"],
|
|
|
}
|
|
|
|
|
|
def get_capabilities(self) -> List[str]:
|
|
|
"""Get agent capabilities"""
|
|
|
return [
|
|
|
"get_available_beds",
|
|
|
"get_patient_requirements",
|
|
|
"get_staff_availability",
|
|
|
"get_system_state",
|
|
|
"analyze_state",
|
|
|
"detect_bottlenecks",
|
|
|
"get_metrics",
|
|
|
]
|
|
|
|