Spaces:
Running
Running
| """ | |
| Short-Term / Session Memory | |
| ============================ | |
| Stores conversation context and ephemeral data as Markdown files | |
| under memory/session/<session_id>/*.md | |
| Entries expire after a configurable TTL (default 1 hour). | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import time | |
| from collections import OrderedDict | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Dict, List, Optional | |
| from .models import MemoryEntry, MemoryTier | |
| class SessionMemory: | |
| """In-memory + file-backed short-term memory store.""" | |
| DEFAULT_TTL = 3600 # seconds β 1 hour | |
| MAX_ENTRIES_PER_SESSION = 50 | |
| def __init__(self, base_dir: str = "memory/session", ttl: int = DEFAULT_TTL): | |
| self.base_dir = Path(base_dir) | |
| self.base_dir.mkdir(parents=True, exist_ok=True) | |
| self.ttl = ttl | |
| # session_id β OrderedDict[entry_id, MemoryEntry] | |
| self._cache: Dict[str, OrderedDict[str, MemoryEntry]] = {} | |
| self._load_from_disk() | |
| # ββ CRUD βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def create(self, entry: MemoryEntry, session_id: str = "default") -> MemoryEntry: | |
| """Add a new entry to a session.""" | |
| entry.tier = MemoryTier.SESSION | |
| entry.session_id = session_id | |
| entry.created_at = datetime.utcnow().isoformat() | |
| entry.updated_at = entry.created_at | |
| bucket = self._cache.setdefault(session_id, OrderedDict()) | |
| # evict oldest when full | |
| while len(bucket) >= self.MAX_ENTRIES_PER_SESSION: | |
| bucket.popitem(last=False) | |
| bucket[entry.id] = entry | |
| self._persist(entry, session_id) | |
| return entry | |
| def read(self, entry_id: str, session_id: str = "default") -> Optional[MemoryEntry]: | |
| """Retrieve a single entry by ID.""" | |
| bucket = self._cache.get(session_id, {}) | |
| entry = bucket.get(entry_id) | |
| if entry: | |
| entry.access_count += 1 | |
| entry.updated_at = datetime.utcnow().isoformat() | |
| self._persist(entry, session_id) | |
| return entry | |
| def update(self, entry_id: str, session_id: str = "default", **kwargs) -> Optional[MemoryEntry]: | |
| """Update fields on an existing entry.""" | |
| bucket = self._cache.get(session_id, {}) | |
| entry = bucket.get(entry_id) | |
| if not entry: | |
| return None | |
| for k, v in kwargs.items(): | |
| if hasattr(entry, k) and k not in ("id", "tier", "created_at"): | |
| setattr(entry, k, v) | |
| entry.updated_at = datetime.utcnow().isoformat() | |
| self._persist(entry, session_id) | |
| return entry | |
| def delete(self, entry_id: str, session_id: str = "default") -> bool: | |
| """Remove an entry.""" | |
| bucket = self._cache.get(session_id, {}) | |
| if entry_id not in bucket: | |
| return False | |
| del bucket[entry_id] | |
| path = self._entry_path(entry_id, session_id) | |
| if path.exists(): | |
| path.unlink() | |
| return True | |
| def list_entries(self, session_id: str = "default", tag: Optional[str] = None) -> List[MemoryEntry]: | |
| """List all entries in a session, optionally filtered by tag.""" | |
| bucket = self._cache.get(session_id, OrderedDict()) | |
| entries = list(bucket.values()) | |
| if tag: | |
| entries = [e for e in entries if tag in e.tags] | |
| return entries | |
| def list_sessions(self) -> List[str]: | |
| """List all known session IDs.""" | |
| return list(self._cache.keys()) | |
| def clear_session(self, session_id: str = "default") -> int: | |
| """Drop all entries in a session. Returns count deleted.""" | |
| bucket = self._cache.pop(session_id, OrderedDict()) | |
| count = len(bucket) | |
| session_dir = self.base_dir / session_id | |
| if session_dir.exists(): | |
| for f in session_dir.glob("*.md"): | |
| f.unlink() | |
| try: | |
| session_dir.rmdir() | |
| except OSError: | |
| pass | |
| return count | |
| def gc(self) -> int: | |
| """Garbage-collect expired entries across all sessions. Returns count removed.""" | |
| now = time.time() | |
| removed = 0 | |
| for sid in list(self._cache.keys()): | |
| for eid in list(self._cache[sid].keys()): | |
| entry = self._cache[sid][eid] | |
| created_ts = datetime.fromisoformat(entry.created_at).timestamp() | |
| if now - created_ts > self.ttl: | |
| self.delete(eid, sid) | |
| removed += 1 | |
| return removed | |
| # ββ search helpers βββββββββββββββββββββββββββββββββββββββ | |
| def search(self, query: str, session_id: Optional[str] = None, limit: int = 10) -> List[MemoryEntry]: | |
| """Simple keyword search across session memories.""" | |
| query_lower = query.lower() | |
| results: List[MemoryEntry] = [] | |
| sessions = [session_id] if session_id else list(self._cache.keys()) | |
| for sid in sessions: | |
| for entry in self._cache.get(sid, {}).values(): | |
| text = f"{entry.title} {entry.content} {' '.join(entry.tags)}".lower() | |
| if query_lower in text: | |
| results.append(entry) | |
| if len(results) >= limit: | |
| return results | |
| return results | |
| # ββ persistence ββββββββββββββββββββββββββββββββββββββββββ | |
| def _entry_path(self, entry_id: str, session_id: str) -> Path: | |
| d = self.base_dir / session_id | |
| d.mkdir(parents=True, exist_ok=True) | |
| return d / f"{entry_id}.md" | |
| def _persist(self, entry: MemoryEntry, session_id: str): | |
| path = self._entry_path(entry.id, session_id) | |
| path.write_text(entry.to_markdown(), encoding="utf-8") | |
| def _load_from_disk(self): | |
| """Bootstrap cache from existing .md files.""" | |
| if not self.base_dir.exists(): | |
| return | |
| for session_dir in self.base_dir.iterdir(): | |
| if not session_dir.is_dir(): | |
| continue | |
| sid = session_dir.name | |
| bucket = self._cache.setdefault(sid, OrderedDict()) | |
| for md_file in sorted(session_dir.glob("*.md")): | |
| try: | |
| text = md_file.read_text(encoding="utf-8") | |
| entry = MemoryEntry.from_markdown(text) | |
| entry.session_id = sid | |
| entry.tier = MemoryTier.SESSION | |
| bucket[entry.id] = entry | |
| except Exception: | |
| pass # skip corrupt files | |