File size: 11,254 Bytes
1367957
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
"""

LangChain-based conversation memory management (v0.2+ compatible)

"""

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_classic.memory import ConversationBufferWindowMemory  # Keep classic for now
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from typing import List, Dict, Any, Optional
import json
import os
import pickle
from datetime import datetime


class ConversationMemory:
    """

    Manages conversation memory using LangChain with persistent storage

    Fixed for LangChain v0.2+ Pydantic v2 validation

    """

    def __init__(self, session_id: str = "default", memory_window: int = 10):
        self.session_id = session_id
        self.memory_window = memory_window

        # βœ… FIX: Create ChatMessageHistory INSTANCE (required by Pydantic v2)
        chat_history: BaseChatMessageHistory = ChatMessageHistory()

        # Initialize LangChain memory with proper chat_history
        self.memory = ConversationBufferWindowMemory(
            chat_memory=chat_history,  # Pass INSTANCE, not dict
            k=memory_window,
            return_messages=True,
            memory_key="chat_history",
            output_key="output"
        )

        # Additional metadata storage
        self.conversation_metadata = {
            'session_id': session_id,
            'domains_discussed': set(),
            'query_types_used': set(),
            'previously_used_papers': set(),
            'interaction_count': 0
        }

        # Load existing memory if available
        self._load_memory()

    def add_interaction(self, user_message: str, ai_response: str, metadata: Dict[str, Any] = None):
        """Add a new interaction to memory"""
        # Add to LangChain memory
        self.memory.save_context(
            {"input": user_message},
            {"output": ai_response}
        )

        # Update metadata
        self.conversation_metadata['interaction_count'] += 1

        if metadata:
            if 'domain' in metadata:
                self.conversation_metadata['domains_discussed'].add(metadata['domain'])
            if 'query_type' in metadata:
                self.conversation_metadata['query_types_used'].add(metadata['query_type'])
            if 'papers_used' in metadata:
                # Track recently used papers to avoid repetition
                paper_ids = metadata.get('paper_ids', [])
                self.conversation_metadata['previously_used_papers'].update(paper_ids)
                # Keep only recent papers (last 20)
                recent_papers = list(self.conversation_metadata['previously_used_papers'])[-20:]
                self.conversation_metadata['previously_used_papers'] = set(recent_papers)

        # Save memory to persistent storage
        self._save_memory()

    def get_conversation_history(self, limit: Optional[int] = None) -> List[Dict[str, str]]:
        """Get conversation history"""
        chat_history = self.memory.chat_memory.messages

        history = []
        for i in range(0, len(chat_history), 2):
            if i + 1 < len(chat_history):
                history.append({
                    'user': chat_history[i].content,
                    'assistant': chat_history[i + 1].content,
                    'turn': i // 2 + 1
                })

        if limit:
            history = history[-limit:]

        return history

    def get_conversation_context(self) -> Dict[str, Any]:
        """Get current conversation context for query enhancement"""
        history = self.get_conversation_history(limit=3)  # Last 3 exchanges

        context = {
            'session_id': self.session_id,
            'interaction_count': self.conversation_metadata['interaction_count'],
            'domains_discussed': list(self.conversation_metadata['domains_discussed']),
            'query_types_used': list(self.conversation_metadata['query_types_used']),
            'previously_used_papers': list(self.conversation_metadata['previously_used_papers']),
            'recent_history': history
        }

        # Extract last topic for context
        if history:
            last_interaction = history[-1]
            context['last_user_message'] = last_interaction['user']
            context['last_assistant_response'] = last_interaction['assistant']
            context['last_topic'] = self._extract_topic(last_interaction['user'])

            # Get last query type from metadata
            if self.conversation_metadata['query_types_used']:
                context['last_query_type'] = list(self.conversation_metadata['query_types_used'])[-1]

            # Add last_domain from domains_discussed
            if self.conversation_metadata['domains_discussed']:
                context['last_domain'] = list(self.conversation_metadata['domains_discussed'])[-1]

        return context

    def get_conversation_summary(self) -> Dict[str, Any]:
        """Get summary of the conversation"""
        history = self.get_conversation_history()

        return {
            'session_id': self.session_id,
            'total_interactions': len(history),
            'domains_covered': list(self.conversation_metadata['domains_discussed']),
            'query_types_used': list(self.conversation_metadata['query_types_used']),
            'papers_referenced': len(self.conversation_metadata['previously_used_papers']),
            'recent_activity': [msg['user'][:50] + '...' for msg in history[-3:]]
        }

    def clear_memory(self):
        """Clear all conversation memory"""
        self.memory.clear()
        self.conversation_metadata = {
            'session_id': self.session_id,
            'domains_discussed': set(),
            'query_types_used': set(),
            'previously_used_papers': set(),
            'interaction_count': 0
        }
        self._save_memory()

    def _extract_topic(self, message: str) -> str:
        """Extract main topic from a message"""
        # Simple topic extraction - can be enhanced
        words = message.lower().split()
        # Filter out common words and keep meaningful ones
        stop_words = {'what', 'how', 'why', 'when', 'where', 'which', 'can', 'you', 'me', 'the', 'a', 'an', 'and', 'or',
                      'but'}
        meaningful_words = [word for word in words if word not in stop_words and len(word) > 3]
        return ' '.join(meaningful_words[:3]) if meaningful_words else 'general discussion'

    def _get_memory_file_path(self) -> str:
        """Get file path for persistent memory storage"""
        memory_dir = "./memory_data"
        os.makedirs(memory_dir, exist_ok=True)
        return f"{memory_dir}/memory_{self.session_id}.pkl"

    def _save_memory(self):
        """Save memory to persistent storage"""
        try:
            # βœ… FIX: Use .dict() for serialization compatibility
            memory_data = {
                'langchain_memory': self.memory.dict(),  # Fixed serialization
                'conversation_metadata': self.conversation_metadata
            }

            with open(self._get_memory_file_path(), 'wb') as f:
                pickle.dump(memory_data, f)

            print(f"πŸ’Ύ Memory saved for session: {self.session_id}")
        except Exception as e:
            print(f"❌ Error saving memory: {e}")

    def _load_memory(self):
        """Load memory from persistent storage"""
        try:
            memory_file = self._get_memory_file_path()
            if os.path.exists(memory_file):
                with open(memory_file, 'rb') as f:
                    memory_data = pickle.load(f)

                # βœ… FIX: Recreate chat_history before initializing memory
                chat_history = ChatMessageHistory()
                memory_config = memory_data['langchain_memory']
                memory_config['chat_memory'] = chat_history  # Ensure proper instance

                self.memory = ConversationBufferWindowMemory(**memory_config)
                self.conversation_metadata = memory_data['conversation_metadata']

                print(f"πŸ“‚ Memory loaded for session: {self.session_id}")
        except Exception as e:
            print(f"❌ Error loading memory: {e}")
            # Continue with fresh memory


