File size: 5,218 Bytes
e431b8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""firewall.py — detect -> mask -> (call) -> restore.

Orchestrates Layer 1 (detectors.py, always on) and Layer 2 (model_ner.py,
optional/graceful) per ARCHITECTURE in the build spec. Implements the merge
rule (union spans, resolve overlaps by confidence/length, never overlapping
masks), a reversible in-memory vault, and a stub LLM round-trip.
"""
from __future__ import annotations

from detectors import detect_all

# Types produced by the secret-side detectors / NER -> placeholder category
_SECRET_TYPES = {
    "AWS_KEY", "GITHUB_TOKEN", "SLACK_TOKEN", "JWT", "PRIVATE_KEY", "SECRET",
}


def _category(type_name: str) -> str:
    return "SECRET" if type_name in _SECRET_TYPES else "PII"


def merge_spans(spans: list[dict]) -> list[dict]:
    """Union spans; on overlap keep the higher-confidence / longer span.

    Returns a list of non-overlapping spans sorted by start position.
    """
    # Prefer higher confidence, then longer spans, then earlier start —
    # greedily accept spans in that priority order, skipping anything that
    # would overlap an already-accepted span.
    ordered = sorted(
        spans,
        key=lambda s: (-s["confidence"], -(s["end"] - s["start"]), s["start"]),
    )
    accepted: list[dict] = []
    for span in ordered:
        if not any(span["start"] < a["end"] and a["start"] < span["end"] for a in accepted):
            accepted.append(span)
    return sorted(accepted, key=lambda s: s["start"])


def mask(text: str, spans: list[dict]) -> tuple[str, dict, list[dict]]:
    """Replace each span with a typed, reversible placeholder.

    Returns (masked_text, vault, findings) where vault maps
    placeholder -> original substring, and findings carry no raw values.
    """
    merged = merge_spans(spans)
    vault: dict[str, str] = {}
    findings: list[dict] = []
    counters: dict[tuple[str, str], int] = {}

    out: list[str] = []
    last = 0
    for span in merged:
        cat = _category(span["type"])
        key = (cat, span["type"])
        counters[key] = counters.get(key, 0) + 1
        placeholder = f"[{cat}_{span['type']}_{counters[key]}]"

        vault[placeholder] = text[span["start"]:span["end"]]
        out.append(text[last:span["start"]])
        out.append(placeholder)
        last = span["end"]

        findings.append({
            "type": span["type"],
            "masked_value": placeholder,
            "source": span["source"],
            "confidence": span["confidence"],
        })

    out.append(text[last:])
    return "".join(out), vault, findings


def restore(text: str, vault: dict[str, str]) -> str:
    """Swap placeholders back to their original values."""
    for placeholder, original in vault.items():
        text = text.replace(placeholder, original)
    return text


def sanitize(text: str, pii: bool = True, secrets: bool = True, use_model: bool = True) -> dict:
    """Run the full DETECT -> MASK pipeline.

    Returns a dict with sanitized_text, vault, findings, blocked count, and
    model_status ("on" | "off" | "unavailable") describing whether Layer 2
    (the fine-tuned NER model) contributed.
    """
    spans = detect_all(text, pii=pii, secrets=secrets)

    model_status = "off"
    if use_model:
        try:
            from model_ner import detect_entities
            model_spans = detect_entities(text)
            model_status = "on" if model_spans else "unavailable"
            spans += model_spans
        except Exception:
            model_status = "unavailable"

    sanitized_text, vault, findings = mask(text, spans)
    return {
        "sanitized_text": sanitized_text,
        "sanitized": sanitized_text,
        "vault": vault,
        "findings": findings,
        "blocked": len(findings),
        "model_status": model_status,
    }


def call_llm_stub(sanitized_text: str) -> str:
    """A no-API-key-needed stand-in for an LLM call: echoes the sanitized
    prompt back inside a templated reply, so the round-trip demo never
    depends on network access or secrets."""
    return (
        "Thanks for the message. Here's a summary of what I received "
        f"(sanitized):\n\n{sanitized_text}\n\n"
        "(This is a stub response — no data left your machine.)"
    )


def round_trip(text: str, pii: bool = True, secrets: bool = True, use_model: bool = True, llm_fn=None) -> dict:
    """sanitize -> (call llm_fn or stub) -> restore originals in the reply."""
    result = sanitize(text, pii=pii, secrets=secrets, use_model=use_model)
    llm_fn = llm_fn or call_llm_stub
    llm_response = llm_fn(result["sanitized_text"])
    result["llm_response"] = llm_response
    result["restored_response"] = restore(llm_response, result["vault"])
    return result


def audit_log(result: dict) -> dict:
    """JSON-safe audit log: placeholders + types + sources only, never the
    vault or any raw value."""
    return {
        "blocked": result["blocked"],
        "leaked": 0,
        "findings": [
            {
                "type": f["type"],
                "masked_value": f["masked_value"],
                "source": f["source"],
                "confidence": f["confidence"],
            }
            for f in result["findings"]
        ],
    }