Spaces:
Sleeping
Sleeping
File size: 3,432 Bytes
b13e570 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from __future__ import annotations
import os
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Mapping, Optional
# Short-term memory configuration
# -------------------------------
# These environment variables let you tune behavior without code changes:
# - MCP_MEMORY_MAX_ITEMS: max number of tool outputs to keep per session (default: 10)
# - MCP_MEMORY_TTL_SECONDS: how long entries live before expiring (default: 900 = 15 minutes)
DEFAULT_MAX_ITEMS = int(os.getenv("MCP_MEMORY_MAX_ITEMS", "10"))
DEFAULT_TTL_SECONDS = int(os.getenv("MCP_MEMORY_TTL_SECONDS", "900"))
@dataclass
class MemoryEntry:
ts: float
tool_name: str
output: Any
# NOTE: For safety, this store is intentionally **not** keyed by tenant.
# It is keyed only by a logical session identifier (e.g. chat session ID).
_MEMORY: Dict[str, List[MemoryEntry]] = {}
def _now() -> float:
return time.time()
def extract_session_id(payload: Mapping[str, Any]) -> Optional[str]:
"""
Extract a logical session identifier from the payload.
Supported keys (first match wins):
- \"session_id\"
- \"sessionId\"
- \"conversation_id\"
- \"conversationId\"
Returns:
Normalized session_id string or None if not present.
"""
for key in ("session_id", "sessionId", "conversation_id", "conversationId"):
value = payload.get(key)
if isinstance(value, str):
value = value.strip()
if value:
return value
return None
def _prune_expired(entries: List[MemoryEntry], ttl_seconds: int) -> List[MemoryEntry]:
if not entries:
return entries
cutoff = _now() - ttl_seconds
return [e for e in entries if e.ts >= cutoff]
def add_entry(
session_id: str,
tool_name: str,
output: Any,
max_items: int = DEFAULT_MAX_ITEMS,
ttl_seconds: int = DEFAULT_TTL_SECONDS,
) -> None:
"""
Store a new tool output in short-term memory for this session.
- Keeps only the last `max_items` entries
- Drops entries older than `ttl_seconds`
"""
if not session_id:
return
entries = _MEMORY.get(session_id, [])
entries = _prune_expired(entries, ttl_seconds)
entries.append(MemoryEntry(ts=_now(), tool_name=tool_name, output=output))
# Enforce bounded size: keep the most recent entries
if len(entries) > max_items:
entries = entries[-max_items:]
_MEMORY[session_id] = entries
def get_recent(
session_id: str,
limit: Optional[int] = None,
ttl_seconds: int = DEFAULT_TTL_SECONDS,
) -> List[Dict[str, Any]]:
"""
Return recent, non-expired entries for this session.
Each entry is a dict:
{\"tool\": str, \"timestamp\": float, \"output\": Any}
"""
if not session_id:
return []
entries = _MEMORY.get(session_id, [])
entries = _prune_expired(entries, ttl_seconds)
_MEMORY[session_id] = entries # write back pruned list
if limit is not None and limit > 0:
entries = entries[-limit:]
return [
{
"tool": e.tool_name,
"timestamp": e.ts,
"output": e.output,
}
for e in entries
]
def clear_session(session_id: str) -> None:
"""
Explicitly clear all short-term memory for a session.
Useful when a chat session ends.
"""
if session_id in _MEMORY:
del _MEMORY[session_id]
|