Spaces:
Sleeping
Sleeping
File size: 5,218 Bytes
e431b8d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | """firewall.py — detect -> mask -> (call) -> restore.
Orchestrates Layer 1 (detectors.py, always on) and Layer 2 (model_ner.py,
optional/graceful) per ARCHITECTURE in the build spec. Implements the merge
rule (union spans, resolve overlaps by confidence/length, never overlapping
masks), a reversible in-memory vault, and a stub LLM round-trip.
"""
from __future__ import annotations
from detectors import detect_all
# Types produced by the secret-side detectors / NER -> placeholder category
_SECRET_TYPES = {
"AWS_KEY", "GITHUB_TOKEN", "SLACK_TOKEN", "JWT", "PRIVATE_KEY", "SECRET",
}
def _category(type_name: str) -> str:
return "SECRET" if type_name in _SECRET_TYPES else "PII"
def merge_spans(spans: list[dict]) -> list[dict]:
"""Union spans; on overlap keep the higher-confidence / longer span.
Returns a list of non-overlapping spans sorted by start position.
"""
# Prefer higher confidence, then longer spans, then earlier start —
# greedily accept spans in that priority order, skipping anything that
# would overlap an already-accepted span.
ordered = sorted(
spans,
key=lambda s: (-s["confidence"], -(s["end"] - s["start"]), s["start"]),
)
accepted: list[dict] = []
for span in ordered:
if not any(span["start"] < a["end"] and a["start"] < span["end"] for a in accepted):
accepted.append(span)
return sorted(accepted, key=lambda s: s["start"])
def mask(text: str, spans: list[dict]) -> tuple[str, dict, list[dict]]:
"""Replace each span with a typed, reversible placeholder.
Returns (masked_text, vault, findings) where vault maps
placeholder -> original substring, and findings carry no raw values.
"""
merged = merge_spans(spans)
vault: dict[str, str] = {}
findings: list[dict] = []
counters: dict[tuple[str, str], int] = {}
out: list[str] = []
last = 0
for span in merged:
cat = _category(span["type"])
key = (cat, span["type"])
counters[key] = counters.get(key, 0) + 1
placeholder = f"[{cat}_{span['type']}_{counters[key]}]"
vault[placeholder] = text[span["start"]:span["end"]]
out.append(text[last:span["start"]])
out.append(placeholder)
last = span["end"]
findings.append({
"type": span["type"],
"masked_value": placeholder,
"source": span["source"],
"confidence": span["confidence"],
})
out.append(text[last:])
return "".join(out), vault, findings
def restore(text: str, vault: dict[str, str]) -> str:
"""Swap placeholders back to their original values."""
for placeholder, original in vault.items():
text = text.replace(placeholder, original)
return text
def sanitize(text: str, pii: bool = True, secrets: bool = True, use_model: bool = True) -> dict:
"""Run the full DETECT -> MASK pipeline.
Returns a dict with sanitized_text, vault, findings, blocked count, and
model_status ("on" | "off" | "unavailable") describing whether Layer 2
(the fine-tuned NER model) contributed.
"""
spans = detect_all(text, pii=pii, secrets=secrets)
model_status = "off"
if use_model:
try:
from model_ner import detect_entities
model_spans = detect_entities(text)
model_status = "on" if model_spans else "unavailable"
spans += model_spans
except Exception:
model_status = "unavailable"
sanitized_text, vault, findings = mask(text, spans)
return {
"sanitized_text": sanitized_text,
"sanitized": sanitized_text,
"vault": vault,
"findings": findings,
"blocked": len(findings),
"model_status": model_status,
}
def call_llm_stub(sanitized_text: str) -> str:
"""A no-API-key-needed stand-in for an LLM call: echoes the sanitized
prompt back inside a templated reply, so the round-trip demo never
depends on network access or secrets."""
return (
"Thanks for the message. Here's a summary of what I received "
f"(sanitized):\n\n{sanitized_text}\n\n"
"(This is a stub response — no data left your machine.)"
)
def round_trip(text: str, pii: bool = True, secrets: bool = True, use_model: bool = True, llm_fn=None) -> dict:
"""sanitize -> (call llm_fn or stub) -> restore originals in the reply."""
result = sanitize(text, pii=pii, secrets=secrets, use_model=use_model)
llm_fn = llm_fn or call_llm_stub
llm_response = llm_fn(result["sanitized_text"])
result["llm_response"] = llm_response
result["restored_response"] = restore(llm_response, result["vault"])
return result
def audit_log(result: dict) -> dict:
"""JSON-safe audit log: placeholders + types + sources only, never the
vault or any raw value."""
return {
"blocked": result["blocked"],
"leaked": 0,
"findings": [
{
"type": f["type"],
"masked_value": f["masked_value"],
"source": f["source"],
"confidence": f["confidence"],
}
for f in result["findings"]
],
}
|