agent-memory / memory /session.py
Chris4K's picture
Upload 17 files
86a0172 verified
"""
Short-Term / Session Memory
============================
Stores conversation context and ephemeral data as Markdown files
under memory/session/<session_id>/*.md
Entries expire after a configurable TTL (default 1 hour).
"""
from __future__ import annotations
import json
import os
import time
from collections import OrderedDict
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
from .models import MemoryEntry, MemoryTier
class SessionMemory:
"""In-memory + file-backed short-term memory store."""
DEFAULT_TTL = 3600 # seconds – 1 hour
MAX_ENTRIES_PER_SESSION = 50
def __init__(self, base_dir: str = "memory/session", ttl: int = DEFAULT_TTL):
self.base_dir = Path(base_dir)
self.base_dir.mkdir(parents=True, exist_ok=True)
self.ttl = ttl
# session_id β†’ OrderedDict[entry_id, MemoryEntry]
self._cache: Dict[str, OrderedDict[str, MemoryEntry]] = {}
self._load_from_disk()
# ── CRUD ─────────────────────────────────────────────────
def create(self, entry: MemoryEntry, session_id: str = "default") -> MemoryEntry:
"""Add a new entry to a session."""
entry.tier = MemoryTier.SESSION
entry.session_id = session_id
entry.created_at = datetime.utcnow().isoformat()
entry.updated_at = entry.created_at
bucket = self._cache.setdefault(session_id, OrderedDict())
# evict oldest when full
while len(bucket) >= self.MAX_ENTRIES_PER_SESSION:
bucket.popitem(last=False)
bucket[entry.id] = entry
self._persist(entry, session_id)
return entry
def read(self, entry_id: str, session_id: str = "default") -> Optional[MemoryEntry]:
"""Retrieve a single entry by ID."""
bucket = self._cache.get(session_id, {})
entry = bucket.get(entry_id)
if entry:
entry.access_count += 1
entry.updated_at = datetime.utcnow().isoformat()
self._persist(entry, session_id)
return entry
def update(self, entry_id: str, session_id: str = "default", **kwargs) -> Optional[MemoryEntry]:
"""Update fields on an existing entry."""
bucket = self._cache.get(session_id, {})
entry = bucket.get(entry_id)
if not entry:
return None
for k, v in kwargs.items():
if hasattr(entry, k) and k not in ("id", "tier", "created_at"):
setattr(entry, k, v)
entry.updated_at = datetime.utcnow().isoformat()
self._persist(entry, session_id)
return entry
def delete(self, entry_id: str, session_id: str = "default") -> bool:
"""Remove an entry."""
bucket = self._cache.get(session_id, {})
if entry_id not in bucket:
return False
del bucket[entry_id]
path = self._entry_path(entry_id, session_id)
if path.exists():
path.unlink()
return True
def list_entries(self, session_id: str = "default", tag: Optional[str] = None) -> List[MemoryEntry]:
"""List all entries in a session, optionally filtered by tag."""
bucket = self._cache.get(session_id, OrderedDict())
entries = list(bucket.values())
if tag:
entries = [e for e in entries if tag in e.tags]
return entries
def list_sessions(self) -> List[str]:
"""List all known session IDs."""
return list(self._cache.keys())
def clear_session(self, session_id: str = "default") -> int:
"""Drop all entries in a session. Returns count deleted."""
bucket = self._cache.pop(session_id, OrderedDict())
count = len(bucket)
session_dir = self.base_dir / session_id
if session_dir.exists():
for f in session_dir.glob("*.md"):
f.unlink()
try:
session_dir.rmdir()
except OSError:
pass
return count
def gc(self) -> int:
"""Garbage-collect expired entries across all sessions. Returns count removed."""
now = time.time()
removed = 0
for sid in list(self._cache.keys()):
for eid in list(self._cache[sid].keys()):
entry = self._cache[sid][eid]
created_ts = datetime.fromisoformat(entry.created_at).timestamp()
if now - created_ts > self.ttl:
self.delete(eid, sid)
removed += 1
return removed
# ── search helpers ───────────────────────────────────────
def search(self, query: str, session_id: Optional[str] = None, limit: int = 10) -> List[MemoryEntry]:
"""Simple keyword search across session memories."""
query_lower = query.lower()
results: List[MemoryEntry] = []
sessions = [session_id] if session_id else list(self._cache.keys())
for sid in sessions:
for entry in self._cache.get(sid, {}).values():
text = f"{entry.title} {entry.content} {' '.join(entry.tags)}".lower()
if query_lower in text:
results.append(entry)
if len(results) >= limit:
return results
return results
# ── persistence ──────────────────────────────────────────
def _entry_path(self, entry_id: str, session_id: str) -> Path:
d = self.base_dir / session_id
d.mkdir(parents=True, exist_ok=True)
return d / f"{entry_id}.md"
def _persist(self, entry: MemoryEntry, session_id: str):
path = self._entry_path(entry.id, session_id)
path.write_text(entry.to_markdown(), encoding="utf-8")
def _load_from_disk(self):
"""Bootstrap cache from existing .md files."""
if not self.base_dir.exists():
return
for session_dir in self.base_dir.iterdir():
if not session_dir.is_dir():
continue
sid = session_dir.name
bucket = self._cache.setdefault(sid, OrderedDict())
for md_file in sorted(session_dir.glob("*.md")):
try:
text = md_file.read_text(encoding="utf-8")
entry = MemoryEntry.from_markdown(text)
entry.session_id = sid
entry.tier = MemoryTier.SESSION
bucket[entry.id] = entry
except Exception:
pass # skip corrupt files