Spaces:
Running
Running
| # app/ai/memory/redis_context_memory.py - Redis-Based Context Memory (FIXED) | |
| """ | |
| Conversation memory management with Redis: | |
| - Persists across server restarts | |
| - Fast in-memory access | |
| - Auto-expires old conversations (7 days) | |
| - Stores: message history + conversation context | |
| """ | |
| import json | |
| from typing import Dict, List, Optional | |
| from datetime import datetime | |
| from structlog import get_logger | |
| logger = get_logger(__name__) | |
| # Redis key prefixes | |
| HISTORY_PREFIX = "aida:history" | |
| CONTEXT_PREFIX = "aida:context" | |
| TTL = 60 * 60 * 24 * 7 # 7 days | |
| # Lazy import to avoid circular dependency | |
| _redis_client = None | |
| def _get_redis_client(): | |
| """Lazy load redis client to avoid circular imports""" | |
| global _redis_client | |
| if _redis_client is None: | |
| try: | |
| from app.ai.config import redis_client | |
| _redis_client = redis_client | |
| except ImportError as e: | |
| logger.warning(f"β οΈ Failed to import redis_client: {e}") | |
| _redis_client = None | |
| return _redis_client | |
| # ========== REDIS CONVERSATION MEMORY ========== | |
| class RedisConversationMemory: | |
| """ | |
| Redis-based conversation memory | |
| Stores in Redis: | |
| - Message history: aida:history:{user_id}:{session_id} | |
| - Context: aida:context:{user_id}:{session_id} | |
| """ | |
| def __init__(self, user_id: str, session_id: str): | |
| self.user_id = user_id | |
| self.session_id = session_id | |
| self.history_key = f"{HISTORY_PREFIX}:{user_id}:{session_id}" | |
| self.context_key = f"{CONTEXT_PREFIX}:{user_id}:{session_id}" | |
| logger.info("πΎ RedisConversationMemory created", user_id=user_id, session_id=session_id) | |
| # ========== ADD MESSAGE ========== | |
| async def add_message( | |
| self, | |
| role: str, | |
| content: str, | |
| metadata: Optional[Dict] = None, | |
| ) -> None: | |
| """Add message to history in Redis""" | |
| try: | |
| rc = _get_redis_client() | |
| if rc is None: | |
| logger.warning("β οΈ Redis not available") | |
| return | |
| # Get current history | |
| history = await self._get_history_list() | |
| # Add new message | |
| message = { | |
| "role": role, | |
| "content": content, | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "metadata": metadata or {}, | |
| } | |
| history.append(message) | |
| # Save back to Redis | |
| await rc.setex( | |
| self.history_key, | |
| TTL, | |
| json.dumps(history, ensure_ascii=False), | |
| ) | |
| logger.info( | |
| f"π Added {role} message", | |
| user_id=self.user_id, | |
| session_id=self.session_id, | |
| total_messages=len(history) | |
| ) | |
| except Exception as e: | |
| logger.error("β Failed to add message", exc_info=e) | |
| # ========== GET MESSAGE HISTORY (Internal) ========== | |
| async def _get_history_list(self) -> List[Dict]: | |
| """Get raw history list from Redis""" | |
| try: | |
| rc = _get_redis_client() | |
| if rc is None: | |
| return [] | |
| raw = await rc.get(self.history_key) | |
| if raw is None: | |
| return [] | |
| return json.loads(raw) | |
| except Exception as e: | |
| logger.error("β Failed to get history", exc_info=e) | |
| return [] | |
| # ========== GET MESSAGE HISTORY ========== | |
| async def get_messages( | |
| self, | |
| limit: Optional[int] = None, | |
| ) -> List[Dict]: | |
| """ | |
| Get message history | |
| Args: | |
| limit: Optional limit of recent messages (e.g., last 10) | |
| """ | |
| history = await self._get_history_list() | |
| if limit: | |
| return history[-limit:] | |
| return history | |
| # ========== GET FORMATTED HISTORY ========== | |
| async def get_formatted_history(self) -> str: | |
| """ | |
| Get conversation history as formatted string for LLM context | |
| """ | |
| messages = await self.get_messages() | |
| formatted = [] | |
| for msg in messages: | |
| role = "User" if msg["role"] == "user" else "Aida" | |
| content = msg["content"] | |
| formatted.append(f"{role}: {content}") | |
| return "\n".join(formatted) | |
| # ========== GET CONTEXT ========== | |
| async def get_context(self) -> Dict: | |
| """Get current conversation context from Redis""" | |
| try: | |
| rc = _get_redis_client() | |
| if rc is None: | |
| return { | |
| "status": "idle", | |
| "language": "en", | |
| "user_role": None, | |
| "draft": None, | |
| "state": {}, | |
| } | |
| raw = await rc.get(self.context_key) | |
| if raw is None: | |
| return { | |
| "status": "idle", | |
| "language": "en", | |
| "user_role": None, | |
| "draft": None, | |
| "state": {}, | |
| } | |
| return json.loads(raw) | |
| except Exception as e: | |
| logger.error("β Failed to get context", exc_info=e) | |
| return {} | |
| # ========== UPDATE CONTEXT ========== | |
| async def update_context(self, updates: Dict) -> None: | |
| """ | |
| Update conversation context in Redis | |
| """ | |
| try: | |
| rc = _get_redis_client() | |
| if rc is None: | |
| logger.warning("β οΈ Redis not available") | |
| return | |
| # Get current context | |
| context = await self.get_context() | |
| # Update with new values | |
| context.update(updates) | |
| # Save back to Redis | |
| await rc.setex( | |
| self.context_key, | |
| TTL, | |
| json.dumps(context, ensure_ascii=False, default=str), | |
| ) | |
| logger.info( | |
| "π Updated context", | |
| user_id=self.user_id, | |
| session_id=self.session_id, | |
| keys=list(updates.keys()) | |
| ) | |
| except Exception as e: | |
| logger.error("β Failed to update context", exc_info=e) | |
| # ========== GET SUMMARY ========== | |
| async def get_summary(self) -> Dict: | |
| """Get conversation summary""" | |
| try: | |
| messages = await self.get_messages() | |
| context = await self.get_context() | |
| return { | |
| "user_id": self.user_id, | |
| "session_id": self.session_id, | |
| "total_messages": len(messages), | |
| "status": context.get("status", "idle"), | |
| "language": context.get("language", "en"), | |
| "has_draft": context.get("draft") is not None, | |
| } | |
| except Exception as e: | |
| logger.error("β Failed to get summary", exc_info=e) | |
| return {} | |
| # ========== CLEAR MEMORY ========== | |
| async def clear(self) -> None: | |
| """Clear conversation memory (start new chat)""" | |
| try: | |
| rc = _get_redis_client() | |
| if rc is None: | |
| return | |
| await rc.delete(self.history_key) | |
| await rc.delete(self.context_key) | |
| logger.info("ποΈ Conversation memory cleared", user_id=self.user_id, session_id=self.session_id) | |
| except Exception as e: | |
| logger.error("β Failed to clear memory", exc_info=e) | |
| # ========== REDIS MEMORY MANAGER ========== | |
| class RedisMemoryManager: | |
| """ | |
| Global manager for Redis-based conversations | |
| No need to store in-memory - Redis is the single source of truth | |
| """ | |
| # ========== CREATE/GET SESSION ========== | |
| async def get_or_create_session( | |
| self, | |
| user_id: str, | |
| session_id: str, | |
| ) -> RedisConversationMemory: | |
| """ | |
| Get or create conversation memory for a session | |
| With Redis, we always "create" (or retrieve from Redis) | |
| """ | |
| memory = RedisConversationMemory(user_id, session_id) | |
| logger.info("β Session memory ready", user_id=user_id, session_id=session_id) | |
| return memory | |
| # ========== CLOSE SESSION ========== | |
| async def close_session(self, user_id: str, session_id: str) -> None: | |
| """ | |
| Close/clear a session | |
| Optional - can let Redis auto-expire it (7 days) | |
| Or explicitly clear it | |
| """ | |
| try: | |
| memory = RedisConversationMemory(user_id, session_id) | |
| await memory.clear() | |
| logger.info("β Session closed", user_id=user_id, session_id=session_id) | |
| except Exception as e: | |
| logger.error("β Failed to close session", exc_info=e) | |
| # ========== GET USER SESSION HISTORY ========== | |
| async def get_user_history(self, user_id: str) -> List[Dict]: | |
| """ | |
| Get all sessions' message history for a user | |
| """ | |
| try: | |
| rc = _get_redis_client() | |
| if rc is None: | |
| return [] | |
| pattern = f"{HISTORY_PREFIX}:{user_id}:*" | |
| keys = await rc.keys(pattern) | |
| all_messages = [] | |
| for key in keys: | |
| raw = await rc.get(key) | |
| if raw: | |
| messages = json.loads(raw) | |
| all_messages.extend(messages) | |
| logger.info("π Retrieved user history", user_id=user_id, total_messages=len(all_messages)) | |
| return all_messages | |
| except Exception as e: | |
| logger.error("β Failed to get user history", exc_info=e) | |
| return [] | |
| # ========== CLEAR ALL USER SESSIONS ========== | |
| async def clear_user_sessions(self, user_id: str) -> None: | |
| """ | |
| Clear all sessions for a user (user logs out) | |
| """ | |
| try: | |
| rc = _get_redis_client() | |
| if rc is None: | |
| return | |
| history_pattern = f"{HISTORY_PREFIX}:{user_id}:*" | |
| context_pattern = f"{CONTEXT_PREFIX}:{user_id}:*" | |
| history_keys = await rc.keys(history_pattern) | |
| context_keys = await rc.keys(context_pattern) | |
| all_keys = history_keys + context_keys | |
| if all_keys: | |
| await rc.delete(*all_keys) | |
| logger.info("ποΈ All user sessions cleared", user_id=user_id, sessions=len(history_keys)) | |
| except Exception as e: | |
| logger.error("β Failed to clear user sessions", exc_info=e) | |
| # ========== GET REDIS STATS ========== | |
| async def get_stats(self) -> Dict: | |
| """Get Redis memory stats""" | |
| try: | |
| rc = _get_redis_client() | |
| if rc is None: | |
| return {"redis_status": "not connected"} | |
| history_keys = await rc.keys(f"{HISTORY_PREFIX}:*") | |
| context_keys = await rc.keys(f"{CONTEXT_PREFIX}:*") | |
| unique_users = set() | |
| for key in history_keys + context_keys: | |
| parts = key.split(":") | |
| if len(parts) >= 3: | |
| unique_users.add(parts[1]) | |
| return { | |
| "total_history_keys": len(history_keys), | |
| "total_context_keys": len(context_keys), | |
| "unique_users": len(unique_users), | |
| "redis_status": "connected", | |
| } | |
| except Exception as e: | |
| logger.error("β Failed to get stats", exc_info=e) | |
| return {"redis_status": "error"} | |
| # ========== SINGLETON INSTANCE ========== | |
| _memory_manager = None | |
| def get_memory_manager() -> RedisMemoryManager: | |
| """Get or create global Redis memory manager""" | |
| global _memory_manager | |
| if _memory_manager is None: | |
| _memory_manager = RedisMemoryManager() | |
| return _memory_manager | |
| # ========== HELPER: Get Current Memory ========== | |
| async def get_current_memory( | |
| user_id: str, | |
| session_id: str, | |
| ) -> RedisConversationMemory: | |
| """ | |
| Get or create current session memory from Redis | |
| Use this in your routes/services | |
| """ | |
| manager = get_memory_manager() | |
| memory = await manager.get_or_create_session(user_id, session_id) | |
| return memory |