Spaces:
Running
Running
| """ | |
| LangChain-based conversation memory management (v0.2+ compatible) | |
| """ | |
| from langchain_core.chat_history import BaseChatMessageHistory | |
| from langchain_community.chat_message_histories import ChatMessageHistory | |
| from langchain_classic.memory import ConversationBufferWindowMemory # Keep classic for now | |
| from langchain_core.messages import HumanMessage, AIMessage, BaseMessage | |
| from typing import List, Dict, Any, Optional | |
| import json | |
| import os | |
| import pickle | |
| from datetime import datetime | |
| class ConversationMemory: | |
| """ | |
| Manages conversation memory using LangChain with persistent storage | |
| Fixed for LangChain v0.2+ Pydantic v2 validation | |
| """ | |
| def __init__(self, session_id: str = "default", memory_window: int = 10): | |
| self.session_id = session_id | |
| self.memory_window = memory_window | |
| # β FIX: Create ChatMessageHistory INSTANCE (required by Pydantic v2) | |
| chat_history: BaseChatMessageHistory = ChatMessageHistory() | |
| # Initialize LangChain memory with proper chat_history | |
| self.memory = ConversationBufferWindowMemory( | |
| chat_memory=chat_history, # Pass INSTANCE, not dict | |
| k=memory_window, | |
| return_messages=True, | |
| memory_key="chat_history", | |
| output_key="output" | |
| ) | |
| # Additional metadata storage | |
| self.conversation_metadata = { | |
| 'session_id': session_id, | |
| 'domains_discussed': set(), | |
| 'query_types_used': set(), | |
| 'previously_used_papers': set(), | |
| 'interaction_count': 0 | |
| } | |
| # Load existing memory if available | |
| self._load_memory() | |
| def add_interaction(self, user_message: str, ai_response: str, metadata: Dict[str, Any] = None): | |
| """Add a new interaction to memory""" | |
| # Add to LangChain memory | |
| self.memory.save_context( | |
| {"input": user_message}, | |
| {"output": ai_response} | |
| ) | |
| # Update metadata | |
| self.conversation_metadata['interaction_count'] += 1 | |
| if metadata: | |
| if 'domain' in metadata: | |
| self.conversation_metadata['domains_discussed'].add(metadata['domain']) | |
| if 'query_type' in metadata: | |
| self.conversation_metadata['query_types_used'].add(metadata['query_type']) | |
| if 'papers_used' in metadata: | |
| # Track recently used papers to avoid repetition | |
| paper_ids = metadata.get('paper_ids', []) | |
| self.conversation_metadata['previously_used_papers'].update(paper_ids) | |
| # Keep only recent papers (last 20) | |
| recent_papers = list(self.conversation_metadata['previously_used_papers'])[-20:] | |
| self.conversation_metadata['previously_used_papers'] = set(recent_papers) | |
| # Save memory to persistent storage | |
| self._save_memory() | |
| def get_conversation_history(self, limit: Optional[int] = None) -> List[Dict[str, str]]: | |
| """Get conversation history""" | |
| chat_history = self.memory.chat_memory.messages | |
| history = [] | |
| for i in range(0, len(chat_history), 2): | |
| if i + 1 < len(chat_history): | |
| history.append({ | |
| 'user': chat_history[i].content, | |
| 'assistant': chat_history[i + 1].content, | |
| 'turn': i // 2 + 1 | |
| }) | |
| if limit: | |
| history = history[-limit:] | |
| return history | |
| def get_conversation_context(self) -> Dict[str, Any]: | |
| """Get current conversation context for query enhancement""" | |
| history = self.get_conversation_history(limit=3) # Last 3 exchanges | |
| context = { | |
| 'session_id': self.session_id, | |
| 'interaction_count': self.conversation_metadata['interaction_count'], | |
| 'domains_discussed': list(self.conversation_metadata['domains_discussed']), | |
| 'query_types_used': list(self.conversation_metadata['query_types_used']), | |
| 'previously_used_papers': list(self.conversation_metadata['previously_used_papers']), | |
| 'recent_history': history | |
| } | |
| # Extract last topic for context | |
| if history: | |
| last_interaction = history[-1] | |
| context['last_user_message'] = last_interaction['user'] | |
| context['last_assistant_response'] = last_interaction['assistant'] | |
| context['last_topic'] = self._extract_topic(last_interaction['user']) | |
| # Get last query type from metadata | |
| if self.conversation_metadata['query_types_used']: | |
| context['last_query_type'] = list(self.conversation_metadata['query_types_used'])[-1] | |
| # Add last_domain from domains_discussed | |
| if self.conversation_metadata['domains_discussed']: | |
| context['last_domain'] = list(self.conversation_metadata['domains_discussed'])[-1] | |
| return context | |
| def get_conversation_summary(self) -> Dict[str, Any]: | |
| """Get summary of the conversation""" | |
| history = self.get_conversation_history() | |
| return { | |
| 'session_id': self.session_id, | |
| 'total_interactions': len(history), | |
| 'domains_covered': list(self.conversation_metadata['domains_discussed']), | |
| 'query_types_used': list(self.conversation_metadata['query_types_used']), | |
| 'papers_referenced': len(self.conversation_metadata['previously_used_papers']), | |
| 'recent_activity': [msg['user'][:50] + '...' for msg in history[-3:]] | |
| } | |
| def clear_memory(self): | |
| """Clear all conversation memory""" | |
| self.memory.clear() | |
| self.conversation_metadata = { | |
| 'session_id': self.session_id, | |
| 'domains_discussed': set(), | |
| 'query_types_used': set(), | |
| 'previously_used_papers': set(), | |
| 'interaction_count': 0 | |
| } | |
| self._save_memory() | |
| def _extract_topic(self, message: str) -> str: | |
| """Extract main topic from a message""" | |
| # Simple topic extraction - can be enhanced | |
| words = message.lower().split() | |
| # Filter out common words and keep meaningful ones | |
| stop_words = {'what', 'how', 'why', 'when', 'where', 'which', 'can', 'you', 'me', 'the', 'a', 'an', 'and', 'or', | |
| 'but'} | |
| meaningful_words = [word for word in words if word not in stop_words and len(word) > 3] | |
| return ' '.join(meaningful_words[:3]) if meaningful_words else 'general discussion' | |
| def _get_memory_file_path(self) -> str: | |
| """Get file path for persistent memory storage""" | |
| memory_dir = "./memory_data" | |
| os.makedirs(memory_dir, exist_ok=True) | |
| return f"{memory_dir}/memory_{self.session_id}.pkl" | |
| def _save_memory(self): | |
| """Save memory to persistent storage""" | |
| try: | |
| # β FIX: Use .dict() for serialization compatibility | |
| memory_data = { | |
| 'langchain_memory': self.memory.dict(), # Fixed serialization | |
| 'conversation_metadata': self.conversation_metadata | |
| } | |
| with open(self._get_memory_file_path(), 'wb') as f: | |
| pickle.dump(memory_data, f) | |
| print(f"πΎ Memory saved for session: {self.session_id}") | |
| except Exception as e: | |
| print(f"β Error saving memory: {e}") | |
| def _load_memory(self): | |
| """Load memory from persistent storage""" | |
| try: | |
| memory_file = self._get_memory_file_path() | |
| if os.path.exists(memory_file): | |
| with open(memory_file, 'rb') as f: | |
| memory_data = pickle.load(f) | |
| # β FIX: Recreate chat_history before initializing memory | |
| chat_history = ChatMessageHistory() | |
| memory_config = memory_data['langchain_memory'] | |
| memory_config['chat_memory'] = chat_history # Ensure proper instance | |
| self.memory = ConversationBufferWindowMemory(**memory_config) | |
| self.conversation_metadata = memory_data['conversation_metadata'] | |
| print(f"π Memory loaded for session: {self.session_id}") | |
| except Exception as e: | |
| print(f"β Error loading memory: {e}") | |
| # Continue with fresh memory | |
| # For Vercel serverless compatibility | |
| class VercelMemoryManager: | |
| """ | |
| Memory manager optimized for Vercel serverless environment | |
| Uses JSON files instead of pickle for compatibility | |
| """ | |
| def __init__(self, session_id: str = "default"): | |
| self.session_id = session_id | |
| self.memory_file = f"/tmp/memory_{session_id}.json" | |
| self.conversation_history = [] | |
| self.load_memory() | |
| def add_interaction(self, user_message: str, ai_response: str, metadata: Dict[str, Any] = None): | |
| """Add interaction to memory""" | |
| interaction = { | |
| 'user': user_message, | |
| 'assistant': ai_response, | |
| 'metadata': metadata or {}, | |
| 'timestamp': self._get_timestamp() | |
| } | |
| self.conversation_history.append(interaction) | |
| # Keep only last 20 interactions in serverless environment | |
| if len(self.conversation_history) > 20: | |
| self.conversation_history = self.conversation_history[-20:] | |
| self.save_memory() | |
| def get_conversation_context(self) -> Dict[str, Any]: | |
| """Get conversation context""" | |
| recent_history = self.conversation_history[-3:] if self.conversation_history else [] | |
| domains = set() | |
| query_types = set() | |
| for interaction in self.conversation_history: | |
| if 'metadata' in interaction: | |
| meta = interaction['metadata'] | |
| if 'domain' in meta: | |
| domains.add(meta['domain']) | |
| if 'query_type' in meta: | |
| query_types.add(meta['query_type']) | |
| return { | |
| 'session_id': self.session_id, | |
| 'interaction_count': len(self.conversation_history), | |
| 'domains_discussed': list(domains), | |
| 'query_types_used': list(query_types), | |
| 'recent_history': recent_history | |
| } | |
| def save_memory(self): | |
| """Save memory to JSON file""" | |
| try: | |
| with open(self.memory_file, 'w') as f: | |
| json.dump(self.conversation_history, f) | |
| except Exception as e: | |
| print(f"β Error saving memory: {e}") | |
| def load_memory(self): | |
| """Load memory from JSON file""" | |
| try: | |
| if os.path.exists(self.memory_file): | |
| with open(self.memory_file, 'r') as f: | |
| self.conversation_history = json.load(f) | |
| except Exception as e: | |
| print(f"β Error loading memory: {e}") | |
| self.conversation_history = [] | |
| def _get_timestamp(self) -> str: | |
| """Get current timestamp""" | |
| from datetime import datetime | |
| return datetime.now().isoformat() |