""" LangChain-based conversation memory management (v0.2+ compatible) """ from langchain_core.chat_history import BaseChatMessageHistory from langchain_community.chat_message_histories import ChatMessageHistory from langchain_classic.memory import ConversationBufferWindowMemory # Keep classic for now from langchain_core.messages import HumanMessage, AIMessage, BaseMessage from typing import List, Dict, Any, Optional import json import os import pickle from datetime import datetime class ConversationMemory: """ Manages conversation memory using LangChain with persistent storage Fixed for LangChain v0.2+ Pydantic v2 validation """ def __init__(self, session_id: str = "default", memory_window: int = 10): self.session_id = session_id self.memory_window = memory_window # ✅ FIX: Create ChatMessageHistory INSTANCE (required by Pydantic v2) chat_history: BaseChatMessageHistory = ChatMessageHistory() # Initialize LangChain memory with proper chat_history self.memory = ConversationBufferWindowMemory( chat_memory=chat_history, # Pass INSTANCE, not dict k=memory_window, return_messages=True, memory_key="chat_history", output_key="output" ) # Additional metadata storage self.conversation_metadata = { 'session_id': session_id, 'domains_discussed': set(), 'query_types_used': set(), 'previously_used_papers': set(), 'interaction_count': 0 } # Load existing memory if available self._load_memory() def add_interaction(self, user_message: str, ai_response: str, metadata: Dict[str, Any] = None): """Add a new interaction to memory""" # Add to LangChain memory self.memory.save_context( {"input": user_message}, {"output": ai_response} ) # Update metadata self.conversation_metadata['interaction_count'] += 1 if metadata: if 'domain' in metadata: self.conversation_metadata['domains_discussed'].add(metadata['domain']) if 'query_type' in metadata: self.conversation_metadata['query_types_used'].add(metadata['query_type']) if 'papers_used' in metadata: # Track recently used papers to avoid repetition paper_ids = metadata.get('paper_ids', []) self.conversation_metadata['previously_used_papers'].update(paper_ids) # Keep only recent papers (last 20) recent_papers = list(self.conversation_metadata['previously_used_papers'])[-20:] self.conversation_metadata['previously_used_papers'] = set(recent_papers) # Save memory to persistent storage self._save_memory() def get_conversation_history(self, limit: Optional[int] = None) -> List[Dict[str, str]]: """Get conversation history""" chat_history = self.memory.chat_memory.messages history = [] for i in range(0, len(chat_history), 2): if i + 1 < len(chat_history): history.append({ 'user': chat_history[i].content, 'assistant': chat_history[i + 1].content, 'turn': i // 2 + 1 }) if limit: history = history[-limit:] return history def get_conversation_context(self) -> Dict[str, Any]: """Get current conversation context for query enhancement""" history = self.get_conversation_history(limit=3) # Last 3 exchanges context = { 'session_id': self.session_id, 'interaction_count': self.conversation_metadata['interaction_count'], 'domains_discussed': list(self.conversation_metadata['domains_discussed']), 'query_types_used': list(self.conversation_metadata['query_types_used']), 'previously_used_papers': list(self.conversation_metadata['previously_used_papers']), 'recent_history': history } # Extract last topic for context if history: last_interaction = history[-1] context['last_user_message'] = last_interaction['user'] context['last_assistant_response'] = last_interaction['assistant'] context['last_topic'] = self._extract_topic(last_interaction['user']) # Get last query type from metadata if self.conversation_metadata['query_types_used']: context['last_query_type'] = list(self.conversation_metadata['query_types_used'])[-1] # Add last_domain from domains_discussed if self.conversation_metadata['domains_discussed']: context['last_domain'] = list(self.conversation_metadata['domains_discussed'])[-1] return context def get_conversation_summary(self) -> Dict[str, Any]: """Get summary of the conversation""" history = self.get_conversation_history() return { 'session_id': self.session_id, 'total_interactions': len(history), 'domains_covered': list(self.conversation_metadata['domains_discussed']), 'query_types_used': list(self.conversation_metadata['query_types_used']), 'papers_referenced': len(self.conversation_metadata['previously_used_papers']), 'recent_activity': [msg['user'][:50] + '...' for msg in history[-3:]] } def clear_memory(self): """Clear all conversation memory""" self.memory.clear() self.conversation_metadata = { 'session_id': self.session_id, 'domains_discussed': set(), 'query_types_used': set(), 'previously_used_papers': set(), 'interaction_count': 0 } self._save_memory() def _extract_topic(self, message: str) -> str: """Extract main topic from a message""" # Simple topic extraction - can be enhanced words = message.lower().split() # Filter out common words and keep meaningful ones stop_words = {'what', 'how', 'why', 'when', 'where', 'which', 'can', 'you', 'me', 'the', 'a', 'an', 'and', 'or', 'but'} meaningful_words = [word for word in words if word not in stop_words and len(word) > 3] return ' '.join(meaningful_words[:3]) if meaningful_words else 'general discussion' def _get_memory_file_path(self) -> str: """Get file path for persistent memory storage""" memory_dir = "./memory_data" os.makedirs(memory_dir, exist_ok=True) return f"{memory_dir}/memory_{self.session_id}.pkl" def _save_memory(self): """Save memory to persistent storage""" try: # ✅ FIX: Use .dict() for serialization compatibility memory_data = { 'langchain_memory': self.memory.dict(), # Fixed serialization 'conversation_metadata': self.conversation_metadata } with open(self._get_memory_file_path(), 'wb') as f: pickle.dump(memory_data, f) print(f"💾 Memory saved for session: {self.session_id}") except Exception as e: print(f"❌ Error saving memory: {e}") def _load_memory(self): """Load memory from persistent storage""" try: memory_file = self._get_memory_file_path() if os.path.exists(memory_file): with open(memory_file, 'rb') as f: memory_data = pickle.load(f) # ✅ FIX: Recreate chat_history before initializing memory chat_history = ChatMessageHistory() memory_config = memory_data['langchain_memory'] memory_config['chat_memory'] = chat_history # Ensure proper instance self.memory = ConversationBufferWindowMemory(**memory_config) self.conversation_metadata = memory_data['conversation_metadata'] print(f"📂 Memory loaded for session: {self.session_id}") except Exception as e: print(f"❌ Error loading memory: {e}") # Continue with fresh memory # For Vercel serverless compatibility class VercelMemoryManager: """ Memory manager optimized for Vercel serverless environment Uses JSON files instead of pickle for compatibility """ def __init__(self, session_id: str = "default"): self.session_id = session_id self.memory_file = f"/tmp/memory_{session_id}.json" self.conversation_history = [] self.load_memory() def add_interaction(self, user_message: str, ai_response: str, metadata: Dict[str, Any] = None): """Add interaction to memory""" interaction = { 'user': user_message, 'assistant': ai_response, 'metadata': metadata or {}, 'timestamp': self._get_timestamp() } self.conversation_history.append(interaction) # Keep only last 20 interactions in serverless environment if len(self.conversation_history) > 20: self.conversation_history = self.conversation_history[-20:] self.save_memory() def get_conversation_context(self) -> Dict[str, Any]: """Get conversation context""" recent_history = self.conversation_history[-3:] if self.conversation_history else [] domains = set() query_types = set() for interaction in self.conversation_history: if 'metadata' in interaction: meta = interaction['metadata'] if 'domain' in meta: domains.add(meta['domain']) if 'query_type' in meta: query_types.add(meta['query_type']) return { 'session_id': self.session_id, 'interaction_count': len(self.conversation_history), 'domains_discussed': list(domains), 'query_types_used': list(query_types), 'recent_history': recent_history } def save_memory(self): """Save memory to JSON file""" try: with open(self.memory_file, 'w') as f: json.dump(self.conversation_history, f) except Exception as e: print(f"❌ Error saving memory: {e}") def load_memory(self): """Load memory from JSON file""" try: if os.path.exists(self.memory_file): with open(self.memory_file, 'r') as f: self.conversation_history = json.load(f) except Exception as e: print(f"❌ Error loading memory: {e}") self.conversation_history = [] def _get_timestamp(self) -> str: """Get current timestamp""" from datetime import datetime return datetime.now().isoformat()