MedSearchPro / lib /memory_manager.py
paulhemb's picture
Initial Backend Deployment
1367957
"""
LangChain-based conversation memory management (v0.2+ compatible)
"""
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_classic.memory import ConversationBufferWindowMemory # Keep classic for now
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from typing import List, Dict, Any, Optional
import json
import os
import pickle
from datetime import datetime
class ConversationMemory:
"""
Manages conversation memory using LangChain with persistent storage
Fixed for LangChain v0.2+ Pydantic v2 validation
"""
def __init__(self, session_id: str = "default", memory_window: int = 10):
self.session_id = session_id
self.memory_window = memory_window
# βœ… FIX: Create ChatMessageHistory INSTANCE (required by Pydantic v2)
chat_history: BaseChatMessageHistory = ChatMessageHistory()
# Initialize LangChain memory with proper chat_history
self.memory = ConversationBufferWindowMemory(
chat_memory=chat_history, # Pass INSTANCE, not dict
k=memory_window,
return_messages=True,
memory_key="chat_history",
output_key="output"
)
# Additional metadata storage
self.conversation_metadata = {
'session_id': session_id,
'domains_discussed': set(),
'query_types_used': set(),
'previously_used_papers': set(),
'interaction_count': 0
}
# Load existing memory if available
self._load_memory()
def add_interaction(self, user_message: str, ai_response: str, metadata: Dict[str, Any] = None):
"""Add a new interaction to memory"""
# Add to LangChain memory
self.memory.save_context(
{"input": user_message},
{"output": ai_response}
)
# Update metadata
self.conversation_metadata['interaction_count'] += 1
if metadata:
if 'domain' in metadata:
self.conversation_metadata['domains_discussed'].add(metadata['domain'])
if 'query_type' in metadata:
self.conversation_metadata['query_types_used'].add(metadata['query_type'])
if 'papers_used' in metadata:
# Track recently used papers to avoid repetition
paper_ids = metadata.get('paper_ids', [])
self.conversation_metadata['previously_used_papers'].update(paper_ids)
# Keep only recent papers (last 20)
recent_papers = list(self.conversation_metadata['previously_used_papers'])[-20:]
self.conversation_metadata['previously_used_papers'] = set(recent_papers)
# Save memory to persistent storage
self._save_memory()
def get_conversation_history(self, limit: Optional[int] = None) -> List[Dict[str, str]]:
"""Get conversation history"""
chat_history = self.memory.chat_memory.messages
history = []
for i in range(0, len(chat_history), 2):
if i + 1 < len(chat_history):
history.append({
'user': chat_history[i].content,
'assistant': chat_history[i + 1].content,
'turn': i // 2 + 1
})
if limit:
history = history[-limit:]
return history
def get_conversation_context(self) -> Dict[str, Any]:
"""Get current conversation context for query enhancement"""
history = self.get_conversation_history(limit=3) # Last 3 exchanges
context = {
'session_id': self.session_id,
'interaction_count': self.conversation_metadata['interaction_count'],
'domains_discussed': list(self.conversation_metadata['domains_discussed']),
'query_types_used': list(self.conversation_metadata['query_types_used']),
'previously_used_papers': list(self.conversation_metadata['previously_used_papers']),
'recent_history': history
}
# Extract last topic for context
if history:
last_interaction = history[-1]
context['last_user_message'] = last_interaction['user']
context['last_assistant_response'] = last_interaction['assistant']
context['last_topic'] = self._extract_topic(last_interaction['user'])
# Get last query type from metadata
if self.conversation_metadata['query_types_used']:
context['last_query_type'] = list(self.conversation_metadata['query_types_used'])[-1]
# Add last_domain from domains_discussed
if self.conversation_metadata['domains_discussed']:
context['last_domain'] = list(self.conversation_metadata['domains_discussed'])[-1]
return context
def get_conversation_summary(self) -> Dict[str, Any]:
"""Get summary of the conversation"""
history = self.get_conversation_history()
return {
'session_id': self.session_id,
'total_interactions': len(history),
'domains_covered': list(self.conversation_metadata['domains_discussed']),
'query_types_used': list(self.conversation_metadata['query_types_used']),
'papers_referenced': len(self.conversation_metadata['previously_used_papers']),
'recent_activity': [msg['user'][:50] + '...' for msg in history[-3:]]
}
def clear_memory(self):
"""Clear all conversation memory"""
self.memory.clear()
self.conversation_metadata = {
'session_id': self.session_id,
'domains_discussed': set(),
'query_types_used': set(),
'previously_used_papers': set(),
'interaction_count': 0
}
self._save_memory()
def _extract_topic(self, message: str) -> str:
"""Extract main topic from a message"""
# Simple topic extraction - can be enhanced
words = message.lower().split()
# Filter out common words and keep meaningful ones
stop_words = {'what', 'how', 'why', 'when', 'where', 'which', 'can', 'you', 'me', 'the', 'a', 'an', 'and', 'or',
'but'}
meaningful_words = [word for word in words if word not in stop_words and len(word) > 3]
return ' '.join(meaningful_words[:3]) if meaningful_words else 'general discussion'
def _get_memory_file_path(self) -> str:
"""Get file path for persistent memory storage"""
memory_dir = "./memory_data"
os.makedirs(memory_dir, exist_ok=True)
return f"{memory_dir}/memory_{self.session_id}.pkl"
def _save_memory(self):
"""Save memory to persistent storage"""
try:
# βœ… FIX: Use .dict() for serialization compatibility
memory_data = {
'langchain_memory': self.memory.dict(), # Fixed serialization
'conversation_metadata': self.conversation_metadata
}
with open(self._get_memory_file_path(), 'wb') as f:
pickle.dump(memory_data, f)
print(f"πŸ’Ύ Memory saved for session: {self.session_id}")
except Exception as e:
print(f"❌ Error saving memory: {e}")
def _load_memory(self):
"""Load memory from persistent storage"""
try:
memory_file = self._get_memory_file_path()
if os.path.exists(memory_file):
with open(memory_file, 'rb') as f:
memory_data = pickle.load(f)
# βœ… FIX: Recreate chat_history before initializing memory
chat_history = ChatMessageHistory()
memory_config = memory_data['langchain_memory']
memory_config['chat_memory'] = chat_history # Ensure proper instance
self.memory = ConversationBufferWindowMemory(**memory_config)
self.conversation_metadata = memory_data['conversation_metadata']
print(f"πŸ“‚ Memory loaded for session: {self.session_id}")
except Exception as e:
print(f"❌ Error loading memory: {e}")
# Continue with fresh memory
# For Vercel serverless compatibility
class VercelMemoryManager:
"""
Memory manager optimized for Vercel serverless environment
Uses JSON files instead of pickle for compatibility
"""
def __init__(self, session_id: str = "default"):
self.session_id = session_id
self.memory_file = f"/tmp/memory_{session_id}.json"
self.conversation_history = []
self.load_memory()
def add_interaction(self, user_message: str, ai_response: str, metadata: Dict[str, Any] = None):
"""Add interaction to memory"""
interaction = {
'user': user_message,
'assistant': ai_response,
'metadata': metadata or {},
'timestamp': self._get_timestamp()
}
self.conversation_history.append(interaction)
# Keep only last 20 interactions in serverless environment
if len(self.conversation_history) > 20:
self.conversation_history = self.conversation_history[-20:]
self.save_memory()
def get_conversation_context(self) -> Dict[str, Any]:
"""Get conversation context"""
recent_history = self.conversation_history[-3:] if self.conversation_history else []
domains = set()
query_types = set()
for interaction in self.conversation_history:
if 'metadata' in interaction:
meta = interaction['metadata']
if 'domain' in meta:
domains.add(meta['domain'])
if 'query_type' in meta:
query_types.add(meta['query_type'])
return {
'session_id': self.session_id,
'interaction_count': len(self.conversation_history),
'domains_discussed': list(domains),
'query_types_used': list(query_types),
'recent_history': recent_history
}
def save_memory(self):
"""Save memory to JSON file"""
try:
with open(self.memory_file, 'w') as f:
json.dump(self.conversation_history, f)
except Exception as e:
print(f"❌ Error saving memory: {e}")
def load_memory(self):
"""Load memory from JSON file"""
try:
if os.path.exists(self.memory_file):
with open(self.memory_file, 'r') as f:
self.conversation_history = json.load(f)
except Exception as e:
print(f"❌ Error loading memory: {e}")
self.conversation_history = []
def _get_timestamp(self) -> str:
"""Get current timestamp"""
from datetime import datetime
return datetime.now().isoformat()