Spaces:
Sleeping
Sleeping
| """In-memory deterministic session memory for ShadowOps decisions.""" | |
| from __future__ import annotations | |
| from collections import defaultdict, deque | |
| from dataclasses import dataclass, field | |
| from datetime import datetime, timezone | |
| import json | |
| from pathlib import Path | |
| from typing import Any, Iterable | |
| BACKEND_DIR = Path(__file__).resolve().parent | |
| DEFAULT_MEMORY_PATH = BACKEND_DIR / "data" / "session_memory.json" | |
| def _parse_timestamp(value: Any) -> float: | |
| if value is None: | |
| return 0.0 | |
| if isinstance(value, (int, float)): | |
| return float(value) | |
| text = str(value).strip() | |
| if not text: | |
| return 0.0 | |
| if text.isdigit(): | |
| return float(text) | |
| with_z = text.replace("Z", "+00:00") | |
| try: | |
| parsed = datetime.fromisoformat(with_z) | |
| except ValueError: | |
| return 0.0 | |
| if parsed.tzinfo is None: | |
| parsed = parsed.replace(tzinfo=timezone.utc) | |
| return parsed.timestamp() | |
| def _norm(value: str) -> str: | |
| return " ".join(str(value).replace("_", " ").replace("-", " ").lower().split()) | |
| class ActionMemoryRecord: | |
| actor: str | |
| session_id: str | |
| service: str | |
| domain: str | |
| environment: str | |
| timestamp: Any | |
| decision: str | |
| risk_score: float | |
| action_summary: str | |
| indicators: list[str] = field(default_factory=list) | |
| def from_mapping(cls, payload: dict[str, Any]) -> "ActionMemoryRecord": | |
| return cls( | |
| actor=str(payload.get("actor") or "unknown"), | |
| session_id=str(payload.get("session_id") or "default"), | |
| service=str(payload.get("service") or payload.get("domain") or "unknown"), | |
| domain=str(payload.get("domain") or "unknown"), | |
| environment=str(payload.get("environment") or "production"), | |
| timestamp=payload.get("timestamp", 0), | |
| decision=str(payload.get("decision") or payload.get("supervisor_decision") or "UNKNOWN"), | |
| risk_score=float(payload.get("risk_score", 0.0)), | |
| action_summary=str(payload.get("action_summary") or payload.get("raw_payload") or ""), | |
| indicators=list(payload.get("indicators") or []), | |
| ) | |
| def to_mapping(self) -> dict[str, Any]: | |
| return { | |
| "actor": self.actor, | |
| "session_id": self.session_id, | |
| "service": self.service, | |
| "domain": self.domain, | |
| "environment": self.environment, | |
| "timestamp": self.timestamp, | |
| "decision": self.decision, | |
| "risk_score": self.risk_score, | |
| "action_summary": self.action_summary, | |
| "indicators": list(self.indicators), | |
| } | |
| class SessionMemory: | |
| def __init__( | |
| self, | |
| max_actions_per_session: int = 20, | |
| decay_window_seconds: float = 3600.0, | |
| *, | |
| persistence_enabled: bool = True, | |
| storage_path: Path | str = DEFAULT_MEMORY_PATH, | |
| ): | |
| self.max_actions_per_session = max_actions_per_session | |
| self.decay_window_seconds = decay_window_seconds | |
| self.persistence_enabled = persistence_enabled | |
| self.storage_path = Path(storage_path) | |
| self._by_session: dict[str, deque[ActionMemoryRecord]] = defaultdict( | |
| lambda: deque(maxlen=self.max_actions_per_session) | |
| ) | |
| if self.persistence_enabled: | |
| self.load() | |
| def load(self) -> None: | |
| if not self.persistence_enabled or not self.storage_path.exists(): | |
| return | |
| try: | |
| payload = json.loads(self.storage_path.read_text(encoding="utf-8")) | |
| sessions = payload.get("sessions", {}) if isinstance(payload, dict) else {} | |
| for session_id, records in sessions.items(): | |
| queue = self._by_session[str(session_id)] | |
| for item in records[-self.max_actions_per_session:]: | |
| if isinstance(item, dict): | |
| queue.append(ActionMemoryRecord.from_mapping(item)) | |
| except Exception: | |
| self._by_session.clear() | |
| def save(self) -> None: | |
| if not self.persistence_enabled: | |
| return | |
| payload = { | |
| "version": 1, | |
| "max_actions_per_session": self.max_actions_per_session, | |
| "sessions": { | |
| session_id: [record.to_mapping() for record in records] | |
| for session_id, records in self._by_session.items() | |
| }, | |
| } | |
| self.storage_path.parent.mkdir(parents=True, exist_ok=True) | |
| self.storage_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") | |
| def clear(self) -> None: | |
| """Clear all memory records and persist the empty state when enabled.""" | |
| self._by_session.clear() | |
| self.save() | |
| def add_record(self, record: ActionMemoryRecord | dict[str, Any]) -> ActionMemoryRecord: | |
| if isinstance(record, dict): | |
| record = ActionMemoryRecord.from_mapping(record) | |
| self._by_session[record.session_id].append(record) | |
| self.save() | |
| return record | |
| def get_recent_actions(self, session_id: str, limit: int = 10) -> list[ActionMemoryRecord]: | |
| records = list(self._by_session.get(str(session_id), ())) | |
| return records[-limit:] | |
| def _all_records(self) -> list[ActionMemoryRecord]: | |
| records: list[ActionMemoryRecord] = [] | |
| for session_records in self._by_session.values(): | |
| records.extend(session_records) | |
| return records | |
| def _reference_time(self, records: Iterable[ActionMemoryRecord]) -> float: | |
| values = [_parse_timestamp(record.timestamp) for record in records] | |
| return max(values) if values else 0.0 | |
| def _decayed_score(self, records: Iterable[ActionMemoryRecord]) -> float: | |
| rows = list(records) | |
| if not rows: | |
| return 0.0 | |
| reference = self._reference_time(rows) | |
| weighted = [] | |
| for record in rows: | |
| age = max(0.0, reference - _parse_timestamp(record.timestamp)) | |
| decay = max(0.0, 1.0 - age / max(self.decay_window_seconds, 1.0)) | |
| indicator_boost = min(0.25, 0.04 * len(record.indicators)) | |
| decision_boost = 0.10 if record.decision.upper() in {"BLOCK", "FORK", "QUARANTINE"} else 0.0 | |
| weighted.append(min(1.0, record.risk_score + indicator_boost + decision_boost) * decay) | |
| return max(0.0, min(1.0, sum(weighted) / max(len(weighted), 1) + 0.08 * max(0, len(rows) - 1))) | |
| def compute_actor_risk(self, actor: str) -> float: | |
| actor = _norm(actor or "unknown") | |
| return self._decayed_score(record for record in self._all_records() if _norm(record.actor) == actor) | |
| def compute_session_risk(self, session_id: str) -> float: | |
| return self._decayed_score(self._by_session.get(str(session_id), ())) | |
| def compute_service_risk(self, service: str) -> float: | |
| service = _norm(service or "unknown") | |
| return self._decayed_score(record for record in self._all_records() if _norm(record.service) == service) | |
| def _record_tokens(self, record: ActionMemoryRecord) -> set[str]: | |
| text = _norm(record.action_summary + " " + " ".join(record.indicators)) | |
| tokens = set(record.indicators) | |
| if "firewall" in text or "security group" in text or "open port" in text: | |
| tokens.add("firewall_open") | |
| if "admin" in text or "administratoraccess" in text or "privilege" in text: | |
| tokens.add("iam_admin") | |
| if "export" in text or "exfil" in text or "transfer" in text: | |
| tokens.add("data_export") | |
| if "secret" in text and ("ci" in text or "workflow" in text): | |
| tokens.add("ci_secret_access") | |
| if "workflow" in text or "pipeline" in text: | |
| tokens.add("workflow_modification") | |
| if "deploy" in text or "production" in text: | |
| tokens.add("production_deploy") | |
| if "public" in text and ("bucket" in text or "s3" in text): | |
| tokens.add("public_bucket") | |
| if "external transfer" in text or "external" in text and "transfer" in text: | |
| tokens.add("external_transfer") | |
| if "failed auth" in text or "failed login" in text: | |
| tokens.add("failed_auth") | |
| if "production change" in text: | |
| tokens.add("production_change") | |
| return {_norm(token).replace(" ", "_") for token in tokens} | |
| def detect_risky_chains(self, session_id: str) -> list[str]: | |
| sequence = [self._record_tokens(record) for record in self._by_session.get(str(session_id), ())] | |
| chain_specs = [ | |
| ("firewall open -> IAM admin creation -> data export", ["firewall_open", "iam_admin", "data_export"]), | |
| ("CI secret access -> workflow modification -> production deploy", ["ci_secret_access", "workflow_modification", "production_deploy"]), | |
| ("public bucket -> external transfer -> permission escalation", ["public_bucket", "external_transfer", "iam_admin"]), | |
| ("failed auth -> privilege escalation -> production change", ["failed_auth", "iam_admin", "production_change"]), | |
| ] | |
| matches = [] | |
| for name, required in chain_specs: | |
| cursor = 0 | |
| for tokens in sequence: | |
| if required[cursor] in tokens: | |
| cursor += 1 | |
| if cursor == len(required): | |
| matches.append(name) | |
| break | |
| return matches | |
| def summarize_memory_context(self, session_id: str) -> dict[str, Any]: | |
| recent = self.get_recent_actions(session_id, limit=self.max_actions_per_session) | |
| actor = recent[-1].actor if recent else "unknown" | |
| service = recent[-1].service if recent else "unknown" | |
| chains = self.detect_risky_chains(session_id) | |
| return { | |
| "session_id": str(session_id), | |
| "recent_action_count": len(recent), | |
| "actor": actor, | |
| "actor_risk": self.compute_actor_risk(actor), | |
| "session_risk": self.compute_session_risk(session_id), | |
| "service": service, | |
| "service_risk": self.compute_service_risk(service), | |
| "risky_chains": chains, | |
| "recent_decisions": [record.decision for record in recent[-5:]], | |
| "recent_indicators": sorted({indicator for record in recent for indicator in record.indicators}), | |
| } | |
| DEFAULT_MEMORY = SessionMemory() | |
| def add_record(record: ActionMemoryRecord | dict[str, Any]) -> ActionMemoryRecord: | |
| return DEFAULT_MEMORY.add_record(record) | |
| def get_recent_actions(session_id: str, limit: int = 10) -> list[ActionMemoryRecord]: | |
| return DEFAULT_MEMORY.get_recent_actions(session_id, limit) | |
| def compute_actor_risk(actor: str) -> float: | |
| return DEFAULT_MEMORY.compute_actor_risk(actor) | |
| def compute_session_risk(session_id: str) -> float: | |
| return DEFAULT_MEMORY.compute_session_risk(session_id) | |
| def compute_service_risk(service: str) -> float: | |
| return DEFAULT_MEMORY.compute_service_risk(service) | |
| def detect_risky_chains(session_id: str) -> list[str]: | |
| return DEFAULT_MEMORY.detect_risky_chains(session_id) | |
| def summarize_memory_context(session_id: str) -> dict[str, Any]: | |
| return DEFAULT_MEMORY.summarize_memory_context(session_id) | |
| def clear_memory() -> None: | |
| DEFAULT_MEMORY.clear() | |