File size: 4,490 Bytes
c59510c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""Embedding cache for stealth + novelty rewards.

Wraps ``sentence-transformers/all-MiniLM-L6-v2`` (~80MB, runs fine on
Mac CPU) and pre-computes an embedding for every benign reference in
``scenarios/benign_refs.jsonl``. The reward function uses these to
score how closely a candidate payload resembles the benign distribution
of its slot ("stealth") and how different it is from recent attacker
outputs ("novelty").

The model and reference embeddings are loaded lazily on first use so
unit tests that don't need them avoid the 80MB download.
"""

from __future__ import annotations

import json
from pathlib import Path
from typing import Dict, List, Optional, Sequence

import numpy as np


DEFAULT_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
DEFAULT_REFS_PATH = (
    Path(__file__).resolve().parent.parent.parent / "scenarios" / "benign_refs.jsonl"
)


class EmbeddingCache:
    """Lazy-loaded sentence-transformer with per-channel benign references."""

    def __init__(
        self,
        refs_path: Path | str = DEFAULT_REFS_PATH,
        model_name: str = DEFAULT_MODEL,
    ) -> None:
        self.refs_path = Path(refs_path)
        self.model_name = model_name
        self._model = None  # loaded on first .encode() call
        self._channel_refs: Dict[str, List[str]] = self._load_refs(self.refs_path)
        self._channel_vecs: Dict[str, np.ndarray] = {}

    @staticmethod
    def _load_refs(path: Path) -> Dict[str, List[str]]:
        if not path.exists():
            raise FileNotFoundError(f"benign_refs not found at {path}")
        out: Dict[str, List[str]] = {}
        with path.open(encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                rec = json.loads(line)
                out.setdefault(rec["channel"], []).append(rec["text"])
        return out

    # ------------------------------------------------------------------
    # Lazy loaders
    # ------------------------------------------------------------------

    def _ensure_model(self) -> None:
        if self._model is not None:
            return
        from sentence_transformers import SentenceTransformer

        self._model = SentenceTransformer(self.model_name)

    def _ensure_channel_vecs(self, channel: str) -> np.ndarray:
        if channel in self._channel_vecs:
            return self._channel_vecs[channel]
        if channel not in self._channel_refs:
            raise KeyError(f"no benign refs for channel {channel!r}")
        self._ensure_model()
        vecs = self._model.encode(  # type: ignore[union-attr]
            self._channel_refs[channel],
            normalize_embeddings=True,
            show_progress_bar=False,
        )
        vecs = np.asarray(vecs, dtype=np.float32)
        self._channel_vecs[channel] = vecs
        return vecs

    def _encode(self, texts: Sequence[str]) -> np.ndarray:
        self._ensure_model()
        v = self._model.encode(  # type: ignore[union-attr]
            list(texts),
            normalize_embeddings=True,
            show_progress_bar=False,
        )
        return np.asarray(v, dtype=np.float32)

    # ------------------------------------------------------------------
    # Public scoring API
    # ------------------------------------------------------------------

    def stealth_score(self, payload: str, channel: str) -> float:
        """Max cosine similarity between payload and that channel's benign refs.

        Higher = the payload looks more like benign content for this slot.
        """
        if not payload or not payload.strip():
            return 0.0
        refs = self._ensure_channel_vecs(channel)
        emb = self._encode([payload])[0]
        sims = refs @ emb  # cosine since both sides are unit-normalized
        return float(np.clip(sims.max(), 0.0, 1.0))

    def novelty_score(self, payload: str, recent_payloads: Sequence[str]) -> float:
        """1 - max cosine similarity between payload and any recent payload.

        Higher = more novel. Empty ``recent_payloads`` -> 1.0 (max novelty).
        """
        if not recent_payloads:
            return 1.0
        if not payload or not payload.strip():
            return 0.0
        all_texts = [payload, *recent_payloads]
        vecs = self._encode(all_texts)
        sims = vecs[0] @ vecs[1:].T
        max_sim = float(np.clip(sims.max(), 0.0, 1.0))
        return float(np.clip(1.0 - max_sim, 0.0, 1.0))