File size: 11,296 Bytes
d064478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
"""In-memory deterministic session memory for ShadowOps decisions."""

from __future__ import annotations

from collections import defaultdict, deque
from dataclasses import dataclass, field
from datetime import datetime, timezone
import json
from pathlib import Path
from typing import Any, Iterable


BACKEND_DIR = Path(__file__).resolve().parent
DEFAULT_MEMORY_PATH = BACKEND_DIR / "data" / "session_memory.json"


def _parse_timestamp(value: Any) -> float:
    if value is None:
        return 0.0
    if isinstance(value, (int, float)):
        return float(value)
    text = str(value).strip()
    if not text:
        return 0.0
    if text.isdigit():
        return float(text)
    with_z = text.replace("Z", "+00:00")
    try:
        parsed = datetime.fromisoformat(with_z)
    except ValueError:
        return 0.0
    if parsed.tzinfo is None:
        parsed = parsed.replace(tzinfo=timezone.utc)
    return parsed.timestamp()


def _norm(value: str) -> str:
    return " ".join(str(value).replace("_", " ").replace("-", " ").lower().split())


@dataclass(frozen=True)
class ActionMemoryRecord:
    actor: str
    session_id: str
    service: str
    domain: str
    environment: str
    timestamp: Any
    decision: str
    risk_score: float
    action_summary: str
    indicators: list[str] = field(default_factory=list)

    @classmethod
    def from_mapping(cls, payload: dict[str, Any]) -> "ActionMemoryRecord":
        return cls(
            actor=str(payload.get("actor") or "unknown"),
            session_id=str(payload.get("session_id") or "default"),
            service=str(payload.get("service") or payload.get("domain") or "unknown"),
            domain=str(payload.get("domain") or "unknown"),
            environment=str(payload.get("environment") or "production"),
            timestamp=payload.get("timestamp", 0),
            decision=str(payload.get("decision") or payload.get("supervisor_decision") or "UNKNOWN"),
            risk_score=float(payload.get("risk_score", 0.0)),
            action_summary=str(payload.get("action_summary") or payload.get("raw_payload") or ""),
            indicators=list(payload.get("indicators") or []),
        )

    def to_mapping(self) -> dict[str, Any]:
        return {
            "actor": self.actor,
            "session_id": self.session_id,
            "service": self.service,
            "domain": self.domain,
            "environment": self.environment,
            "timestamp": self.timestamp,
            "decision": self.decision,
            "risk_score": self.risk_score,
            "action_summary": self.action_summary,
            "indicators": list(self.indicators),
        }


