File size: 3,014 Bytes
6ca2339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""In-memory session manager with TTL-based expiry + SQLite persistence."""

import time
from typing import Optional
from app.chat_db import save_message, get_chat_history


class Session:
    """Represents a single conversation session."""

    def __init__(self, session_id: str):
        self.session_id = session_id
        self.history: list[dict] = []  # [{"role": "user"|"assistant", "content": str}]
        self.current_topic: str = ""
        self.pending_clarification: Optional[str] = None
        self.last_confidence: float = 0.0
        self.last_action: str = ""  # "retrieve" | "reason" | "clarify"
        self.last_active: float = time.time()

    def add_user_message(self, message: str):
        self.history.append({"role": "user", "content": message})
        self.last_active = time.time()
        self._trim_history()
        # Persist to SQLite
        save_message(self.session_id, "user", message)

    def add_assistant_message(self, message: str):
        self.history.append({"role": "assistant", "content": message})
        self.last_active = time.time()
        self._trim_history()
        # Persist to SQLite
        save_message(self.session_id, "assistant", message)

    def get_history_text(self) -> str:
        """Return formatted conversation history for LLM context."""
        lines = []
        for msg in self.history[-10:]:  # last 10 messages
            role = "User" if msg["role"] == "user" else "Assistant"
            lines.append(f"{role}: {msg['content']}")
        return "\n".join(lines)

    def _trim_history(self, max_turns: int = 20):
        """Keep only the last max_turns messages in memory."""
        if len(self.history) > max_turns:
            self.history = self.history[-max_turns:]


class SessionManager:
    """Manages all active sessions with TTL expiry."""

    def __init__(self, ttl_minutes: int = 30):
        self._sessions: dict[str, Session] = {}
        self._ttl_seconds = ttl_minutes * 60

    def get_or_create(self, session_id: str) -> Session:
        """Get an existing session or create a new one."""
        self._cleanup_expired()
        if session_id not in self._sessions:
            session = Session(session_id)
            # Load existing history from database (if any)
            db_history = get_chat_history(session_id)
            for msg in db_history:
                session.history.append({"role": msg["role"], "content": msg["content"]})
            session._trim_history()
            self._sessions[session_id] = session
        session = self._sessions[session_id]
        session.last_active = time.time()
        return session

    def _cleanup_expired(self):
        """Remove sessions that have exceeded TTL."""
        now = time.time()
        expired = [
            sid for sid, s in self._sessions.items()
            if now - s.last_active > self._ttl_seconds
        ]
        for sid in expired:
            del self._sessions[sid]


# Global singleton
session_manager = SessionManager()