# For Vercel serverless compatibility
class VercelMemoryManager:
    """

    Memory manager optimized for Vercel serverless environment

    Uses JSON files instead of pickle for compatibility

    """

    def __init__(self, session_id: str = "default"):
        self.session_id = session_id
        self.memory_file = f"/tmp/memory_{session_id}.json"
        self.conversation_history = []
        self.load_memory()

    def add_interaction(self, user_message: str, ai_response: str, metadata: Dict[str, Any] = None):
        """Add interaction to memory"""
        interaction = {
            'user': user_message,
            'assistant': ai_response,
            'metadata': metadata or {},
            'timestamp': self._get_timestamp()
        }

        self.conversation_history.append(interaction)

        # Keep only last 20 interactions in serverless environment
        if len(self.conversation_history) > 20:
            self.conversation_history = self.conversation_history[-20:]

        self.save_memory()

    def get_conversation_context(self) -> Dict[str, Any]:
        """Get conversation context"""
        recent_history = self.conversation_history[-3:] if self.conversation_history else []

        domains = set()
        query_types = set()

        for interaction in self.conversation_history:
            if 'metadata' in interaction:
                meta = interaction['metadata']
                if 'domain' in meta:
                    domains.add(meta['domain'])
                if 'query_type' in meta:
                    query_types.add(meta['query_type'])

        return {
            'session_id': self.session_id,
            'interaction_count': len(self.conversation_history),
            'domains_discussed': list(domains),
            'query_types_used': list(query_types),
            'recent_history': recent_history
        }

    def save_memory(self):
        """Save memory to JSON file"""
        try:
            with open(self.memory_file, 'w') as f:
                json.dump(self.conversation_history, f)
        except Exception as e:
            print(f"❌ Error saving memory: {e}")

    def load_memory(self):
        """Load memory from JSON file"""
        try:
            if os.path.exists(self.memory_file):
                with open(self.memory_file, 'r') as f:
                    self.conversation_history = json.load(f)
        except Exception as e:
            print(f"❌ Error loading memory: {e}")
            self.conversation_history = []

    def _get_timestamp(self) -> str:
        """Get current timestamp"""
        from datetime import datetime
        return datetime.now().isoformat()