VoxDoc / app /services /memory /memory_service.py
joelthomas77's picture
Upload app code
60d4850 verified
"""
Cross-Session Memory Service — Redis cache + PostgreSQL persistence.
Provides session recovery (if a WebSocket disconnects mid-encounter)
and cross-encounter patient context (accumulated medical history).
Falls back to in-process dict cache if Redis/PostgreSQL are unavailable.
"""
from __future__ import annotations
import json
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from app.config import settings
logger = logging.getLogger(__name__)
# Redis key prefixes
_SESSION_PREFIX = "session:"
_PATIENT_PREFIX = "patient:"
_SESSION_TTL = 3600 * 4 # 4 hours
_PATIENT_TTL = 3600 * 24 * 90 # 90 days
@dataclass
class SessionCache:
"""Cached session state for recovery."""
session_id: str
conversation_mode: str = "patient"
state: str = "greeting"
turn_count: int = 0
transcript: str = ""
chief_complaint: str = ""
entities: Dict[str, Any] = field(default_factory=dict)
vitals: Dict[str, Any] = field(default_factory=dict)
followup_qa: List[Dict[str, str]] = field(default_factory=list)
detected_language: str = "en"
detected_specialty: str = "general"
soap_snapshot: Optional[Dict[str, str]] = None
last_updated: float = field(default_factory=time.time)
def to_dict(self) -> Dict[str, Any]:
return {
"session_id": self.session_id,
"conversation_mode": self.conversation_mode,
"state": self.state,
"turn_count": self.turn_count,
"transcript": self.transcript,
"chief_complaint": self.chief_complaint,
"entities": self.entities,
"vitals": self.vitals,
"followup_qa": self.followup_qa,
"detected_language": self.detected_language,
"detected_specialty": self.detected_specialty,
"soap_snapshot": self.soap_snapshot,
"last_updated": self.last_updated,
}
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> "SessionCache":
return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__})
@dataclass
class PatientMemory:
"""Long-term patient context across encounters."""
patient_id: str
encounter_count: int = 0
known_conditions: List[Dict[str, str]] = field(default_factory=list)
known_medications: List[Dict[str, str]] = field(default_factory=list)
allergies: List[str] = field(default_factory=list)
preferred_language: str = "en"
communication_preferences: Dict[str, Any] = field(default_factory=dict)
last_encounter_summary: str = ""
last_encounter_date: Optional[str] = None
last_updated: float = field(default_factory=time.time)
def to_dict(self) -> Dict[str, Any]:
return {
"patient_id": self.patient_id,
"encounter_count": self.encounter_count,
"known_conditions": self.known_conditions,
"known_medications": self.known_medications,
"allergies": self.allergies,
"preferred_language": self.preferred_language,
"communication_preferences": self.communication_preferences,
"last_encounter_summary": self.last_encounter_summary,
"last_encounter_date": self.last_encounter_date,
"last_updated": self.last_updated,
}
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> "PatientMemory":
return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__})
class MemoryService:
"""Unified memory service with Redis + PostgreSQL backends.
Usage:
memory = get_memory_service()
await memory.connect()
# Save/restore session
await memory.save_session(session_cache)
restored = await memory.get_session("session-123")
# Patient memory
await memory.update_patient_memory("patient-456", new_conditions=[...])
patient = await memory.get_patient_memory("patient-456")
"""
def __init__(self):
self._redis = None
self._use_redis = False
self._connected = False
# In-process fallback caches
self._session_cache: Dict[str, str] = {}
self._patient_cache: Dict[str, str] = {}
async def connect(self) -> None:
"""Connect to Redis. Falls back to in-process cache."""
try:
import redis.asyncio as aioredis
redis_url = getattr(settings, "redis_url", "redis://localhost:6379/0")
self._redis = aioredis.from_url(
redis_url,
encoding="utf-8",
decode_responses=True,
)
# Test connection
await self._redis.ping()
self._use_redis = True
self._connected = True
logger.info(f"Connected to Redis at {redis_url}")
except ImportError:
logger.info(
"redis-py not installed. Using in-process memory cache. "
"Install with: pip install redis[hiredis]>=5.0.0"
)
self._connected = True
except Exception as e:
logger.warning(f"Redis connection failed ({e}). Using in-process cache.")
self._connected = True
async def disconnect(self) -> None:
"""Disconnect from Redis."""
if self._redis and self._use_redis:
try:
await self._redis.aclose()
except Exception:
pass
self._connected = False
# -----------------------------------------------------------------
# Session Cache
# -----------------------------------------------------------------
async def save_session(self, session: SessionCache) -> None:
"""Save session state for recovery."""
session.last_updated = time.time()
data = json.dumps(session.to_dict())
key = f"{_SESSION_PREFIX}{session.session_id}"
if self._use_redis and self._redis:
try:
await self._redis.setex(key, _SESSION_TTL, data)
return
except Exception as e:
logger.warning(f"Redis session save failed: {e}")
self._session_cache[key] = data
async def get_session(self, session_id: str) -> Optional[SessionCache]:
"""Restore a cached session."""
key = f"{_SESSION_PREFIX}{session_id}"
if self._use_redis and self._redis:
try:
data = await self._redis.get(key)
if data:
return SessionCache.from_dict(json.loads(data))
return None
except Exception as e:
logger.warning(f"Redis session get failed: {e}")
data = self._session_cache.get(key)
if data:
return SessionCache.from_dict(json.loads(data))
return None
async def delete_session(self, session_id: str) -> None:
"""Remove a session from cache."""
key = f"{_SESSION_PREFIX}{session_id}"
if self._use_redis and self._redis:
try:
await self._redis.delete(key)
return
except Exception:
pass
self._session_cache.pop(key, None)
# -----------------------------------------------------------------
# Patient Memory
# -----------------------------------------------------------------
async def get_patient_memory(self, patient_id: str) -> Optional[PatientMemory]:
"""Get long-term patient memory."""
key = f"{_PATIENT_PREFIX}{patient_id}"
if self._use_redis and self._redis:
try:
data = await self._redis.get(key)
if data:
return PatientMemory.from_dict(json.loads(data))
return None
except Exception as e:
logger.warning(f"Redis patient memory get failed: {e}")
data = self._patient_cache.get(key)
if data:
return PatientMemory.from_dict(json.loads(data))
return None
async def save_patient_memory(self, memory: PatientMemory) -> None:
"""Save patient memory."""
memory.last_updated = time.time()
data = json.dumps(memory.to_dict())
key = f"{_PATIENT_PREFIX}{memory.patient_id}"
if self._use_redis and self._redis:
try:
await self._redis.setex(key, _PATIENT_TTL, data)
return
except Exception as e:
logger.warning(f"Redis patient memory save failed: {e}")
self._patient_cache[key] = data
async def update_patient_after_encounter(
self,
patient_id: str,
new_conditions: Optional[List[Dict[str, str]]] = None,
new_medications: Optional[List[Dict[str, str]]] = None,
encounter_summary: str = "",
encounter_date: Optional[str] = None,
) -> PatientMemory:
"""Update patient memory after an encounter concludes.
Merges new entities into existing memory, avoiding duplicates.
"""
existing = await self.get_patient_memory(patient_id)
if not existing:
existing = PatientMemory(patient_id=patient_id)
existing.encounter_count += 1
# Merge conditions (deduplicate by text)
if new_conditions:
existing_texts = {c.get("text", "").lower() for c in existing.known_conditions}
for cond in new_conditions:
if cond.get("text", "").lower() not in existing_texts:
existing.known_conditions.append(cond)
# Merge medications
if new_medications:
existing_texts = {m.get("text", "").lower() for m in existing.known_medications}
for med in new_medications:
if med.get("text", "").lower() not in existing_texts:
existing.known_medications.append(med)
if encounter_summary:
existing.last_encounter_summary = encounter_summary
if encounter_date:
existing.last_encounter_date = encounter_date
await self.save_patient_memory(existing)
return existing
# -----------------------------------------------------------------
# Bulk operations
# -----------------------------------------------------------------
async def get_active_sessions(self) -> List[str]:
"""List all active session IDs."""
if self._use_redis and self._redis:
try:
keys = []
async for key in self._redis.scan_iter(f"{_SESSION_PREFIX}*"):
session_id = key.replace(_SESSION_PREFIX, "")
keys.append(session_id)
return keys
except Exception:
pass
return [
k.replace(_SESSION_PREFIX, "")
for k in self._session_cache.keys()
]
# Singleton
_memory_service: Optional[MemoryService] = None
def get_memory_service() -> MemoryService:
global _memory_service
if _memory_service is None:
_memory_service = MemoryService()
return _memory_service