File size: 3,432 Bytes
b13e570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from __future__ import annotations

import os
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Mapping, Optional

# Short-term memory configuration
# -------------------------------
# These environment variables let you tune behavior without code changes:
# - MCP_MEMORY_MAX_ITEMS: max number of tool outputs to keep per session (default: 10)
# - MCP_MEMORY_TTL_SECONDS: how long entries live before expiring (default: 900 = 15 minutes)

DEFAULT_MAX_ITEMS = int(os.getenv("MCP_MEMORY_MAX_ITEMS", "10"))
DEFAULT_TTL_SECONDS = int(os.getenv("MCP_MEMORY_TTL_SECONDS", "900"))


@dataclass
class MemoryEntry:
    ts: float
    tool_name: str
    output: Any


# NOTE: For safety, this store is intentionally **not** keyed by tenant.
# It is keyed only by a logical session identifier (e.g. chat session ID).
_MEMORY: Dict[str, List[MemoryEntry]] = {}


def _now() -> float:
    return time.time()


def extract_session_id(payload: Mapping[str, Any]) -> Optional[str]:
    """
    Extract a logical session identifier from the payload.

    Supported keys (first match wins):
      - \"session_id\"
      - \"sessionId\"
      - \"conversation_id\"
      - \"conversationId\"

    Returns:
        Normalized session_id string or None if not present.
    """
    for key in ("session_id", "sessionId", "conversation_id", "conversationId"):
        value = payload.get(key)
        if isinstance(value, str):
            value = value.strip()
            if value:
                return value
    return None


def _prune_expired(entries: List[MemoryEntry], ttl_seconds: int) -> List[MemoryEntry]:
    if not entries:
        return entries
    cutoff = _now() - ttl_seconds
    return [e for e in entries if e.ts >= cutoff]


def add_entry(
    session_id: str,
    tool_name: str,
    output: Any,
    max_items: int = DEFAULT_MAX_ITEMS,
    ttl_seconds: int = DEFAULT_TTL_SECONDS,
) -> None:
    """
    Store a new tool output in short-term memory for this session.

    - Keeps only the last `max_items` entries
    - Drops entries older than `ttl_seconds`
    """
    if not session_id:
        return

    entries = _MEMORY.get(session_id, [])
    entries = _prune_expired(entries, ttl_seconds)

    entries.append(MemoryEntry(ts=_now(), tool_name=tool_name, output=output))

    # Enforce bounded size: keep the most recent entries
    if len(entries) > max_items:
        entries = entries[-max_items:]

    _MEMORY[session_id] = entries


def get_recent(
    session_id: str,
    limit: Optional[int] = None,
    ttl_seconds: int = DEFAULT_TTL_SECONDS,
) -> List[Dict[str, Any]]:
    """
    Return recent, non-expired entries for this session.

    Each entry is a dict:
      {\"tool\": str, \"timestamp\": float, \"output\": Any}
    """
    if not session_id:
        return []

    entries = _MEMORY.get(session_id, [])
    entries = _prune_expired(entries, ttl_seconds)
    _MEMORY[session_id] = entries  # write back pruned list

    if limit is not None and limit > 0:
        entries = entries[-limit:]

    return [
        {
            "tool": e.tool_name,
            "timestamp": e.ts,
            "output": e.output,
        }
        for e in entries
    ]


def clear_session(session_id: str) -> None:
    """
    Explicitly clear all short-term memory for a session.
    Useful when a chat session ends.
    """
    if session_id in _MEMORY:
        del _MEMORY[session_id]