Spaces:
Running
Running
File size: 6,953 Bytes
86a0172 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | """
Short-Term / Session Memory
============================
Stores conversation context and ephemeral data as Markdown files
under memory/session/<session_id>/*.md
Entries expire after a configurable TTL (default 1 hour).
"""
from __future__ import annotations
import json
import os
import time
from collections import OrderedDict
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
from .models import MemoryEntry, MemoryTier
class SessionMemory:
"""In-memory + file-backed short-term memory store."""
DEFAULT_TTL = 3600 # seconds β 1 hour
MAX_ENTRIES_PER_SESSION = 50
def __init__(self, base_dir: str = "memory/session", ttl: int = DEFAULT_TTL):
self.base_dir = Path(base_dir)
self.base_dir.mkdir(parents=True, exist_ok=True)
self.ttl = ttl
# session_id β OrderedDict[entry_id, MemoryEntry]
self._cache: Dict[str, OrderedDict[str, MemoryEntry]] = {}
self._load_from_disk()
# ββ CRUD βββββββββββββββββββββββββββββββββββββββββββββββββ
def create(self, entry: MemoryEntry, session_id: str = "default") -> MemoryEntry:
"""Add a new entry to a session."""
entry.tier = MemoryTier.SESSION
entry.session_id = session_id
entry.created_at = datetime.utcnow().isoformat()
entry.updated_at = entry.created_at
bucket = self._cache.setdefault(session_id, OrderedDict())
# evict oldest when full
while len(bucket) >= self.MAX_ENTRIES_PER_SESSION:
bucket.popitem(last=False)
bucket[entry.id] = entry
self._persist(entry, session_id)
return entry
def read(self, entry_id: str, session_id: str = "default") -> Optional[MemoryEntry]:
"""Retrieve a single entry by ID."""
bucket = self._cache.get(session_id, {})
entry = bucket.get(entry_id)
if entry:
entry.access_count += 1
entry.updated_at = datetime.utcnow().isoformat()
self._persist(entry, session_id)
return entry
def update(self, entry_id: str, session_id: str = "default", **kwargs) -> Optional[MemoryEntry]:
"""Update fields on an existing entry."""
bucket = self._cache.get(session_id, {})
entry = bucket.get(entry_id)
if not entry:
return None
for k, v in kwargs.items():
if hasattr(entry, k) and k not in ("id", "tier", "created_at"):
setattr(entry, k, v)
entry.updated_at = datetime.utcnow().isoformat()
self._persist(entry, session_id)
return entry
def delete(self, entry_id: str, session_id: str = "default") -> bool:
"""Remove an entry."""
bucket = self._cache.get(session_id, {})
if entry_id not in bucket:
return False
del bucket[entry_id]
path = self._entry_path(entry_id, session_id)
if path.exists():
path.unlink()
return True
def list_entries(self, session_id: str = "default", tag: Optional[str] = None) -> List[MemoryEntry]:
"""List all entries in a session, optionally filtered by tag."""
bucket = self._cache.get(session_id, OrderedDict())
entries = list(bucket.values())
if tag:
entries = [e for e in entries if tag in e.tags]
return entries
def list_sessions(self) -> List[str]:
"""List all known session IDs."""
return list(self._cache.keys())
def clear_session(self, session_id: str = "default") -> int:
"""Drop all entries in a session. Returns count deleted."""
bucket = self._cache.pop(session_id, OrderedDict())
count = len(bucket)
session_dir = self.base_dir / session_id
if session_dir.exists():
for f in session_dir.glob("*.md"):
f.unlink()
try:
session_dir.rmdir()
except OSError:
pass
return count
def gc(self) -> int:
"""Garbage-collect expired entries across all sessions. Returns count removed."""
now = time.time()
removed = 0
for sid in list(self._cache.keys()):
for eid in list(self._cache[sid].keys()):
entry = self._cache[sid][eid]
created_ts = datetime.fromisoformat(entry.created_at).timestamp()
if now - created_ts > self.ttl:
self.delete(eid, sid)
removed += 1
return removed
# ββ search helpers βββββββββββββββββββββββββββββββββββββββ
def search(self, query: str, session_id: Optional[str] = None, limit: int = 10) -> List[MemoryEntry]:
"""Simple keyword search across session memories."""
query_lower = query.lower()
results: List[MemoryEntry] = []
sessions = [session_id] if session_id else list(self._cache.keys())
for sid in sessions:
for entry in self._cache.get(sid, {}).values():
text = f"{entry.title} {entry.content} {' '.join(entry.tags)}".lower()
if query_lower in text:
results.append(entry)
if len(results) >= limit:
return results
return results
# ββ persistence ββββββββββββββββββββββββββββββββββββββββββ
def _entry_path(self, entry_id: str, session_id: str) -> Path:
d = self.base_dir / session_id
d.mkdir(parents=True, exist_ok=True)
return d / f"{entry_id}.md"
def _persist(self, entry: MemoryEntry, session_id: str):
path = self._entry_path(entry.id, session_id)
path.write_text(entry.to_markdown(), encoding="utf-8")
def _load_from_disk(self):
"""Bootstrap cache from existing .md files."""
if not self.base_dir.exists():
return
for session_dir in self.base_dir.iterdir():
if not session_dir.is_dir():
continue
sid = session_dir.name
bucket = self._cache.setdefault(sid, OrderedDict())
for md_file in sorted(session_dir.glob("*.md")):
try:
text = md_file.read_text(encoding="utf-8")
entry = MemoryEntry.from_markdown(text)
entry.session_id = sid
entry.tier = MemoryTier.SESSION
bucket[entry.id] = entry
except Exception:
pass # skip corrupt files
|