nothingworry's picture
feat: Add short-term conversation memory with TTL for MCP tools
b13e570
raw
history blame
3.43 kB
from __future__ import annotations
import os
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Mapping, Optional
# Short-term memory configuration
# -------------------------------
# These environment variables let you tune behavior without code changes:
# - MCP_MEMORY_MAX_ITEMS: max number of tool outputs to keep per session (default: 10)
# - MCP_MEMORY_TTL_SECONDS: how long entries live before expiring (default: 900 = 15 minutes)
DEFAULT_MAX_ITEMS = int(os.getenv("MCP_MEMORY_MAX_ITEMS", "10"))
DEFAULT_TTL_SECONDS = int(os.getenv("MCP_MEMORY_TTL_SECONDS", "900"))
@dataclass
class MemoryEntry:
ts: float
tool_name: str
output: Any
# NOTE: For safety, this store is intentionally **not** keyed by tenant.
# It is keyed only by a logical session identifier (e.g. chat session ID).
_MEMORY: Dict[str, List[MemoryEntry]] = {}
def _now() -> float:
return time.time()
def extract_session_id(payload: Mapping[str, Any]) -> Optional[str]:
"""
Extract a logical session identifier from the payload.
Supported keys (first match wins):
- \"session_id\"
- \"sessionId\"
- \"conversation_id\"
- \"conversationId\"
Returns:
Normalized session_id string or None if not present.
"""
for key in ("session_id", "sessionId", "conversation_id", "conversationId"):
value = payload.get(key)
if isinstance(value, str):
value = value.strip()
if value:
return value
return None
def _prune_expired(entries: List[MemoryEntry], ttl_seconds: int) -> List[MemoryEntry]:
if not entries:
return entries
cutoff = _now() - ttl_seconds
return [e for e in entries if e.ts >= cutoff]
def add_entry(
session_id: str,
tool_name: str,
output: Any,
max_items: int = DEFAULT_MAX_ITEMS,
ttl_seconds: int = DEFAULT_TTL_SECONDS,
) -> None:
"""
Store a new tool output in short-term memory for this session.
- Keeps only the last `max_items` entries
- Drops entries older than `ttl_seconds`
"""
if not session_id:
return
entries = _MEMORY.get(session_id, [])
entries = _prune_expired(entries, ttl_seconds)
entries.append(MemoryEntry(ts=_now(), tool_name=tool_name, output=output))
# Enforce bounded size: keep the most recent entries
if len(entries) > max_items:
entries = entries[-max_items:]
_MEMORY[session_id] = entries
def get_recent(
session_id: str,
limit: Optional[int] = None,
ttl_seconds: int = DEFAULT_TTL_SECONDS,
) -> List[Dict[str, Any]]:
"""
Return recent, non-expired entries for this session.
Each entry is a dict:
{\"tool\": str, \"timestamp\": float, \"output\": Any}
"""
if not session_id:
return []
entries = _MEMORY.get(session_id, [])
entries = _prune_expired(entries, ttl_seconds)
_MEMORY[session_id] = entries # write back pruned list
if limit is not None and limit > 0:
entries = entries[-limit:]
return [
{
"tool": e.tool_name,
"timestamp": e.ts,
"output": e.output,
}
for e in entries
]
def clear_session(session_id: str) -> None:
"""
Explicitly clear all short-term memory for a session.
Useful when a chat session ends.
"""
if session_id in _MEMORY:
del _MEMORY[session_id]