""" Short-Term / Session Memory ============================ Stores conversation context and ephemeral data as Markdown files under memory/session//*.md Entries expire after a configurable TTL (default 1 hour). """ from __future__ import annotations import json import os import time from collections import OrderedDict from datetime import datetime from pathlib import Path from typing import Dict, List, Optional from .models import MemoryEntry, MemoryTier class SessionMemory: """In-memory + file-backed short-term memory store.""" DEFAULT_TTL = 3600 # seconds – 1 hour MAX_ENTRIES_PER_SESSION = 50 def __init__(self, base_dir: str = "memory/session", ttl: int = DEFAULT_TTL): self.base_dir = Path(base_dir) self.base_dir.mkdir(parents=True, exist_ok=True) self.ttl = ttl # session_id → OrderedDict[entry_id, MemoryEntry] self._cache: Dict[str, OrderedDict[str, MemoryEntry]] = {} self._load_from_disk() # ── CRUD ───────────────────────────────────────────────── def create(self, entry: MemoryEntry, session_id: str = "default") -> MemoryEntry: """Add a new entry to a session.""" entry.tier = MemoryTier.SESSION entry.session_id = session_id entry.created_at = datetime.utcnow().isoformat() entry.updated_at = entry.created_at bucket = self._cache.setdefault(session_id, OrderedDict()) # evict oldest when full while len(bucket) >= self.MAX_ENTRIES_PER_SESSION: bucket.popitem(last=False) bucket[entry.id] = entry self._persist(entry, session_id) return entry def read(self, entry_id: str, session_id: str = "default") -> Optional[MemoryEntry]: """Retrieve a single entry by ID.""" bucket = self._cache.get(session_id, {}) entry = bucket.get(entry_id) if entry: entry.access_count += 1 entry.updated_at = datetime.utcnow().isoformat() self._persist(entry, session_id) return entry def update(self, entry_id: str, session_id: str = "default", **kwargs) -> Optional[MemoryEntry]: """Update fields on an existing entry.""" bucket = self._cache.get(session_id, {}) entry = bucket.get(entry_id) if not entry: return None for k, v in kwargs.items(): if hasattr(entry, k) and k not in ("id", "tier", "created_at"): setattr(entry, k, v) entry.updated_at = datetime.utcnow().isoformat() self._persist(entry, session_id) return entry def delete(self, entry_id: str, session_id: str = "default") -> bool: """Remove an entry.""" bucket = self._cache.get(session_id, {}) if entry_id not in bucket: return False del bucket[entry_id] path = self._entry_path(entry_id, session_id) if path.exists(): path.unlink() return True def list_entries(self, session_id: str = "default", tag: Optional[str] = None) -> List[MemoryEntry]: """List all entries in a session, optionally filtered by tag.""" bucket = self._cache.get(session_id, OrderedDict()) entries = list(bucket.values()) if tag: entries = [e for e in entries if tag in e.tags] return entries def list_sessions(self) -> List[str]: """List all known session IDs.""" return list(self._cache.keys()) def clear_session(self, session_id: str = "default") -> int: """Drop all entries in a session. Returns count deleted.""" bucket = self._cache.pop(session_id, OrderedDict()) count = len(bucket) session_dir = self.base_dir / session_id if session_dir.exists(): for f in session_dir.glob("*.md"): f.unlink() try: session_dir.rmdir() except OSError: pass return count def gc(self) -> int: """Garbage-collect expired entries across all sessions. Returns count removed.""" now = time.time() removed = 0 for sid in list(self._cache.keys()): for eid in list(self._cache[sid].keys()): entry = self._cache[sid][eid] created_ts = datetime.fromisoformat(entry.created_at).timestamp() if now - created_ts > self.ttl: self.delete(eid, sid) removed += 1 return removed # ── search helpers ─────────────────────────────────────── def search(self, query: str, session_id: Optional[str] = None, limit: int = 10) -> List[MemoryEntry]: """Simple keyword search across session memories.""" query_lower = query.lower() results: List[MemoryEntry] = [] sessions = [session_id] if session_id else list(self._cache.keys()) for sid in sessions: for entry in self._cache.get(sid, {}).values(): text = f"{entry.title} {entry.content} {' '.join(entry.tags)}".lower() if query_lower in text: results.append(entry) if len(results) >= limit: return results return results # ── persistence ────────────────────────────────────────── def _entry_path(self, entry_id: str, session_id: str) -> Path: d = self.base_dir / session_id d.mkdir(parents=True, exist_ok=True) return d / f"{entry_id}.md" def _persist(self, entry: MemoryEntry, session_id: str): path = self._entry_path(entry.id, session_id) path.write_text(entry.to_markdown(), encoding="utf-8") def _load_from_disk(self): """Bootstrap cache from existing .md files.""" if not self.base_dir.exists(): return for session_dir in self.base_dir.iterdir(): if not session_dir.is_dir(): continue sid = session_dir.name bucket = self._cache.setdefault(sid, OrderedDict()) for md_file in sorted(session_dir.glob("*.md")): try: text = md_file.read_text(encoding="utf-8") entry = MemoryEntry.from_markdown(text) entry.session_id = sid entry.tier = MemoryTier.SESSION bucket[entry.id] = entry except Exception: pass # skip corrupt files