File size: 17,514 Bytes
330b6e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
"""Session management service for chat agent."""

import json
import logging
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
from uuid import uuid4

import redis
from sqlalchemy.exc import SQLAlchemyError

from ..models.chat_session import ChatSession
from ..models.base import db


logger = logging.getLogger(__name__)


class SessionManagerError(Exception):
    """Base exception for session manager errors."""
    pass


class SessionNotFoundError(SessionManagerError):
    """Raised when a session is not found."""
    pass


class SessionExpiredError(SessionManagerError):
    """Raised when a session has expired."""
    pass


class SessionManager:
    """Manages user chat sessions with Redis caching and PostgreSQL persistence."""
    
    def __init__(self, redis_client: redis.Redis, session_timeout: int = 3600):
        """

        Initialize the session manager.

        

        Args:

            redis_client: Redis client instance for caching

            session_timeout: Session timeout in seconds (default: 1 hour)

        """
        self.redis_client = redis_client
        self.session_timeout = session_timeout
        self.cache_prefix = "session:"
        self.user_sessions_prefix = "user_sessions:"
        
    def create_session(self, user_id: str, language: str = 'python', 

                      session_metadata: Optional[Dict[str, Any]] = None) -> ChatSession:
        """

        Create a new chat session.

        

        Args:

            user_id: User identifier

            language: Programming language for the session (default: python)

            session_metadata: Additional session metadata

            

        Returns:

            ChatSession: The created session

            

        Raises:

            SessionManagerError: If session creation fails

        """
        try:
            # Create session in database
            session = ChatSession.create_session(
                user_id=user_id,
                language=language,
                session_metadata=session_metadata or {}
            )
            
            # Cache session in Redis
            self._cache_session(session)
            
            # Add session to user's session list
            self._add_to_user_sessions(user_id, session.id)
            
            logger.info(f"Created new session {session.id} for user {user_id}")
            return session
            
        except SQLAlchemyError as e:
            logger.error(f"Database error creating session: {e}")
            raise SessionManagerError(f"Failed to create session: {e}")
        except redis.RedisError as e:
            logger.error(f"Redis error caching session: {e}")
            # Session was created in DB, continue without cache
            return session
    
    def get_session(self, session_id: str) -> ChatSession:
        """

        Get a session by ID, checking cache first then database.

        

        Args:

            session_id: Session identifier

            

        Returns:

            ChatSession: The session object

            

        Raises:

            SessionNotFoundError: If session doesn't exist

            SessionExpiredError: If session has expired

        """
        # Try to get from cache first
        cached_session = self._get_cached_session(session_id)
        if cached_session:
            # Check if session is expired
            if self._is_session_expired(cached_session):
                self._expire_session(session_id)
                raise SessionExpiredError(f"Session {session_id} has expired")
            return cached_session
        
        # Get from database
        try:
            session = db.session.query(ChatSession).filter(
                ChatSession.id == session_id,
                ChatSession.is_active == True
            ).first()
            
            if not session:
                raise SessionNotFoundError(f"Session {session_id} not found")
            
            # Check if session is expired
            if session.is_expired(self.session_timeout):
                session.deactivate()
                raise SessionExpiredError(f"Session {session_id} has expired")
            
            # Cache the session
            self._cache_session(session)
            
            return session
            
        except SQLAlchemyError as e:
            logger.error(f"Database error getting session {session_id}: {e}")
            raise SessionManagerError(f"Failed to get session: {e}")
    
    def update_session_activity(self, session_id: str) -> None:
        """

        Update session activity timestamp.

        

        Args:

            session_id: Session identifier

            

        Raises:

            SessionNotFoundError: If session doesn't exist

        """
        try:
            session = self.get_session(session_id)
            session.update_activity()
            
            # Update cache
            self._cache_session(session)
            
            logger.debug(f"Updated activity for session {session_id}")
            
        except (SessionNotFoundError, SessionExpiredError):
            raise
        except Exception as e:
            logger.error(f"Error updating session activity: {e}")
            raise SessionManagerError(f"Failed to update session activity: {e}")
    
    def get_user_sessions(self, user_id: str, active_only: bool = True) -> List[ChatSession]:
        """

        Get all sessions for a user.

        

        Args:

            user_id: User identifier

            active_only: Whether to return only active sessions

            

        Returns:

            List[ChatSession]: List of user sessions

        """
        try:
            query = db.session.query(ChatSession).filter(ChatSession.user_id == user_id)
            
            if active_only:
                query = query.filter(ChatSession.is_active == True)
            
            sessions = query.order_by(ChatSession.last_active.desc()).all()
            
            # Filter out expired sessions
            if active_only:
                active_sessions = []
                for session in sessions:
                    if not session.is_expired(self.session_timeout):
                        active_sessions.append(session)
                    else:
                        # Mark as inactive
                        session.deactivate()
                        self._remove_from_cache(session.id)
                
                return active_sessions
            
            return sessions
            
        except SQLAlchemyError as e:
            logger.error(f"Database error getting user sessions: {e}")
            raise SessionManagerError(f"Failed to get user sessions: {e}")
    
    def cleanup_inactive_sessions(self) -> int:
        """

        Clean up inactive and expired sessions.

        

        Returns:

            int: Number of sessions cleaned up

        """
        try:
            # Clean up expired sessions in database
            cleaned_count = ChatSession.cleanup_expired_sessions(self.session_timeout)
            
            # Clean up expired sessions from cache
            self._cleanup_expired_cache_sessions()
            
            logger.info(f"Cleaned up {cleaned_count} expired sessions")
            return cleaned_count
            
        except SQLAlchemyError as e:
            logger.error(f"Database error during cleanup: {e}")
            raise SessionManagerError(f"Failed to cleanup sessions: {e}")
    
    def delete_session(self, session_id: str) -> None:
        """

        Delete a session completely.

        

        Args:

            session_id: Session identifier

            

        Raises:

            SessionNotFoundError: If session doesn't exist

        """
        try:
            session = db.session.query(ChatSession).filter(
                ChatSession.id == session_id
            ).first()
            
            if not session:
                raise SessionNotFoundError(f"Session {session_id} not found")
            
            user_id = session.user_id
            
            # Delete from database (cascade will handle related records)
            db.session.delete(session)
            db.session.commit()
            
            # Remove from cache
            self._remove_from_cache(session_id)
            
            # Remove from user sessions list
            self._remove_from_user_sessions(user_id, session_id)
            
            logger.info(f"Deleted session {session_id}")
            
        except SQLAlchemyError as e:
            logger.error(f"Database error deleting session: {e}")
            raise SessionManagerError(f"Failed to delete session: {e}")
    
    def set_session_language(self, session_id: str, language: str) -> None:
        """

        Set the programming language for a session.

        

        Args:

            session_id: Session identifier

            language: Programming language

            

        Raises:

            SessionNotFoundError: If session doesn't exist

        """
        try:
            session = self.get_session(session_id)
            session.set_language(language)
            
            # Update cache
            self._cache_session(session)
            
            logger.info(f"Set language to {language} for session {session_id}")
            
        except (SessionNotFoundError, SessionExpiredError):
            raise
        except Exception as e:
            logger.error(f"Error setting session language: {e}")
            raise SessionManagerError(f"Failed to set session language: {e}")
    
    def increment_message_count(self, session_id: str) -> None:
        """

        Increment the message count for a session.

        

        Args:

            session_id: Session identifier

        """
        try:
            session = self.get_session(session_id)
            session.increment_message_count()
            
            # Update cache
            self._cache_session(session)
            
        except (SessionNotFoundError, SessionExpiredError):
            raise
        except Exception as e:
            logger.error(f"Error incrementing message count: {e}")
            raise SessionManagerError(f"Failed to increment message count: {e}")
    
    def _cache_session(self, session: ChatSession) -> None:
        """Cache a session in Redis."""
        if not self.redis_client:
            return  # Skip caching if Redis is not available
            
        try:
            cache_key = f"{self.cache_prefix}{session.id}"
            session_data = {
                'id': session.id,
                'user_id': session.user_id,
                'language': session.language,
                'created_at': session.created_at.isoformat(),
                'last_active': session.last_active.isoformat(),
                'message_count': session.message_count,
                'is_active': session.is_active,
                'session_metadata': session.session_metadata
            }
            
            # Set with expiration
            self.redis_client.setex(
                cache_key,
                self.session_timeout + 300,  # Add 5 minutes buffer
                json.dumps(session_data)
            )
            
        except redis.RedisError as e:
            logger.warning(f"Failed to cache session {session.id}: {e}")
    
    def _get_cached_session(self, session_id: str) -> Optional[ChatSession]:
        """Get a session from Redis cache."""
        if not self.redis_client:
            return None  # Skip cache lookup if Redis is not available
            
        try:
            cache_key = f"{self.cache_prefix}{session_id}"
            cached_data = self.redis_client.get(cache_key)
            
            if not cached_data:
                return None
            
            session_data = json.loads(cached_data)
            
            # Create a ChatSession object from cached data
            session = ChatSession(
                user_id=session_data['user_id'],
                language=session_data['language'],
                session_metadata=session_data['session_metadata']
            )
            session.id = session_data['id']
            session.created_at = datetime.fromisoformat(session_data['created_at'])
            session.last_active = datetime.fromisoformat(session_data['last_active'])
            session.message_count = session_data['message_count']
            session.is_active = session_data['is_active']
            
            return session
            
        except (redis.RedisError, json.JSONDecodeError, KeyError) as e:
            logger.warning(f"Failed to get cached session {session_id}: {e}")
            return None
    
    def _remove_from_cache(self, session_id: str) -> None:
        """Remove a session from Redis cache."""
        if not self.redis_client:
            return  # Skip cache removal if Redis is not available
            
        try:
            cache_key = f"{self.cache_prefix}{session_id}"
            self.redis_client.delete(cache_key)
        except redis.RedisError as e:
            logger.warning(f"Failed to remove session {session_id} from cache: {e}")
    
    def _add_to_user_sessions(self, user_id: str, session_id: str) -> None:
        """Add session to user's session list in Redis."""
        if not self.redis_client:
            return  # Skip user session tracking if Redis is not available
            
        try:
            user_sessions_key = f"{self.user_sessions_prefix}{user_id}"
            self.redis_client.sadd(user_sessions_key, session_id)
            # Set expiration for user sessions list
            self.redis_client.expire(user_sessions_key, self.session_timeout * 2)
        except redis.RedisError as e:
            logger.warning(f"Failed to add session to user sessions list: {e}")
    
    def _remove_from_user_sessions(self, user_id: str, session_id: str) -> None:
        """Remove session from user's session list in Redis."""
        if not self.redis_client:
            return  # Skip user session tracking if Redis is not available
            
        try:
            user_sessions_key = f"{self.user_sessions_prefix}{user_id}"
            self.redis_client.srem(user_sessions_key, session_id)
        except redis.RedisError as e:
            logger.warning(f"Failed to remove session from user sessions list: {e}")
    
    def _is_session_expired(self, session: ChatSession) -> bool:
        """Check if a session is expired."""
        return session.is_expired(self.session_timeout)
    
    def _expire_session(self, session_id: str) -> None:
        """Mark a session as expired and clean up."""
        try:
            # Mark as inactive in database
            session = db.session.query(ChatSession).filter(
                ChatSession.id == session_id
            ).first()
            
            if session:
                session.deactivate()
                self._remove_from_user_sessions(session.user_id, session_id)
            
            # Remove from cache
            self._remove_from_cache(session_id)
            
        except SQLAlchemyError as e:
            logger.error(f"Error expiring session {session_id}: {e}")
    
    def _cleanup_expired_cache_sessions(self) -> None:
        """Clean up expired sessions from Redis cache."""
        try:
            # Get all session keys
            pattern = f"{self.cache_prefix}*"
            session_keys = self.redis_client.keys(pattern)
            
            expired_keys = []
            for key in session_keys:
                try:
                    cached_data = self.redis_client.get(key)
                    if cached_data:
                        session_data = json.loads(cached_data)
                        last_active = datetime.fromisoformat(session_data['last_active'])
                        
                        if datetime.utcnow() - last_active > timedelta(seconds=self.session_timeout):
                            expired_keys.append(key)
                except (json.JSONDecodeError, KeyError, ValueError):
                    # Invalid data, mark for deletion
                    expired_keys.append(key)
            
            # Delete expired keys
            if expired_keys:
                self.redis_client.delete(*expired_keys)
                logger.info(f"Cleaned up {len(expired_keys)} expired cache entries")
                
        except redis.RedisError as e:
            logger.warning(f"Failed to cleanup expired cache sessions: {e}")


def create_session_manager(redis_client: redis.Redis, session_timeout: int = 3600) -> SessionManager:
    """

    Factory function to create a SessionManager instance.

    

    Args:

        redis_client: Redis client instance

        session_timeout: Session timeout in seconds

        

    Returns:

        SessionManager: Configured session manager instance

    """
    return SessionManager(redis_client, session_timeout)