AIDA / app /ai /memory /redis_context_memory.py
destinyebuka's picture
updated
bdbd9f4
# app/ai/memory/redis_context_memory.py - Redis-Based Context Memory (FIXED)
"""
Conversation memory management with Redis:
- Persists across server restarts
- Fast in-memory access
- Auto-expires old conversations (7 days)
- Stores: message history + conversation context
"""
import json
from typing import Dict, List, Optional
from datetime import datetime
from structlog import get_logger
logger = get_logger(__name__)
# Redis key prefixes
HISTORY_PREFIX = "aida:history"
CONTEXT_PREFIX = "aida:context"
TTL = 60 * 60 * 24 * 7 # 7 days
# Lazy import to avoid circular dependency
_redis_client = None
def _get_redis_client():
"""Lazy load redis client to avoid circular imports"""
global _redis_client
if _redis_client is None:
try:
from app.ai.config import redis_client
_redis_client = redis_client
except ImportError as e:
logger.warning(f"⚠️ Failed to import redis_client: {e}")
_redis_client = None
return _redis_client
# ========== REDIS CONVERSATION MEMORY ==========
class RedisConversationMemory:
"""
Redis-based conversation memory
Stores in Redis:
- Message history: aida:history:{user_id}:{session_id}
- Context: aida:context:{user_id}:{session_id}
"""
def __init__(self, user_id: str, session_id: str):
self.user_id = user_id
self.session_id = session_id
self.history_key = f"{HISTORY_PREFIX}:{user_id}:{session_id}"
self.context_key = f"{CONTEXT_PREFIX}:{user_id}:{session_id}"
logger.info("πŸ’Ύ RedisConversationMemory created", user_id=user_id, session_id=session_id)
# ========== ADD MESSAGE ==========
async def add_message(
self,
role: str,
content: str,
metadata: Optional[Dict] = None,
) -> None:
"""Add message to history in Redis"""
try:
rc = _get_redis_client()
if rc is None:
logger.warning("⚠️ Redis not available")
return
# Get current history
history = await self._get_history_list()
# Add new message
message = {
"role": role,
"content": content,
"timestamp": datetime.utcnow().isoformat(),
"metadata": metadata or {},
}
history.append(message)
# Save back to Redis
await rc.setex(
self.history_key,
TTL,
json.dumps(history, ensure_ascii=False),
)
logger.info(
f"πŸ“ Added {role} message",
user_id=self.user_id,
session_id=self.session_id,
total_messages=len(history)
)
except Exception as e:
logger.error("❌ Failed to add message", exc_info=e)
# ========== GET MESSAGE HISTORY (Internal) ==========
async def _get_history_list(self) -> List[Dict]:
"""Get raw history list from Redis"""
try:
rc = _get_redis_client()
if rc is None:
return []
raw = await rc.get(self.history_key)
if raw is None:
return []
return json.loads(raw)
except Exception as e:
logger.error("❌ Failed to get history", exc_info=e)
return []
# ========== GET MESSAGE HISTORY ==========
async def get_messages(
self,
limit: Optional[int] = None,
) -> List[Dict]:
"""
Get message history
Args:
limit: Optional limit of recent messages (e.g., last 10)
"""
history = await self._get_history_list()
if limit:
return history[-limit:]
return history
# ========== GET FORMATTED HISTORY ==========
async def get_formatted_history(self) -> str:
"""
Get conversation history as formatted string for LLM context
"""
messages = await self.get_messages()
formatted = []
for msg in messages:
role = "User" if msg["role"] == "user" else "Aida"
content = msg["content"]
formatted.append(f"{role}: {content}")
return "\n".join(formatted)
# ========== GET CONTEXT ==========
async def get_context(self) -> Dict:
"""Get current conversation context from Redis"""
try:
rc = _get_redis_client()
if rc is None:
return {
"status": "idle",
"language": "en",
"user_role": None,
"draft": None,
"state": {},
}
raw = await rc.get(self.context_key)
if raw is None:
return {
"status": "idle",
"language": "en",
"user_role": None,
"draft": None,
"state": {},
}
return json.loads(raw)
except Exception as e:
logger.error("❌ Failed to get context", exc_info=e)
return {}
# ========== UPDATE CONTEXT ==========
async def update_context(self, updates: Dict) -> None:
"""
Update conversation context in Redis
"""
try:
rc = _get_redis_client()
if rc is None:
logger.warning("⚠️ Redis not available")
return
# Get current context
context = await self.get_context()
# Update with new values
context.update(updates)
# Save back to Redis
await rc.setex(
self.context_key,
TTL,
json.dumps(context, ensure_ascii=False, default=str),
)
logger.info(
"πŸ“‹ Updated context",
user_id=self.user_id,
session_id=self.session_id,
keys=list(updates.keys())
)
except Exception as e:
logger.error("❌ Failed to update context", exc_info=e)
# ========== GET SUMMARY ==========
async def get_summary(self) -> Dict:
"""Get conversation summary"""
try:
messages = await self.get_messages()
context = await self.get_context()
return {
"user_id": self.user_id,
"session_id": self.session_id,
"total_messages": len(messages),
"status": context.get("status", "idle"),
"language": context.get("language", "en"),
"has_draft": context.get("draft") is not None,
}
except Exception as e:
logger.error("❌ Failed to get summary", exc_info=e)
return {}
# ========== CLEAR MEMORY ==========
async def clear(self) -> None:
"""Clear conversation memory (start new chat)"""
try:
rc = _get_redis_client()
if rc is None:
return
await rc.delete(self.history_key)
await rc.delete(self.context_key)
logger.info("πŸ—‘οΈ Conversation memory cleared", user_id=self.user_id, session_id=self.session_id)
except Exception as e:
logger.error("❌ Failed to clear memory", exc_info=e)
# ========== REDIS MEMORY MANAGER ==========
class RedisMemoryManager:
"""
Global manager for Redis-based conversations
No need to store in-memory - Redis is the single source of truth
"""
# ========== CREATE/GET SESSION ==========
async def get_or_create_session(
self,
user_id: str,
session_id: str,
) -> RedisConversationMemory:
"""
Get or create conversation memory for a session
With Redis, we always "create" (or retrieve from Redis)
"""
memory = RedisConversationMemory(user_id, session_id)
logger.info("βœ… Session memory ready", user_id=user_id, session_id=session_id)
return memory
# ========== CLOSE SESSION ==========
async def close_session(self, user_id: str, session_id: str) -> None:
"""
Close/clear a session
Optional - can let Redis auto-expire it (7 days)
Or explicitly clear it
"""
try:
memory = RedisConversationMemory(user_id, session_id)
await memory.clear()
logger.info("❌ Session closed", user_id=user_id, session_id=session_id)
except Exception as e:
logger.error("❌ Failed to close session", exc_info=e)
# ========== GET USER SESSION HISTORY ==========
async def get_user_history(self, user_id: str) -> List[Dict]:
"""
Get all sessions' message history for a user
"""
try:
rc = _get_redis_client()
if rc is None:
return []
pattern = f"{HISTORY_PREFIX}:{user_id}:*"
keys = await rc.keys(pattern)
all_messages = []
for key in keys:
raw = await rc.get(key)
if raw:
messages = json.loads(raw)
all_messages.extend(messages)
logger.info("πŸ“Š Retrieved user history", user_id=user_id, total_messages=len(all_messages))
return all_messages
except Exception as e:
logger.error("❌ Failed to get user history", exc_info=e)
return []
# ========== CLEAR ALL USER SESSIONS ==========
async def clear_user_sessions(self, user_id: str) -> None:
"""
Clear all sessions for a user (user logs out)
"""
try:
rc = _get_redis_client()
if rc is None:
return
history_pattern = f"{HISTORY_PREFIX}:{user_id}:*"
context_pattern = f"{CONTEXT_PREFIX}:{user_id}:*"
history_keys = await rc.keys(history_pattern)
context_keys = await rc.keys(context_pattern)
all_keys = history_keys + context_keys
if all_keys:
await rc.delete(*all_keys)
logger.info("πŸ—‘οΈ All user sessions cleared", user_id=user_id, sessions=len(history_keys))
except Exception as e:
logger.error("❌ Failed to clear user sessions", exc_info=e)
# ========== GET REDIS STATS ==========
async def get_stats(self) -> Dict:
"""Get Redis memory stats"""
try:
rc = _get_redis_client()
if rc is None:
return {"redis_status": "not connected"}
history_keys = await rc.keys(f"{HISTORY_PREFIX}:*")
context_keys = await rc.keys(f"{CONTEXT_PREFIX}:*")
unique_users = set()
for key in history_keys + context_keys:
parts = key.split(":")
if len(parts) >= 3:
unique_users.add(parts[1])
return {
"total_history_keys": len(history_keys),
"total_context_keys": len(context_keys),
"unique_users": len(unique_users),
"redis_status": "connected",
}
except Exception as e:
logger.error("❌ Failed to get stats", exc_info=e)
return {"redis_status": "error"}
# ========== SINGLETON INSTANCE ==========
_memory_manager = None
def get_memory_manager() -> RedisMemoryManager:
"""Get or create global Redis memory manager"""
global _memory_manager
if _memory_manager is None:
_memory_manager = RedisMemoryManager()
return _memory_manager
# ========== HELPER: Get Current Memory ==========
async def get_current_memory(
user_id: str,
session_id: str,
) -> RedisConversationMemory:
"""
Get or create current session memory from Redis
Use this in your routes/services
"""
manager = get_memory_manager()
memory = await manager.get_or_create_session(user_id, session_id)
return memory