File size: 6,953 Bytes
86a0172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""

Short-Term / Session Memory

============================

Stores conversation context and ephemeral data as Markdown files

under  memory/session/<session_id>/*.md



Entries expire after a configurable TTL (default 1 hour).

"""

from __future__ import annotations

import json
import os
import time
from collections import OrderedDict
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional

from .models import MemoryEntry, MemoryTier


class SessionMemory:
    """In-memory + file-backed short-term memory store."""

    DEFAULT_TTL = 3600          # seconds – 1 hour
    MAX_ENTRIES_PER_SESSION = 50

    def __init__(self, base_dir: str = "memory/session", ttl: int = DEFAULT_TTL):
        self.base_dir = Path(base_dir)
        self.base_dir.mkdir(parents=True, exist_ok=True)
        self.ttl = ttl
        # session_id β†’ OrderedDict[entry_id, MemoryEntry]
        self._cache: Dict[str, OrderedDict[str, MemoryEntry]] = {}
        self._load_from_disk()

    # ── CRUD ─────────────────────────────────────────────────

    def create(self, entry: MemoryEntry, session_id: str = "default") -> MemoryEntry:
        """Add a new entry to a session."""
        entry.tier = MemoryTier.SESSION
        entry.session_id = session_id
        entry.created_at = datetime.utcnow().isoformat()
        entry.updated_at = entry.created_at

        bucket = self._cache.setdefault(session_id, OrderedDict())
        # evict oldest when full
        while len(bucket) >= self.MAX_ENTRIES_PER_SESSION:
            bucket.popitem(last=False)
        bucket[entry.id] = entry
        self._persist(entry, session_id)
        return entry

    def read(self, entry_id: str, session_id: str = "default") -> Optional[MemoryEntry]:
        """Retrieve a single entry by ID."""
        bucket = self._cache.get(session_id, {})
        entry = bucket.get(entry_id)
        if entry:
            entry.access_count += 1
            entry.updated_at = datetime.utcnow().isoformat()
            self._persist(entry, session_id)
        return entry

    def update(self, entry_id: str, session_id: str = "default", **kwargs) -> Optional[MemoryEntry]:
        """Update fields on an existing entry."""
        bucket = self._cache.get(session_id, {})
        entry = bucket.get(entry_id)
        if not entry:
            return None
        for k, v in kwargs.items():
            if hasattr(entry, k) and k not in ("id", "tier", "created_at"):
                setattr(entry, k, v)
        entry.updated_at = datetime.utcnow().isoformat()
        self._persist(entry, session_id)
        return entry

    def delete(self, entry_id: str, session_id: str = "default") -> bool:
        """Remove an entry."""
        bucket = self._cache.get(session_id, {})
        if entry_id not in bucket:
            return False
        del bucket[entry_id]
        path = self._entry_path(entry_id, session_id)
        if path.exists():
            path.unlink()
        return True

    def list_entries(self, session_id: str = "default", tag: Optional[str] = None) -> List[MemoryEntry]:
        """List all entries in a session, optionally filtered by tag."""
        bucket = self._cache.get(session_id, OrderedDict())
        entries = list(bucket.values())
        if tag:
            entries = [e for e in entries if tag in e.tags]
        return entries

    def list_sessions(self) -> List[str]:
        """List all known session IDs."""
        return list(self._cache.keys())

    def clear_session(self, session_id: str = "default") -> int:
        """Drop all entries in a session. Returns count deleted."""
        bucket = self._cache.pop(session_id, OrderedDict())
        count = len(bucket)
        session_dir = self.base_dir / session_id
        if session_dir.exists():
            for f in session_dir.glob("*.md"):
                f.unlink()
            try:
                session_dir.rmdir()
            except OSError:
                pass
        return count

    def gc(self) -> int:
        """Garbage-collect expired entries across all sessions. Returns count removed."""
        now = time.time()
        removed = 0
        for sid in list(self._cache.keys()):
            for eid in list(self._cache[sid].keys()):
                entry = self._cache[sid][eid]
                created_ts = datetime.fromisoformat(entry.created_at).timestamp()
                if now - created_ts > self.ttl:
                    self.delete(eid, sid)
                    removed += 1
        return removed

    # ── search helpers ───────────────────────────────────────

    def search(self, query: str, session_id: Optional[str] = None, limit: int = 10) -> List[MemoryEntry]:
        """Simple keyword search across session memories."""
        query_lower = query.lower()
        results: List[MemoryEntry] = []

        sessions = [session_id] if session_id else list(self._cache.keys())
        for sid in sessions:
            for entry in self._cache.get(sid, {}).values():
                text = f"{entry.title} {entry.content} {' '.join(entry.tags)}".lower()
                if query_lower in text:
                    results.append(entry)
                    if len(results) >= limit:
                        return results
        return results

    # ── persistence ──────────────────────────────────────────

    def _entry_path(self, entry_id: str, session_id: str) -> Path:
        d = self.base_dir / session_id
        d.mkdir(parents=True, exist_ok=True)
        return d / f"{entry_id}.md"

    def _persist(self, entry: MemoryEntry, session_id: str):
        path = self._entry_path(entry.id, session_id)
        path.write_text(entry.to_markdown(), encoding="utf-8")

    def _load_from_disk(self):
        """Bootstrap cache from existing .md files."""
        if not self.base_dir.exists():
            return
        for session_dir in self.base_dir.iterdir():
            if not session_dir.is_dir():
                continue
            sid = session_dir.name
            bucket = self._cache.setdefault(sid, OrderedDict())
            for md_file in sorted(session_dir.glob("*.md")):
                try:
                    text = md_file.read_text(encoding="utf-8")
                    entry = MemoryEntry.from_markdown(text)
                    entry.session_id = sid
                    entry.tier = MemoryTier.SESSION
                    bucket[entry.id] = entry
                except Exception:
                    pass  # skip corrupt files