"""firewall.py — detect -> mask -> (call) -> restore. Orchestrates Layer 1 (detectors.py, always on) and Layer 2 (model_ner.py, optional/graceful) per ARCHITECTURE in the build spec. Implements the merge rule (union spans, resolve overlaps by confidence/length, never overlapping masks), a reversible in-memory vault, and a stub LLM round-trip. """ from __future__ import annotations from detectors import detect_all # Types produced by the secret-side detectors / NER -> placeholder category _SECRET_TYPES = { "AWS_KEY", "GITHUB_TOKEN", "SLACK_TOKEN", "JWT", "PRIVATE_KEY", "SECRET", } def _category(type_name: str) -> str: return "SECRET" if type_name in _SECRET_TYPES else "PII" def merge_spans(spans: list[dict]) -> list[dict]: """Union spans; on overlap keep the higher-confidence / longer span. Returns a list of non-overlapping spans sorted by start position. """ # Prefer higher confidence, then longer spans, then earlier start — # greedily accept spans in that priority order, skipping anything that # would overlap an already-accepted span. ordered = sorted( spans, key=lambda s: (-s["confidence"], -(s["end"] - s["start"]), s["start"]), ) accepted: list[dict] = [] for span in ordered: if not any(span["start"] < a["end"] and a["start"] < span["end"] for a in accepted): accepted.append(span) return sorted(accepted, key=lambda s: s["start"]) def mask(text: str, spans: list[dict]) -> tuple[str, dict, list[dict]]: """Replace each span with a typed, reversible placeholder. Returns (masked_text, vault, findings) where vault maps placeholder -> original substring, and findings carry no raw values. """ merged = merge_spans(spans) vault: dict[str, str] = {} findings: list[dict] = [] counters: dict[tuple[str, str], int] = {} out: list[str] = [] last = 0 for span in merged: cat = _category(span["type"]) key = (cat, span["type"]) counters[key] = counters.get(key, 0) + 1 placeholder = f"[{cat}_{span['type']}_{counters[key]}]" vault[placeholder] = text[span["start"]:span["end"]] out.append(text[last:span["start"]]) out.append(placeholder) last = span["end"] findings.append({ "type": span["type"], "masked_value": placeholder, "source": span["source"], "confidence": span["confidence"], }) out.append(text[last:]) return "".join(out), vault, findings def restore(text: str, vault: dict[str, str]) -> str: """Swap placeholders back to their original values.""" for placeholder, original in vault.items(): text = text.replace(placeholder, original) return text def sanitize(text: str, pii: bool = True, secrets: bool = True, use_model: bool = True) -> dict: """Run the full DETECT -> MASK pipeline. Returns a dict with sanitized_text, vault, findings, blocked count, and model_status ("on" | "off" | "unavailable") describing whether Layer 2 (the fine-tuned NER model) contributed. """ spans = detect_all(text, pii=pii, secrets=secrets) model_status = "off" if use_model: try: from model_ner import detect_entities model_spans = detect_entities(text) model_status = "on" if model_spans else "unavailable" spans += model_spans except Exception: model_status = "unavailable" sanitized_text, vault, findings = mask(text, spans) return { "sanitized_text": sanitized_text, "sanitized": sanitized_text, "vault": vault, "findings": findings, "blocked": len(findings), "model_status": model_status, } def call_llm_stub(sanitized_text: str) -> str: """A no-API-key-needed stand-in for an LLM call: echoes the sanitized prompt back inside a templated reply, so the round-trip demo never depends on network access or secrets.""" return ( "Thanks for the message. Here's a summary of what I received " f"(sanitized):\n\n{sanitized_text}\n\n" "(This is a stub response — no data left your machine.)" ) def round_trip(text: str, pii: bool = True, secrets: bool = True, use_model: bool = True, llm_fn=None) -> dict: """sanitize -> (call llm_fn or stub) -> restore originals in the reply.""" result = sanitize(text, pii=pii, secrets=secrets, use_model=use_model) llm_fn = llm_fn or call_llm_stub llm_response = llm_fn(result["sanitized_text"]) result["llm_response"] = llm_response result["restored_response"] = restore(llm_response, result["vault"]) return result def audit_log(result: dict) -> dict: """JSON-safe audit log: placeholders + types + sources only, never the vault or any raw value.""" return { "blocked": result["blocked"], "leaked": 0, "findings": [ { "type": f["type"], "masked_value": f["masked_value"], "source": f["source"], "confidence": f["confidence"], } for f in result["findings"] ], }