from __future__ import annotations import os import time from dataclasses import dataclass from typing import Any, Dict, List, Mapping, Optional # Short-term memory configuration # ------------------------------- # These environment variables let you tune behavior without code changes: # - MCP_MEMORY_MAX_ITEMS: max number of tool outputs to keep per session (default: 10) # - MCP_MEMORY_TTL_SECONDS: how long entries live before expiring (default: 900 = 15 minutes) DEFAULT_MAX_ITEMS = int(os.getenv("MCP_MEMORY_MAX_ITEMS", "10")) DEFAULT_TTL_SECONDS = int(os.getenv("MCP_MEMORY_TTL_SECONDS", "900")) @dataclass class MemoryEntry: ts: float tool_name: str output: Any # NOTE: For safety, this store is intentionally **not** keyed by tenant. # It is keyed only by a logical session identifier (e.g. chat session ID). _MEMORY: Dict[str, List[MemoryEntry]] = {} def _now() -> float: return time.time() def extract_session_id(payload: Mapping[str, Any]) -> Optional[str]: """ Extract a logical session identifier from the payload. Supported keys (first match wins): - \"session_id\" - \"sessionId\" - \"conversation_id\" - \"conversationId\" Returns: Normalized session_id string or None if not present. """ for key in ("session_id", "sessionId", "conversation_id", "conversationId"): value = payload.get(key) if isinstance(value, str): value = value.strip() if value: return value return None def _prune_expired(entries: List[MemoryEntry], ttl_seconds: int) -> List[MemoryEntry]: if not entries: return entries cutoff = _now() - ttl_seconds return [e for e in entries if e.ts >= cutoff] def add_entry( session_id: str, tool_name: str, output: Any, max_items: int = DEFAULT_MAX_ITEMS, ttl_seconds: int = DEFAULT_TTL_SECONDS, ) -> None: """ Store a new tool output in short-term memory for this session. - Keeps only the last `max_items` entries - Drops entries older than `ttl_seconds` """ if not session_id: return entries = _MEMORY.get(session_id, []) entries = _prune_expired(entries, ttl_seconds) entries.append(MemoryEntry(ts=_now(), tool_name=tool_name, output=output)) # Enforce bounded size: keep the most recent entries if len(entries) > max_items: entries = entries[-max_items:] _MEMORY[session_id] = entries def get_recent( session_id: str, limit: Optional[int] = None, ttl_seconds: int = DEFAULT_TTL_SECONDS, ) -> List[Dict[str, Any]]: """ Return recent, non-expired entries for this session. Each entry is a dict: {\"tool\": str, \"timestamp\": float, \"output\": Any} """ if not session_id: return [] entries = _MEMORY.get(session_id, []) entries = _prune_expired(entries, ttl_seconds) _MEMORY[session_id] = entries # write back pruned list if limit is not None and limit > 0: entries = entries[-limit:] return [ { "tool": e.tool_name, "timestamp": e.ts, "output": e.output, } for e in entries ] def clear_session(session_id: str) -> None: """ Explicitly clear all short-term memory for a session. Useful when a chat session ends. """ if session_id in _MEMORY: del _MEMORY[session_id]