File size: 14,825 Bytes
639f3bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
"""
Session Manager - Handles conversation sessions and message history
Supports both in-memory and Redis-based storage
"""

import asyncio
import time
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
import json
import uuid

from ..core.config import settings
from ..core.logging import LoggerMixin
from ..models.schemas import ChatMessage, ConversationHistory, SessionInfo


class SessionManager(LoggerMixin):
    """
    Manages chat sessions and conversation history
    Supports both in-memory and Redis storage backends
    """
    
    def __init__(self):
        self.sessions: Dict[str, ConversationHistory] = {}
        self.redis_client = None
        self.use_redis = bool(settings.redis_url)
        self.session_timeout = settings.session_timeout * 60  # Convert to seconds
        self.max_sessions_per_user = settings.max_sessions_per_user
        self.max_messages_per_session = settings.max_messages_per_session
        
        # Cleanup task
        self._cleanup_task = None
    
    async def initialize(self) -> bool:
        """Initialize the session manager"""
        try:
            if self.use_redis:
                await self._initialize_redis()
            
            # Start cleanup task
            self._cleanup_task = asyncio.create_task(self._cleanup_expired_sessions())
            
            self.log_info("Session manager initialized", 
                         storage_type="redis" if self.use_redis else "memory",
                         session_timeout=self.session_timeout)
            return True
            
        except Exception as e:
            self.log_error("Failed to initialize session manager", error=str(e))
            return False
    
    async def _initialize_redis(self):
        """Initialize Redis connection"""
        try:
            import redis.asyncio as redis
            self.redis_client = redis.from_url(settings.redis_url)
            
            # Test connection
            await self.redis_client.ping()
            self.log_info("Redis connection established", url=settings.redis_url)
            
        except ImportError:
            self.log_warning("Redis not available, falling back to memory storage")
            self.use_redis = False
        except Exception as e:
            self.log_error("Redis connection failed", error=str(e))
            self.use_redis = False
    
    async def shutdown(self):
        """Shutdown the session manager"""
        try:
            if self._cleanup_task:
                self._cleanup_task.cancel()
                try:
                    await self._cleanup_task
                except asyncio.CancelledError:
                    pass
            
            if self.redis_client:
                await self.redis_client.close()
            
            self.log_info("Session manager shutdown complete")
            
        except Exception as e:
            self.log_error("Session manager shutdown failed", error=str(e))
    
    async def create_session(self, session_id: str, user_id: Optional[str] = None) -> bool:
        """
        Create a new chat session
        
        Args:
            session_id: Unique session identifier
            user_id: Optional user identifier for session limits
            
        Returns:
            bool: True if session created successfully
        """
        try:
            # Check if session already exists
            if await self.session_exists(session_id):
                self.log_info("Session already exists", session_id=session_id)
                return True
            
            # Check user session limits if user_id provided
            if user_id and self.max_sessions_per_user > 0:
                user_sessions = await self.get_user_sessions(user_id)
                if len(user_sessions) >= self.max_sessions_per_user:
                    self.log_warning("User session limit exceeded", 
                                   user_id=user_id, 
                                   limit=self.max_sessions_per_user)
                    return False
            
            # Create new session
            session = ConversationHistory(
                session_id=session_id,
                messages=[],
                created_at=datetime.utcnow(),
                updated_at=datetime.utcnow(),
                message_count=0
            )
            
            await self._store_session(session)
            
            self.log_info("Session created", session_id=session_id, user_id=user_id)
            return True
            
        except Exception as e:
            self.log_error("Failed to create session", error=str(e), session_id=session_id)
            return False
    
    async def add_message(self, session_id: str, message: ChatMessage) -> bool:
        """
        Add a message to a session
        
        Args:
            session_id: Session identifier
            message: Message to add
            
        Returns:
            bool: True if message added successfully
        """
        try:
            # Get or create session
            session = await self.get_session(session_id)
            if not session:
                await self.create_session(session_id)
                session = await self.get_session(session_id)
            
            if not session:
                self.log_error("Failed to create session", session_id=session_id)
                return False
            
            # Check message limit
            if (self.max_messages_per_session > 0 and 
                len(session.messages) >= self.max_messages_per_session):
                # Remove oldest messages to make room
                messages_to_remove = len(session.messages) - self.max_messages_per_session + 1
                session.messages = session.messages[messages_to_remove:]
                self.log_info("Trimmed old messages", 
                             session_id=session_id, 
                             removed_count=messages_to_remove)
            
            # Add message
            session.messages.append(message)
            session.message_count = len(session.messages)
            session.updated_at = datetime.utcnow()
            
            # Store updated session
            await self._store_session(session)
            
            self.log_debug("Message added to session", 
                          session_id=session_id, 
                          message_role=message.role,
                          total_messages=session.message_count)
            return True
            
        except Exception as e:
            self.log_error("Failed to add message", error=str(e), session_id=session_id)
            return False
    
    async def get_session(self, session_id: str) -> Optional[ConversationHistory]:
        """
        Get a session by ID
        
        Args:
            session_id: Session identifier
            
        Returns:
            ConversationHistory or None if not found
        """
        try:
            if self.use_redis:
                return await self._get_session_from_redis(session_id)
            else:
                return self.sessions.get(session_id)
                
        except Exception as e:
            self.log_error("Failed to get session", error=str(e), session_id=session_id)
            return None
    
    async def session_exists(self, session_id: str) -> bool:
        """Check if a session exists"""
        session = await self.get_session(session_id)
        return session is not None
    
    async def delete_session(self, session_id: str) -> bool:
        """
        Delete a session
        
        Args:
            session_id: Session identifier
            
        Returns:
            bool: True if session deleted successfully
        """
        try:
            if self.use_redis:
                await self.redis_client.delete(f"session:{session_id}")
            else:
                self.sessions.pop(session_id, None)
            
            self.log_info("Session deleted", session_id=session_id)
            return True
            
        except Exception as e:
            self.log_error("Failed to delete session", error=str(e), session_id=session_id)
            return False
    
    async def get_session_messages(self, session_id: str, limit: Optional[int] = None) -> List[ChatMessage]:
        """
        Get messages from a session
        
        Args:
            session_id: Session identifier
            limit: Optional limit on number of messages to return
            
        Returns:
            List of ChatMessage objects
        """
        session = await self.get_session(session_id)
        if not session:
            return []
        
        messages = session.messages
        if limit and limit > 0:
            messages = messages[-limit:]  # Get last N messages
        
        return messages
    
    async def get_active_sessions(self) -> List[SessionInfo]:
        """Get information about all active sessions"""
        try:
            sessions = []
            
            if self.use_redis:
                # Get all session keys from Redis
                keys = await self.redis_client.keys("session:*")
                for key in keys:
                    session_id = key.decode().replace("session:", "")
                    session = await self.get_session(session_id)
                    if session:
                        sessions.append(self._session_to_info(session))
            else:
                # Get from memory
                for session in self.sessions.values():
                    sessions.append(self._session_to_info(session))
            
            return sessions
            
        except Exception as e:
            self.log_error("Failed to get active sessions", error=str(e))
            return []
    
    async def get_user_sessions(self, user_id: str) -> List[SessionInfo]:
        """Get sessions for a specific user (requires user_id in session metadata)"""
        # This is a simplified implementation
        # In a real system, you'd store user_id -> session_id mappings
        all_sessions = await self.get_active_sessions()
        return [s for s in all_sessions if s.session_id.startswith(f"{user_id}-")]
    
    def _session_to_info(self, session: ConversationHistory) -> SessionInfo:
        """Convert ConversationHistory to SessionInfo"""
        return SessionInfo(
            session_id=session.session_id,
            created_at=session.created_at,
            updated_at=session.updated_at,
            message_count=session.message_count,
            model_name=settings.model_name,  # Current model
            is_active=True
        )
    
    async def _store_session(self, session: ConversationHistory):
        """Store session in the appropriate backend"""
        if self.use_redis:
            await self._store_session_in_redis(session)
        else:
            self.sessions[session.session_id] = session
    
    async def _store_session_in_redis(self, session: ConversationHistory):
        """Store session in Redis"""
        key = f"session:{session.session_id}"
        data = {
            "session_id": session.session_id,
            "messages": [
                {
                    "role": msg.role,
                    "content": msg.content,
                    "timestamp": msg.timestamp.isoformat(),
                    "metadata": msg.metadata or {}
                }
                for msg in session.messages
            ],
            "created_at": session.created_at.isoformat(),
            "updated_at": session.updated_at.isoformat(),
            "message_count": session.message_count
        }
        
        await self.redis_client.setex(
            key,
            self.session_timeout,
            json.dumps(data, default=str)
        )
    
    async def _get_session_from_redis(self, session_id: str) -> Optional[ConversationHistory]:
        """Get session from Redis"""
        key = f"session:{session_id}"
        data = await self.redis_client.get(key)
        
        if not data:
            return None
        
        try:
            session_data = json.loads(data)
            messages = [
                ChatMessage(
                    role=msg["role"],
                    content=msg["content"],
                    timestamp=datetime.fromisoformat(msg["timestamp"]),
                    metadata=msg.get("metadata")
                )
                for msg in session_data["messages"]
            ]
            
            return ConversationHistory(
                session_id=session_data["session_id"],
                messages=messages,
                created_at=datetime.fromisoformat(session_data["created_at"]),
                updated_at=datetime.fromisoformat(session_data["updated_at"]),
                message_count=session_data["message_count"]
            )
            
        except Exception as e:
            self.log_error("Failed to parse session from Redis", error=str(e), session_id=session_id)
            return None
    
    async def _cleanup_expired_sessions(self):
        """Background task to cleanup expired sessions"""
        while True:
            try:
                await asyncio.sleep(300)  # Run every 5 minutes
                
                if not self.use_redis:  # Redis handles expiration automatically
                    current_time = datetime.utcnow()
                    expired_sessions = []
                    
                    for session_id, session in self.sessions.items():
                        if (current_time - session.updated_at).total_seconds() > self.session_timeout:
                            expired_sessions.append(session_id)
                    
                    for session_id in expired_sessions:
                        del self.sessions[session_id]
                        self.log_debug("Expired session cleaned up", session_id=session_id)
                    
                    if expired_sessions:
                        self.log_info("Cleaned up expired sessions", count=len(expired_sessions))
                
            except asyncio.CancelledError:
                break
            except Exception as e:
                self.log_error("Session cleanup failed", error=str(e))


# Global session manager instance
session_manager = SessionManager()


async def get_session_manager() -> SessionManager:
    """Get the global session manager instance"""
    return session_manager


async def initialize_session_manager() -> bool:
    """Initialize the global session manager"""
    return await session_manager.initialize()


async def shutdown_session_manager():
    """Shutdown the global session manager"""
    await session_manager.shutdown()