Spaces:
Sleeping
Sleeping
| """ | |
| Cross-Session Memory Service — Redis cache + PostgreSQL persistence. | |
| Provides session recovery (if a WebSocket disconnects mid-encounter) | |
| and cross-encounter patient context (accumulated medical history). | |
| Falls back to in-process dict cache if Redis/PostgreSQL are unavailable. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import time | |
| from dataclasses import dataclass, field | |
| from typing import Any, Dict, List, Optional | |
| from app.config import settings | |
| logger = logging.getLogger(__name__) | |
| # Redis key prefixes | |
| _SESSION_PREFIX = "session:" | |
| _PATIENT_PREFIX = "patient:" | |
| _SESSION_TTL = 3600 * 4 # 4 hours | |
| _PATIENT_TTL = 3600 * 24 * 90 # 90 days | |
| class SessionCache: | |
| """Cached session state for recovery.""" | |
| session_id: str | |
| conversation_mode: str = "patient" | |
| state: str = "greeting" | |
| turn_count: int = 0 | |
| transcript: str = "" | |
| chief_complaint: str = "" | |
| entities: Dict[str, Any] = field(default_factory=dict) | |
| vitals: Dict[str, Any] = field(default_factory=dict) | |
| followup_qa: List[Dict[str, str]] = field(default_factory=list) | |
| detected_language: str = "en" | |
| detected_specialty: str = "general" | |
| soap_snapshot: Optional[Dict[str, str]] = None | |
| last_updated: float = field(default_factory=time.time) | |
| def to_dict(self) -> Dict[str, Any]: | |
| return { | |
| "session_id": self.session_id, | |
| "conversation_mode": self.conversation_mode, | |
| "state": self.state, | |
| "turn_count": self.turn_count, | |
| "transcript": self.transcript, | |
| "chief_complaint": self.chief_complaint, | |
| "entities": self.entities, | |
| "vitals": self.vitals, | |
| "followup_qa": self.followup_qa, | |
| "detected_language": self.detected_language, | |
| "detected_specialty": self.detected_specialty, | |
| "soap_snapshot": self.soap_snapshot, | |
| "last_updated": self.last_updated, | |
| } | |
| def from_dict(cls, d: Dict[str, Any]) -> "SessionCache": | |
| return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__}) | |
| class PatientMemory: | |
| """Long-term patient context across encounters.""" | |
| patient_id: str | |
| encounter_count: int = 0 | |
| known_conditions: List[Dict[str, str]] = field(default_factory=list) | |
| known_medications: List[Dict[str, str]] = field(default_factory=list) | |
| allergies: List[str] = field(default_factory=list) | |
| preferred_language: str = "en" | |
| communication_preferences: Dict[str, Any] = field(default_factory=dict) | |
| last_encounter_summary: str = "" | |
| last_encounter_date: Optional[str] = None | |
| last_updated: float = field(default_factory=time.time) | |
| def to_dict(self) -> Dict[str, Any]: | |
| return { | |
| "patient_id": self.patient_id, | |
| "encounter_count": self.encounter_count, | |
| "known_conditions": self.known_conditions, | |
| "known_medications": self.known_medications, | |
| "allergies": self.allergies, | |
| "preferred_language": self.preferred_language, | |
| "communication_preferences": self.communication_preferences, | |
| "last_encounter_summary": self.last_encounter_summary, | |
| "last_encounter_date": self.last_encounter_date, | |
| "last_updated": self.last_updated, | |
| } | |
| def from_dict(cls, d: Dict[str, Any]) -> "PatientMemory": | |
| return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__}) | |
| class MemoryService: | |
| """Unified memory service with Redis + PostgreSQL backends. | |
| Usage: | |
| memory = get_memory_service() | |
| await memory.connect() | |
| # Save/restore session | |
| await memory.save_session(session_cache) | |
| restored = await memory.get_session("session-123") | |
| # Patient memory | |
| await memory.update_patient_memory("patient-456", new_conditions=[...]) | |
| patient = await memory.get_patient_memory("patient-456") | |
| """ | |
| def __init__(self): | |
| self._redis = None | |
| self._use_redis = False | |
| self._connected = False | |
| # In-process fallback caches | |
| self._session_cache: Dict[str, str] = {} | |
| self._patient_cache: Dict[str, str] = {} | |
| async def connect(self) -> None: | |
| """Connect to Redis. Falls back to in-process cache.""" | |
| try: | |
| import redis.asyncio as aioredis | |
| redis_url = getattr(settings, "redis_url", "redis://localhost:6379/0") | |
| self._redis = aioredis.from_url( | |
| redis_url, | |
| encoding="utf-8", | |
| decode_responses=True, | |
| ) | |
| # Test connection | |
| await self._redis.ping() | |
| self._use_redis = True | |
| self._connected = True | |
| logger.info(f"Connected to Redis at {redis_url}") | |
| except ImportError: | |
| logger.info( | |
| "redis-py not installed. Using in-process memory cache. " | |
| "Install with: pip install redis[hiredis]>=5.0.0" | |
| ) | |
| self._connected = True | |
| except Exception as e: | |
| logger.warning(f"Redis connection failed ({e}). Using in-process cache.") | |
| self._connected = True | |
| async def disconnect(self) -> None: | |
| """Disconnect from Redis.""" | |
| if self._redis and self._use_redis: | |
| try: | |
| await self._redis.aclose() | |
| except Exception: | |
| pass | |
| self._connected = False | |
| # ----------------------------------------------------------------- | |
| # Session Cache | |
| # ----------------------------------------------------------------- | |
| async def save_session(self, session: SessionCache) -> None: | |
| """Save session state for recovery.""" | |
| session.last_updated = time.time() | |
| data = json.dumps(session.to_dict()) | |
| key = f"{_SESSION_PREFIX}{session.session_id}" | |
| if self._use_redis and self._redis: | |
| try: | |
| await self._redis.setex(key, _SESSION_TTL, data) | |
| return | |
| except Exception as e: | |
| logger.warning(f"Redis session save failed: {e}") | |
| self._session_cache[key] = data | |
| async def get_session(self, session_id: str) -> Optional[SessionCache]: | |
| """Restore a cached session.""" | |
| key = f"{_SESSION_PREFIX}{session_id}" | |
| if self._use_redis and self._redis: | |
| try: | |
| data = await self._redis.get(key) | |
| if data: | |
| return SessionCache.from_dict(json.loads(data)) | |
| return None | |
| except Exception as e: | |
| logger.warning(f"Redis session get failed: {e}") | |
| data = self._session_cache.get(key) | |
| if data: | |
| return SessionCache.from_dict(json.loads(data)) | |
| return None | |
| async def delete_session(self, session_id: str) -> None: | |
| """Remove a session from cache.""" | |
| key = f"{_SESSION_PREFIX}{session_id}" | |
| if self._use_redis and self._redis: | |
| try: | |
| await self._redis.delete(key) | |
| return | |
| except Exception: | |
| pass | |
| self._session_cache.pop(key, None) | |
| # ----------------------------------------------------------------- | |
| # Patient Memory | |
| # ----------------------------------------------------------------- | |
| async def get_patient_memory(self, patient_id: str) -> Optional[PatientMemory]: | |
| """Get long-term patient memory.""" | |
| key = f"{_PATIENT_PREFIX}{patient_id}" | |
| if self._use_redis and self._redis: | |
| try: | |
| data = await self._redis.get(key) | |
| if data: | |
| return PatientMemory.from_dict(json.loads(data)) | |
| return None | |
| except Exception as e: | |
| logger.warning(f"Redis patient memory get failed: {e}") | |
| data = self._patient_cache.get(key) | |
| if data: | |
| return PatientMemory.from_dict(json.loads(data)) | |
| return None | |
| async def save_patient_memory(self, memory: PatientMemory) -> None: | |
| """Save patient memory.""" | |
| memory.last_updated = time.time() | |
| data = json.dumps(memory.to_dict()) | |
| key = f"{_PATIENT_PREFIX}{memory.patient_id}" | |
| if self._use_redis and self._redis: | |
| try: | |
| await self._redis.setex(key, _PATIENT_TTL, data) | |
| return | |
| except Exception as e: | |
| logger.warning(f"Redis patient memory save failed: {e}") | |
| self._patient_cache[key] = data | |
| async def update_patient_after_encounter( | |
| self, | |
| patient_id: str, | |
| new_conditions: Optional[List[Dict[str, str]]] = None, | |
| new_medications: Optional[List[Dict[str, str]]] = None, | |
| encounter_summary: str = "", | |
| encounter_date: Optional[str] = None, | |
| ) -> PatientMemory: | |
| """Update patient memory after an encounter concludes. | |
| Merges new entities into existing memory, avoiding duplicates. | |
| """ | |
| existing = await self.get_patient_memory(patient_id) | |
| if not existing: | |
| existing = PatientMemory(patient_id=patient_id) | |
| existing.encounter_count += 1 | |
| # Merge conditions (deduplicate by text) | |
| if new_conditions: | |
| existing_texts = {c.get("text", "").lower() for c in existing.known_conditions} | |
| for cond in new_conditions: | |
| if cond.get("text", "").lower() not in existing_texts: | |
| existing.known_conditions.append(cond) | |
| # Merge medications | |
| if new_medications: | |
| existing_texts = {m.get("text", "").lower() for m in existing.known_medications} | |
| for med in new_medications: | |
| if med.get("text", "").lower() not in existing_texts: | |
| existing.known_medications.append(med) | |
| if encounter_summary: | |
| existing.last_encounter_summary = encounter_summary | |
| if encounter_date: | |
| existing.last_encounter_date = encounter_date | |
| await self.save_patient_memory(existing) | |
| return existing | |
| # ----------------------------------------------------------------- | |
| # Bulk operations | |
| # ----------------------------------------------------------------- | |
| async def get_active_sessions(self) -> List[str]: | |
| """List all active session IDs.""" | |
| if self._use_redis and self._redis: | |
| try: | |
| keys = [] | |
| async for key in self._redis.scan_iter(f"{_SESSION_PREFIX}*"): | |
| session_id = key.replace(_SESSION_PREFIX, "") | |
| keys.append(session_id) | |
| return keys | |
| except Exception: | |
| pass | |
| return [ | |
| k.replace(_SESSION_PREFIX, "") | |
| for k in self._session_cache.keys() | |
| ] | |
| # Singleton | |
| _memory_service: Optional[MemoryService] = None | |
| def get_memory_service() -> MemoryService: | |
| global _memory_service | |
| if _memory_service is None: | |
| _memory_service = MemoryService() | |
| return _memory_service | |