File size: 11,146 Bytes
0157ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
"""
Session Store for Messaging Platforms

Provides persistent storage for mapping platform messages to Claude CLI session IDs
and message trees for conversation continuation.
"""

import contextlib
import json
import os
import tempfile
import threading
from datetime import UTC, datetime
from typing import Any

from loguru import logger


class SessionStore:
    """
    Persistent storage for message ↔ Claude session mappings and message trees.

    Uses a JSON file for storage with thread-safe operations.
    Platform-agnostic: works with any messaging platform.
    """

    def __init__(
        self,
        storage_path: str = "sessions.json",
        *,
        message_log_cap: int | None = None,
    ):
        self.storage_path = storage_path
        self._lock = threading.Lock()
        self._trees: dict[str, dict] = {}  # root_id -> tree data
        self._node_to_tree: dict[str, str] = {}  # node_id -> root_id
        # Per-chat message ID log used to support best-effort UI clearing (/clear).
        # Key: "{platform}:{chat_id}" -> list of records
        self._message_log: dict[str, list[dict[str, Any]]] = {}
        self._message_log_ids: dict[str, set[str]] = {}
        self._dirty = False
        self._save_timer: threading.Timer | None = None
        self._save_debounce_secs = 0.5
        self._message_log_cap: int | None = message_log_cap
        self._load()

    def _make_chat_key(self, platform: str, chat_id: str) -> str:
        return f"{platform}:{chat_id}"

    def _load(self) -> None:
        """Load sessions and trees from disk."""
        if not os.path.exists(self.storage_path):
            return

        try:
            with open(self.storage_path, encoding="utf-8") as f:
                data = json.load(f)

            # Load trees
            self._trees = data.get("trees", {})
            self._node_to_tree = data.get("node_to_tree", {})

            # Load message log (optional/backward compatible)
            raw_log = data.get("message_log", {}) or {}
            if isinstance(raw_log, dict):
                self._message_log = {}
                self._message_log_ids = {}
                for chat_key, items in raw_log.items():
                    if not isinstance(chat_key, str) or not isinstance(items, list):
                        continue
                    cleaned: list[dict[str, Any]] = []
                    seen: set[str] = set()
                    for it in items:
                        if not isinstance(it, dict):
                            continue
                        mid = it.get("message_id")
                        if mid is None:
                            continue
                        mid_s = str(mid)
                        if mid_s in seen:
                            continue
                        seen.add(mid_s)
                        cleaned.append(
                            {
                                "message_id": mid_s,
                                "ts": str(it.get("ts") or ""),
                                "direction": str(it.get("direction") or ""),
                                "kind": str(it.get("kind") or ""),
                            }
                        )
                    self._message_log[chat_key] = cleaned
                    self._message_log_ids[chat_key] = seen

            logger.info(
                f"Loaded {len(self._trees)} trees and "
                f"{sum(len(v) for v in self._message_log.values())} msg_ids from {self.storage_path}"
            )
        except Exception as e:
            logger.error(f"Failed to load sessions: {e}")

    def _snapshot(self) -> dict:
        """Snapshot current state for serialization. Caller must hold self._lock."""
        return {
            "trees": dict(self._trees),
            "node_to_tree": dict(self._node_to_tree),
            "message_log": {k: list(v) for k, v in self._message_log.items()},
        }

    def _write_data(self, data: dict) -> None:
        """Atomically write data dict to disk. Must be called WITHOUT holding self._lock."""
        abs_target = os.path.abspath(self.storage_path)
        dir_name = os.path.dirname(abs_target) or "."
        fd, tmp_path = tempfile.mkstemp(
            dir=dir_name, prefix=".sessions.", suffix=".tmp.json"
        )
        try:
            with os.fdopen(fd, "w", encoding="utf-8") as f:
                json.dump(data, f, indent=2)
                f.flush()
                os.fsync(f.fileno())
            os.replace(tmp_path, abs_target)
        except BaseException:
            with contextlib.suppress(OSError):
                os.unlink(tmp_path)
            raise

    def _schedule_save(self) -> None:
        """Schedule a debounced save. Caller must hold self._lock."""
        self._dirty = True
        if self._save_timer is not None:
            self._save_timer.cancel()
            self._save_timer = None
        self._save_timer = threading.Timer(
            self._save_debounce_secs, self._save_from_timer
        )
        self._save_timer.daemon = True
        self._save_timer.start()

    def _save_from_timer(self) -> None:
        """Timer callback: save if dirty. Runs in timer thread."""
        with self._lock:
            if not self._dirty:
                self._save_timer = None
                return
            snapshot = self._snapshot()
            self._dirty = False
            self._save_timer = None
        try:
            self._write_data(snapshot)
        except Exception as e:
            logger.error(f"Failed to save sessions: {e}")
            with self._lock:
                self._dirty = True

    def _flush_save(self) -> dict:
        """Cancel pending timer and snapshot current state. Caller must hold self._lock.
        Returns snapshot dict; caller must call _write_data(snapshot) outside the lock."""
        if self._save_timer is not None:
            self._save_timer.cancel()
            self._save_timer = None
        self._dirty = False
        return self._snapshot()

    def flush_pending_save(self) -> None:
        """Flush any pending debounced save. Call on shutdown to avoid losing data."""
        with self._lock:
            snapshot = self._flush_save()
        try:
            self._write_data(snapshot)
        except Exception as e:
            logger.error(f"Failed to save sessions: {e}")
            with self._lock:
                self._dirty = True

    def record_message_id(
        self,
        platform: str,
        chat_id: str,
        message_id: str,
        direction: str,
        kind: str,
    ) -> None:
        """Record a message_id for later best-effort deletion (/clear)."""
        if message_id is None:
            return

        chat_key = self._make_chat_key(str(platform), str(chat_id))
        mid = str(message_id)

        with self._lock:
            seen = self._message_log_ids.setdefault(chat_key, set())
            if mid in seen:
                return

            rec = {
                "message_id": mid,
                "ts": datetime.now(UTC).isoformat(),
                "direction": str(direction),
                "kind": str(kind),
            }
            self._message_log.setdefault(chat_key, []).append(rec)
            seen.add(mid)

            # Optional cap to prevent unbounded growth if configured.
            if self._message_log_cap is not None and self._message_log_cap > 0:
                items = self._message_log.get(chat_key, [])
                if len(items) > self._message_log_cap:
                    self._message_log[chat_key] = items[-self._message_log_cap :]
                    self._message_log_ids[chat_key] = {
                        str(x.get("message_id")) for x in self._message_log[chat_key]
                    }

            self._schedule_save()

    def get_message_ids_for_chat(self, platform: str, chat_id: str) -> list[str]:
        """Get all recorded message IDs for a chat (in insertion order)."""
        chat_key = self._make_chat_key(str(platform), str(chat_id))
        with self._lock:
            items = self._message_log.get(chat_key, [])
            return [
                str(x.get("message_id"))
                for x in items
                if x.get("message_id") is not None
            ]

    def clear_all(self) -> None:
        """Clear all stored sessions/trees/mappings and persist an empty store."""
        with self._lock:
            self._trees.clear()
            self._node_to_tree.clear()
            self._message_log.clear()
            self._message_log_ids.clear()
            snapshot = self._flush_save()
        try:
            self._write_data(snapshot)
        except Exception as e:
            logger.error(f"Failed to save sessions: {e}")
            with self._lock:
                self._dirty = True

    # ==================== Tree Methods ====================

    def save_tree(self, root_id: str, tree_data: dict) -> None:
        """
        Save a message tree.

        Args:
            root_id: Root node ID of the tree
            tree_data: Serialized tree data from tree.to_dict()
        """
        with self._lock:
            self._trees[root_id] = tree_data

            # Update node-to-tree mapping
            for node_id in tree_data.get("nodes", {}):
                self._node_to_tree[node_id] = root_id

            self._schedule_save()
            logger.debug(f"Saved tree {root_id}")

    def get_tree(self, root_id: str) -> dict | None:
        """Get a tree by its root ID."""
        with self._lock:
            return self._trees.get(root_id)

    def register_node(self, node_id: str, root_id: str) -> None:
        """Register a node ID to a tree root."""
        with self._lock:
            self._node_to_tree[node_id] = root_id
            self._schedule_save()

    def remove_node_mappings(self, node_ids: list[str]) -> None:
        """Remove node IDs from the node-to-tree mapping."""
        with self._lock:
            for nid in node_ids:
                self._node_to_tree.pop(nid, None)
            self._schedule_save()

    def remove_tree(self, root_id: str) -> None:
        """Remove a tree and all its node mappings from the store."""
        with self._lock:
            tree_data = self._trees.pop(root_id, None)
            if tree_data:
                for node_id in tree_data.get("nodes", {}):
                    self._node_to_tree.pop(node_id, None)
                self._schedule_save()

    def get_all_trees(self) -> dict[str, dict]:
        """Get all stored trees (public accessor)."""
        with self._lock:
            return dict(self._trees)

    def get_node_mapping(self) -> dict[str, str]:
        """Get the node-to-tree mapping (public accessor)."""
        with self._lock:
            return dict(self._node_to_tree)

    def sync_from_tree_data(
        self, trees: dict[str, dict], node_to_tree: dict[str, str]
    ) -> None:
        """Sync internal tree state from external data and persist."""
        with self._lock:
            self._trees = trees
            self._node_to_tree = node_to_tree
            self._schedule_save()