File size: 3,464 Bytes
8b4e99f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""

InMemory session manager.

- Thread-safe via Lock

- TTL-based lazy eviction

- Bounded by MAX_SESSIONS

"""

from __future__ import annotations
import os
import threading
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Optional

SESSION_TTL_SEC = int(os.environ.get("SESSION_TTL_SEC", "1800"))
MAX_SESSIONS = int(os.environ.get("MAX_SESSIONS", "1000"))
MAX_TURNS = 50


@dataclass
class AccumulatedContext:
    campaign_name: Optional[str] = None
    industry: Optional[str] = None
    cvr: Optional[float] = None
    ctr: Optional[float] = None
    cpa: Optional[float] = None
    image_base64: Optional[str] = None

    def merge(self, ctx: "AccumulatedContext") -> None:
        """Merge new values in -- never overwrites with None."""
        if ctx.campaign_name is not None:
            self.campaign_name = ctx.campaign_name
        if ctx.industry is not None:
            self.industry = ctx.industry
        if ctx.cvr is not None:
            self.cvr = ctx.cvr
        if ctx.ctr is not None:
            self.ctr = ctx.ctr
        if ctx.cpa is not None:
            self.cpa = ctx.cpa
        if ctx.image_base64 is not None:
            self.image_base64 = ctx.image_base64


@dataclass
class HistoryEntry:
    role: str  # "user" | "assistant"
    content: str
    timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())


@dataclass
class SessionState:
    session_id: str
    created_at: datetime = field(default_factory=datetime.utcnow)
    last_accessed: datetime = field(default_factory=datetime.utcnow)
    turn_count: int = 0
    accumulated_context: AccumulatedContext = field(default_factory=AccumulatedContext)
    history: list[HistoryEntry] = field(default_factory=list)
    current_level: str = "level1"

    def is_expired(self) -> bool:
        return datetime.utcnow() - self.last_accessed > timedelta(seconds=SESSION_TTL_SEC)

    def touch(self) -> None:
        self.last_accessed = datetime.utcnow()


class SessionStore:
    def __init__(self) -> None:
        self._sessions: dict[str, SessionState] = {}
        self._lock = threading.Lock()

    def create(self) -> SessionState:
        with self._lock:
            self._evict_expired()
            if len(self._sessions) >= MAX_SESSIONS:
                raise RuntimeError("MAX_SESSIONS limit reached")
            session_id = str(uuid.uuid4())
            state = SessionState(session_id=session_id)
            self._sessions[session_id] = state
            return state

    def get(self, session_id: str) -> Optional[SessionState]:
        with self._lock:
            state = self._sessions.get(session_id)
            if state is None:
                return None
            if state.is_expired():
                del self._sessions[session_id]
                return None
            state.touch()
            return state

    def save(self, state: SessionState) -> None:
        with self._lock:
            self._sessions[state.session_id] = state

    def _evict_expired(self) -> None:
        expired = [sid for sid, s in self._sessions.items() if s.is_expired()]
        for sid in expired:
            del self._sessions[sid]

    def count(self) -> int:
        with self._lock:
            return len(self._sessions)


# Singleton
store = SessionStore()