text-to-3d / context_memory.py
jainarham's picture
Update context_memory.py
77e93f3 verified
"""
Context Memory Module
Manages session-based context for conversation continuity
"""
from typing import Dict, Any, Optional
from datetime import datetime, timedelta
import threading
import logging
logger = logging.getLogger(__name__)
class ContextMemory:
def __init__(self, max_sessions: int = 1000, session_timeout_hours: int = 24):
"""
Initialize context memory store
Args:
max_sessions: Maximum number of sessions to store
session_timeout_hours: Hours before session expires
"""
self._store: Dict[str, Dict[str, Any]] = {}
self._timestamps: Dict[str, datetime] = {}
self._lock = threading.Lock()
self._max_sessions = max_sessions
self._session_timeout = timedelta(hours=session_timeout_hours)
def save_context(self, session_id: str, context: Dict[str, Any]) -> bool:
"""
Save context for a session
Args:
session_id: Unique session identifier
context: Context data to store
Returns:
True if saved successfully
"""
with self._lock:
# Clean up old sessions if we're at capacity
if len(self._store) >= self._max_sessions:
self._cleanup_old_sessions()
self._store[session_id] = context
self._timestamps[session_id] = datetime.now()
logger.debug(f"Saved context for session {session_id}")
return True
def get_context(self, session_id: str, default: Any = None) -> Optional[Dict[str, Any]]:
"""
Get context for a session
Args:
session_id: Unique session identifier
default: Default value if session not found
Returns:
Context data or default
"""
with self._lock:
if session_id in self._store:
# Check if session has expired
if self._is_expired(session_id):
self._remove_session(session_id)
return default
# Update timestamp on access
self._timestamps[session_id] = datetime.now()
return self._store[session_id]
return default
def update_context(self, session_id: str, updates: Dict[str, Any]) -> bool:
"""
Update existing context with new data
Args:
session_id: Unique session identifier
updates: Data to merge into existing context
Returns:
True if updated successfully
"""
with self._lock:
if session_id in self._store:
self._store[session_id].update(updates)
self._timestamps[session_id] = datetime.now()
return True
return False
def clear_context(self, session_id: str) -> bool:
"""
Clear context for a session
Args:
session_id: Unique session identifier
Returns:
True if cleared successfully
"""
with self._lock:
return self._remove_session(session_id)
def _is_expired(self, session_id: str) -> bool:
"""Check if a session has expired"""
if session_id not in self._timestamps:
return True
age = datetime.now() - self._timestamps[session_id]
return age > self._session_timeout
def _remove_session(self, session_id: str) -> bool:
"""Remove a session from storage"""
if session_id in self._store:
del self._store[session_id]
del self._timestamps[session_id]
logger.debug(f"Removed session {session_id}")
return True
return False
def _cleanup_old_sessions(self):
"""Remove expired sessions"""
expired = [
sid for sid in self._store
if self._is_expired(sid)
]
for sid in expired:
self._remove_session(sid)
# If still at capacity, remove oldest sessions
if len(self._store) >= self._max_sessions:
sorted_sessions = sorted(
self._timestamps.items(),
key=lambda x: x[1]
)
# Remove oldest 10%
remove_count = max(1, self._max_sessions // 10)
for sid, _ in sorted_sessions[:remove_count]:
self._remove_session(sid)
logger.info(f"Cleanup complete. Active sessions: {len(self._store)}")
def get_stats(self) -> Dict[str, Any]:
"""Get memory store statistics"""
with self._lock:
return {
"active_sessions": len(self._store),
"max_sessions": self._max_sessions,
"session_timeout_hours": self._session_timeout.total_seconds() / 3600
}