class SessionMemory:
    def __init__(
        self,
        max_actions_per_session: int = 20,
        decay_window_seconds: float = 3600.0,
        *,
        persistence_enabled: bool = True,
        storage_path: Path | str = DEFAULT_MEMORY_PATH,
    ):
        self.max_actions_per_session = max_actions_per_session
        self.decay_window_seconds = decay_window_seconds
        self.persistence_enabled = persistence_enabled
        self.storage_path = Path(storage_path)
        self._by_session: dict[str, deque[ActionMemoryRecord]] = defaultdict(
            lambda: deque(maxlen=self.max_actions_per_session)
        )
        if self.persistence_enabled:
            self.load()

    def load(self) -> None:
        if not self.persistence_enabled or not self.storage_path.exists():
            return
        try:
            payload = json.loads(self.storage_path.read_text(encoding="utf-8"))
            sessions = payload.get("sessions", {}) if isinstance(payload, dict) else {}
            for session_id, records in sessions.items():
                queue = self._by_session[str(session_id)]
                for item in records[-self.max_actions_per_session:]:
                    if isinstance(item, dict):
                        queue.append(ActionMemoryRecord.from_mapping(item))
        except Exception:
            self._by_session.clear()

    def save(self) -> None:
        if not self.persistence_enabled:
            return
        payload = {
            "version": 1,
            "max_actions_per_session": self.max_actions_per_session,
            "sessions": {
                session_id: [record.to_mapping() for record in records]
                for session_id, records in self._by_session.items()
            },
        }
        self.storage_path.parent.mkdir(parents=True, exist_ok=True)
        self.storage_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")

    def clear(self) -> None:
        """Clear all memory records and persist the empty state when enabled."""

        self._by_session.clear()
        self.save()

    def add_record(self, record: ActionMemoryRecord | dict[str, Any]) -> ActionMemoryRecord:
        if isinstance(record, dict):
            record = ActionMemoryRecord.from_mapping(record)
        self._by_session[record.session_id].append(record)
        self.save()
        return record

    def get_recent_actions(self, session_id: str, limit: int = 10) -> list[ActionMemoryRecord]:
        records = list(self._by_session.get(str(session_id), ()))
        return records[-limit:]

    def _all_records(self) -> list[ActionMemoryRecord]:
        records: list[ActionMemoryRecord] = []
        for session_records in self._by_session.values():
            records.extend(session_records)
        return records

    def _reference_time(self, records: Iterable[ActionMemoryRecord]) -> float:
        values = [_parse_timestamp(record.timestamp) for record in records]
        return max(values) if values else 0.0

    def _decayed_score(self, records: Iterable[ActionMemoryRecord]) -> float:
        rows = list(records)
        if not rows:
            return 0.0
        reference = self._reference_time(rows)
        weighted = []
        for record in rows:
            age = max(0.0, reference - _parse_timestamp(record.timestamp))
            decay = max(0.0, 1.0 - age / max(self.decay_window_seconds, 1.0))
            indicator_boost = min(0.25, 0.04 * len(record.indicators))
            decision_boost = 0.10 if record.decision.upper() in {"BLOCK", "FORK", "QUARANTINE"} else 0.0
            weighted.append(min(1.0, record.risk_score + indicator_boost + decision_boost) * decay)
        return max(0.0, min(1.0, sum(weighted) / max(len(weighted), 1) + 0.08 * max(0, len(rows) - 1)))

    def compute_actor_risk(self, actor: str) -> float:
        actor = _norm(actor or "unknown")
        return self._decayed_score(record for record in self._all_records() if _norm(record.actor) == actor)

    def compute_session_risk(self, session_id: str) -> float:
        return self._decayed_score(self._by_session.get(str(session_id), ()))

    def compute_service_risk(self, service: str) -> float:
        service = _norm(service or "unknown")
        return self._decayed_score(record for record in self._all_records() if _norm(record.service) == service)

    def _record_tokens(self, record: ActionMemoryRecord) -> set[str]:
        text = _norm(record.action_summary + " " + " ".join(record.indicators))
        tokens = set(record.indicators)
        if "firewall" in text or "security group" in text or "open port" in text:
            tokens.add("firewall_open")
        if "admin" in text or "administratoraccess" in text or "privilege" in text:
            tokens.add("iam_admin")
        if "export" in text or "exfil" in text or "transfer" in text:
            tokens.add("data_export")
        if "secret" in text and ("ci" in text or "workflow" in text):
            tokens.add("ci_secret_access")
        if "workflow" in text or "pipeline" in text:
            tokens.add("workflow_modification")
        if "deploy" in text or "production" in text:
            tokens.add("production_deploy")
        if "public" in text and ("bucket" in text or "s3" in text):
            tokens.add("public_bucket")
        if "external transfer" in text or "external" in text and "transfer" in text:
            tokens.add("external_transfer")
        if "failed auth" in text or "failed login" in text:
            tokens.add("failed_auth")
        if "production change" in text:
            tokens.add("production_change")
        return {_norm(token).replace(" ", "_") for token in tokens}

    def detect_risky_chains(self, session_id: str) -> list[str]:
        sequence = [self._record_tokens(record) for record in self._by_session.get(str(session_id), ())]
        chain_specs = [
            ("firewall open -> IAM admin creation -> data export", ["firewall_open", "iam_admin", "data_export"]),
            ("CI secret access -> workflow modification -> production deploy", ["ci_secret_access", "workflow_modification", "production_deploy"]),
            ("public bucket -> external transfer -> permission escalation", ["public_bucket", "external_transfer", "iam_admin"]),
            ("failed auth -> privilege escalation -> production change", ["failed_auth", "iam_admin", "production_change"]),
        ]
        matches = []
        for name, required in chain_specs:
            cursor = 0
            for tokens in sequence:
                if required[cursor] in tokens:
                    cursor += 1
                    if cursor == len(required):
                        matches.append(name)
                        break
        return matches

    def summarize_memory_context(self, session_id: str) -> dict[str, Any]:
        recent = self.get_recent_actions(session_id, limit=self.max_actions_per_session)
        actor = recent[-1].actor if recent else "unknown"
        service = recent[-1].service if recent else "unknown"
        chains = self.detect_risky_chains(session_id)
        return {
            "session_id": str(session_id),
            "recent_action_count": len(recent),
            "actor": actor,
            "actor_risk": self.compute_actor_risk(actor),
            "session_risk": self.compute_session_risk(session_id),
            "service": service,
            "service_risk": self.compute_service_risk(service),
            "risky_chains": chains,
            "recent_decisions": [record.decision for record in recent[-5:]],
            "recent_indicators": sorted({indicator for record in recent for indicator in record.indicators}),
        }


DEFAULT_MEMORY = SessionMemory()


def add_record(record: ActionMemoryRecord | dict[str, Any]) -> ActionMemoryRecord:
    return DEFAULT_MEMORY.add_record(record)


def get_recent_actions(session_id: str, limit: int = 10) -> list[ActionMemoryRecord]:
    return DEFAULT_MEMORY.get_recent_actions(session_id, limit)


def compute_actor_risk(actor: str) -> float:
    return DEFAULT_MEMORY.compute_actor_risk(actor)


def compute_session_risk(session_id: str) -> float:
    return DEFAULT_MEMORY.compute_session_risk(session_id)


def compute_service_risk(service: str) -> float:
    return DEFAULT_MEMORY.compute_service_risk(service)


def detect_risky_chains(session_id: str) -> list[str]:
    return DEFAULT_MEMORY.detect_risky_chains(session_id)


def summarize_memory_context(session_id: str) -> dict[str, Any]:
    return DEFAULT_MEMORY.summarize_memory_context(session_id)


def clear_memory() -> None:
    DEFAULT_MEMORY.clear()