careflow / memory_agent.py
omgy's picture
Upload 8 files
10fcca6 verified
"""
Memory Agent (State Manager Agent) for CareFlow Nexus
Agent 1: Memorizes all hospital data and provides state queries
This agent is 50% rule-based (data queries, metrics) and 50% AI (analysis, bottleneck detection)
"""
import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from base_agent import BaseAgent
from prompts.prompt_templates import StateManagerPrompts
from services.firebase_service import FirebaseService
from services.gemini_service import GeminiService
from utils.response_parser import ResponseParser
logger = logging.getLogger(__name__)
class MemoryAgent(BaseAgent):
"""
State Manager Agent - Memorizes and manages hospital state
Responsibilities:
- Load and cache all hospital data (beds, staff, patients, tasks)
- Provide fast queries to other agents
- Monitor system state in real-time
- Detect bottlenecks and anomalies (AI-powered)
- Generate state analysis reports (AI-powered)
"""
def __init__(
self,
firebase_service: FirebaseService,
gemini_service: GeminiService,
refresh_interval: int = 300, # 5 minutes
):
"""
Initialize Memory Agent
Args:
firebase_service: Firebase service instance
gemini_service: Gemini AI service instance
refresh_interval: How often to refresh state cache (seconds)
"""
super().__init__(
agent_id="memory_agent_001",
agent_type="state_manager",
firebase_service=firebase_service,
gemini_service=gemini_service,
)
self.refresh_interval = refresh_interval
self.state_cache = {
"beds": [],
"patients": [],
"staff": [],
"tasks": [],
"last_refresh": None,
}
self.logger.info("Memory Agent initialized")
async def initialize(self) -> bool:
"""
Initialize agent by loading all hospital data
Returns:
True if successful
"""
try:
self.logger.info("Initializing Memory Agent - loading hospital data...")
await self.refresh_state()
self.logger.info("Memory Agent initialization complete")
return True
except Exception as e:
self.logger.error(f"Failed to initialize Memory Agent: {e}")
return False
async def process(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process incoming requests
Args:
request_data: Request with 'type' and optional parameters
Returns:
Response dictionary
"""
try:
request_type = request_data.get("type", "")
# Auto-refresh if cache is stale
await self._check_and_refresh()
# Route request to appropriate handler
if request_type == "get_available_beds":
result = await self.get_available_beds(request_data.get("filters"))
return self.format_response(True, result, "Available beds retrieved")
elif request_type == "get_patient_requirements":
patient_id = request_data.get("patient_id")
result = await self.get_patient_requirements(patient_id)
return self.format_response(
True, result, "Patient requirements retrieved"
)
elif request_type == "get_staff_availability":
role = request_data.get("role")
ward = request_data.get("ward")
result = await self.get_staff_availability(role, ward)
return self.format_response(
True, result, "Staff availability retrieved"
)
elif request_type == "get_system_state":
result = await self.get_system_state()
return self.format_response(True, result, "System state retrieved")
elif request_type == "analyze_state":
result = await self.analyze_state_with_ai()
return self.format_response(True, result, "State analysis complete")
elif request_type == "detect_bottlenecks":
result = await self.detect_bottlenecks()
return self.format_response(
True, result, "Bottleneck detection complete"
)
elif request_type == "get_metrics":
result = await self.get_metrics()
return self.format_response(True, result, "Metrics retrieved")
else:
return self.format_response(
False,
None,
f"Unknown request type: {request_type}",
"invalid_request",
)
except Exception as e:
self.logger.error(f"Error processing request: {e}")
await self.log_error(str(e), request_data, "process_error")
return self.format_response(False, None, str(e), "processing_error")
# ==================== RULE-BASED METHODS (50%) ====================
async def refresh_state(self) -> None:
"""Refresh entire hospital state from Firebase"""
try:
self.logger.info("Refreshing hospital state cache...")
# Load all data in parallel would be ideal, but we'll do sequential for simplicity
self.state_cache["beds"] = await self.firebase.get_all_beds()
self.state_cache["patients"] = await self.firebase.get_all_patients()
self.state_cache["staff"] = await self.firebase.get_all_staff()
self.state_cache["tasks"] = await self.firebase.get_tasks(
{"status": ["pending", "in_progress"]}
)
self.state_cache["last_refresh"] = datetime.now()
self.logger.info(
f"State refreshed - Beds: {len(self.state_cache['beds'])}, "
f"Patients: {len(self.state_cache['patients'])}, "
f"Staff: {len(self.state_cache['staff'])}, "
f"Tasks: {len(self.state_cache['tasks'])}"
)
except Exception as e:
self.logger.error(f"Error refreshing state: {e}")
raise
async def _check_and_refresh(self) -> None:
"""Check if cache is stale and refresh if needed"""
last_refresh = self.state_cache.get("last_refresh")
if last_refresh is None:
await self.refresh_state()
return
time_since_refresh = (datetime.now() - last_refresh).total_seconds()
if time_since_refresh > self.refresh_interval:
self.logger.info("Cache is stale, refreshing...")
await self.refresh_state()
async def get_available_beds(self, filters: Optional[Dict] = None) -> List[Dict]:
"""
Get available beds with optional filters (RULE-BASED)
Args:
filters: Optional filters like ward, has_oxygen, etc.
Returns:
List of available bed dictionaries
"""
beds = self.state_cache.get("beds", [])
available = [b for b in beds if b.get("status") == "ready"]
if not filters:
return available
# Apply filters
filtered = available
if "ward" in filters:
filtered = [b for b in filtered if b.get("ward") == filters["ward"]]
if "has_oxygen" in filters:
filtered = [
b
for b in filtered
if b.get("equipment", {}).get("has_oxygen") == filters["has_oxygen"]
]
if "has_ventilator" in filters:
filtered = [
b
for b in filtered
if b.get("equipment", {}).get("has_ventilator")
== filters["has_ventilator"]
]
if "is_isolation" in filters:
filtered = [
b
for b in filtered
if b.get("equipment", {}).get("is_isolation") == filters["is_isolation"]
]
if "floor" in filters:
filtered = [b for b in filtered if b.get("floor") == filters["floor"]]
self.logger.info(f"Found {len(filtered)} beds matching filters")
return filtered
async def get_patient_requirements(self, patient_id: str) -> Optional[Dict]:
"""
Get patient requirements (RULE-BASED)
Args:
patient_id: Patient ID
Returns:
Patient requirements dictionary or None
"""
patient = await self.firebase.get_patient(patient_id)
if not patient:
self.logger.warning(f"Patient {patient_id} not found")
return None
# Extract requirements
requirements = patient.get("requirements", {})
# Add diagnosis and severity for context
requirements["diagnosis"] = patient.get("diagnosis", "")
requirements["severity"] = patient.get("severity", "moderate")
requirements["mobility_status"] = patient.get("mobility_status", "ambulatory")
return requirements
async def get_staff_availability(
self, role: str, ward: Optional[str] = None
) -> List[Dict]:
"""
Get available staff by role and optional ward (RULE-BASED)
Args:
role: Staff role (nurse, cleaner, doctor)
ward: Optional ward filter
Returns:
List of available staff
"""
staff = self.state_cache.get("staff", [])
# Filter by role and on-shift
available = [
s
for s in staff
if s.get("role") == role
and s.get("is_on_shift", False)
and s.get("current_load", 0) < 5
]
# Filter by ward if specified
if ward:
available = [s for s in available if s.get("assigned_ward") == ward]
# Sort by workload (least busy first)
available.sort(key=lambda x: x.get("current_load", 0))
self.logger.info(f"Found {len(available)} available {role}s")
return available
async def get_system_state(self) -> Dict[str, Any]:
"""
Get complete system state snapshot (RULE-BASED)
Returns:
System state dictionary
"""
beds = self.state_cache.get("beds", [])
patients = self.state_cache.get("patients", [])
staff = self.state_cache.get("staff", [])
tasks = self.state_cache.get("tasks", [])
return {
"beds": {
"total": len(beds),
"available": len([b for b in beds if b["status"] == "ready"]),
"occupied": len([b for b in beds if b["status"] == "occupied"]),
"cleaning": len([b for b in beds if b["status"] == "cleaning"]),
"maintenance": len([b for b in beds if b["status"] == "maintenance"]),
},
"patients": {
"total": len(patients),
"waiting": len([p for p in patients if p.get("status") == "waiting"]),
"admitted": len([p for p in patients if p.get("status") == "admitted"]),
},
"staff": {
"total": len(staff),
"on_shift": len([s for s in staff if s.get("is_on_shift")]),
"nurses": len(
[s for s in staff if s["role"] == "nurse" and s.get("is_on_shift")]
),
"cleaners": len(
[
s
for s in staff
if s["role"] == "cleaner" and s.get("is_on_shift")
]
),
},
"tasks": {
"total": len(tasks),
"pending": len([t for t in tasks if t["status"] == "pending"]),
"in_progress": len([t for t in tasks if t["status"] == "in_progress"]),
},
"timestamp": datetime.now().isoformat(),
}
async def get_metrics(self) -> Dict[str, Any]:
"""
Get operational metrics (RULE-BASED)
Returns:
Metrics dictionary
"""
return await self.firebase.get_metrics()
# ==================== AI-POWERED METHODS (50%) ====================
async def analyze_state_with_ai(self) -> Dict[str, Any]:
"""
Use Gemini AI to analyze current hospital state (AI-POWERED)
Returns:
Analysis with alerts, bottlenecks, forecast, recommendations
"""
try:
self.logger.info("Running AI state analysis...")
# Prepare state summary
state = await self.get_system_state()
ward_summary = self._prepare_ward_summary()
# Build prompt
prompt = StateManagerPrompts.STATE_ANALYSIS.format(
total_beds=state["beds"]["total"],
available_beds=state["beds"]["available"],
occupied_beds=state["beds"]["occupied"],
cleaning_beds=state["beds"]["cleaning"],
maintenance_beds=state["beds"]["maintenance"],
utilization_rate=round(
(state["beds"]["occupied"] / state["beds"]["total"] * 100)
if state["beds"]["total"] > 0
else 0,
1,
),
total_patients=state["patients"]["total"],
waiting_patients=state["patients"]["waiting"],
admitted_patients=state["patients"]["admitted"],
nurses_count=state["staff"]["nurses"],
cleaners_count=state["staff"]["cleaners"],
total_staff=state["staff"]["on_shift"],
active_tasks=state["tasks"]["total"],
pending_tasks=state["tasks"]["pending"],
in_progress_tasks=state["tasks"]["in_progress"],
overdue_tasks=0, # TODO: Calculate overdue
ward_summary=ward_summary,
)
# Call Gemini AI
response = await self.gemini.generate_json_response(prompt, temperature=0.3)
# Parse response
if response:
parsed = ResponseParser.parse_state_analysis_response(response)
# Log decision
await self.log_decision(
action="state_analysis",
input_data={"state": state},
output_data=parsed,
reasoning="AI-powered state analysis completed",
)
return parsed
else:
self.logger.warning("Empty response from AI, returning default")
return self._default_analysis_response()
except Exception as e:
self.logger.error(f"Error in AI state analysis: {e}")
return self._default_analysis_response()
async def detect_bottlenecks(self) -> List[Dict[str, Any]]:
"""
Detect operational bottlenecks (HYBRID: Rule-based + AI)
Returns:
List of bottleneck dictionaries
"""
bottlenecks = []
# Rule-based detection
rule_bottlenecks = await self._detect_bottlenecks_rule_based()
bottlenecks.extend(rule_bottlenecks)
# AI-enhanced detection
if rule_bottlenecks:
ai_analysis = await self._detect_bottlenecks_ai()
if ai_analysis:
bottlenecks.extend(ai_analysis)
return bottlenecks
async def _detect_bottlenecks_rule_based(self) -> List[Dict[str, Any]]:
"""Rule-based bottleneck detection"""
bottlenecks = []
beds = self.state_cache.get("beds", [])
staff = self.state_cache.get("staff", [])
tasks = self.state_cache.get("tasks", [])
# Check cleaning backlog
cleaning_tasks = [
t
for t in tasks
if t.get("task_type") == "cleaning" and t["status"] == "pending"
]
if len(cleaning_tasks) > 5:
severity = (
"critical"
if len(cleaning_tasks) > 10
else "high"
if len(cleaning_tasks) > 7
else "medium"
)
bottlenecks.append(
{
"type": "cleaning_backlog",
"severity": severity,
"count": len(cleaning_tasks),
"description": f"{len(cleaning_tasks)} cleaning tasks pending",
"recommendation": "Assign more cleaners or prioritize critical cleaning tasks",
}
)
# Check staff overload
overloaded_staff = [s for s in staff if s.get("current_load", 0) >= 5]
if len(overloaded_staff) > 0:
bottlenecks.append(
{
"type": "staff_overload",
"severity": "high" if len(overloaded_staff) > 3 else "medium",
"count": len(overloaded_staff),
"description": f"{len(overloaded_staff)} staff members at maximum workload",
"recommendation": "Redistribute tasks or call additional staff",
}
)
# Check bed capacity
available_beds = [b for b in beds if b["status"] == "ready"]
total_beds = len(beds)
availability_rate = (
(len(available_beds) / total_beds * 100) if total_beds > 0 else 0
)
if availability_rate < 10:
bottlenecks.append(
{
"type": "critical_capacity",
"severity": "critical",
"count": len(available_beds),
"description": f"Only {len(available_beds)} beds available ({availability_rate:.1f}%)",
"recommendation": "Expedite discharges and cleaning tasks urgently",
}
)
elif availability_rate < 20:
bottlenecks.append(
{
"type": "low_capacity",
"severity": "high",
"count": len(available_beds),
"description": f"Low bed availability: {len(available_beds)} beds ({availability_rate:.1f}%)",
"recommendation": "Monitor closely and prepare for capacity issues",
}
)
return bottlenecks
async def _detect_bottlenecks_ai(self) -> List[Dict[str, Any]]:
"""AI-powered bottleneck detection for complex patterns"""
try:
state = await self.get_system_state()
response = await self.gemini.detect_bottlenecks(state)
if response and "bottlenecks" in response:
return response["bottlenecks"]
return []
except Exception as e:
self.logger.error(f"Error in AI bottleneck detection: {e}")
return []
def _prepare_ward_summary(self) -> str:
"""Prepare ward-level summary for prompts"""
beds = self.state_cache.get("beds", [])
wards = {}
for bed in beds:
ward = bed.get("ward", "Unknown")
if ward not in wards:
wards[ward] = {"total": 0, "available": 0, "occupied": 0}
wards[ward]["total"] += 1
if bed["status"] == "ready":
wards[ward]["available"] += 1
elif bed["status"] == "occupied":
wards[ward]["occupied"] += 1
summary_lines = []
for ward, stats in wards.items():
occupancy = (
(stats["occupied"] / stats["total"] * 100) if stats["total"] > 0 else 0
)
summary_lines.append(
f" - {ward}: {stats['available']}/{stats['total']} available ({occupancy:.0f}% occupied)"
)
return "\n".join(summary_lines) if summary_lines else " No ward data available"
def _default_analysis_response(self) -> Dict[str, Any]:
"""Default response when AI fails"""
return {
"critical_alerts": [],
"bottlenecks": [],
"capacity_forecast": {
"next_4_hours": "Unable to generate forecast",
"bed_availability_trend": "stable",
"staffing_adequacy": "unknown",
},
"recommendations": ["Refresh data and try again"],
}
def get_capabilities(self) -> List[str]:
"""Get agent capabilities"""
return [
"get_available_beds",
"get_patient_requirements",
"get_staff_availability",
"get_system_state",
"analyze_state",
"detect_bottlenecks",
"get_metrics",
]