Redac / redac /detect.py
barath19's picture
fix: graceful regex-only fallback when GLiNER unavailable
fa40826
Raw
History Blame Contribute Delete
4.84 kB
"""PII detection.
Two complementary detectors, merged into one span list:
1. GLiNER (zero-shot NER, PII-tuned) for fuzzy entities: people, orgs,
addresses, dates of birth, etc.
2. Regex recognizers for high-precision structured identifiers that NER
models get wrong: emails, phones, IBANs, credit cards, IPs, etc.
Everything runs locally (CPU is fine for GLiNER). No external calls.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from functools import lru_cache
from typing import List
# Labels GLiNER is asked to find. Tuned for documents / ID cards.
DEFAULT_LABELS = [
"person",
"organization",
"address",
"date of birth",
"passport number",
"driver license number",
"national id number",
"bank account number",
"phone number",
"email address",
]
# GLiNER PII-tuned checkpoint.
_GLINER_MODEL = "urchade/gliner_multi_pii-v1"
@dataclass
class Entity:
start: int
end: int
text: str
label: str
score: float
source: str # "gliner" or "regex"
# --- Regex recognizers (high precision, deterministic) -----------------------
# Ordered most-specific first; on an equal-score overlap the earlier
# recognizer wins, so the greedy phone pattern is deliberately last.
_REGEX_RECOGNIZERS = [
("email address", re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b")),
("iban", re.compile(r"\b[A-Z]{2}\d{2}[A-Z0-9]{11,30}\b")),
("national id number", re.compile(r"\b\d{3}-\d{2}-\d{4}\b")), # US SSN shape
("credit card number", re.compile(r"\b(?:\d[ -]?){13,16}\b")),
("ip address", re.compile(r"\b(?:\d{1,3}\.){3}\d{1,3}\b")),
("phone number", re.compile(r"(?<!\w)(?:\+?\d{1,3}[\s.-]?)?(?:\(?\d{2,4}\)?[\s.-]?){2,4}\d{2,4}(?!\w)")),
]
# Common date shapes the greedy phone pattern would otherwise swallow.
_DATE_RE = re.compile(
r"^\d{4}[./-]\d{1,2}[./-]\d{1,2}$|^\d{1,2}[./-]\d{1,2}[./-]\d{2,4}$"
)
def _regex_entities(text: str) -> List[Entity]:
out: List[Entity] = []
for label, pattern in _REGEX_RECOGNIZERS:
for m in pattern.finditer(text):
span = m.group().strip()
if len(span) < 4:
continue
# Phone recognizer: ignore dates and short digit runs (serials).
if label == "phone number":
if _DATE_RE.match(span):
continue
if sum(c.isdigit() for c in span) < 7:
continue
out.append(
Entity(
start=m.start(),
end=m.start() + len(span),
text=span,
label=label,
score=1.0,
source="regex",
)
)
return out
# --- GLiNER ------------------------------------------------------------------
@lru_cache(maxsize=1)
def gliner_available() -> bool:
"""True if GLiNER can be imported. Lets the app fall back to regex-only
(and tell the user) instead of crashing when GLiNER isn't installed."""
try:
import gliner # noqa: F401
return True
except Exception:
return False
@lru_cache(maxsize=1)
def _load_gliner():
from gliner import GLiNER
return GLiNER.from_pretrained(_GLINER_MODEL)
def _gliner_entities(text: str, labels: List[str], threshold: float) -> List[Entity]:
model = _load_gliner()
preds = model.predict_entities(text, labels, threshold=threshold)
return [
Entity(
start=p["start"],
end=p["end"],
text=p["text"],
label=p["label"],
score=float(p.get("score", 0.0)),
source="gliner",
)
for p in preds
]
# --- Merge -------------------------------------------------------------------
def _resolve_overlaps(entities: List[Entity]) -> List[Entity]:
"""Keep the highest-scoring entity when spans overlap; regex (score 1.0)
wins ties, which is what we want for structured identifiers."""
ordered = sorted(entities, key=lambda e: (-e.score, e.start))
kept: List[Entity] = []
for e in ordered:
if any(not (e.end <= k.start or e.start >= k.end) for k in kept):
continue
kept.append(e)
return sorted(kept, key=lambda e: e.start)
def detect_entities(
text: str,
labels: List[str] | None = None,
threshold: float = 0.45,
use_gliner: bool = True,
) -> List[Entity]:
"""Return de-duplicated PII spans found in `text`, sorted by position."""
if not text or not text.strip():
return []
labels = labels or DEFAULT_LABELS
found: List[Entity] = _regex_entities(text)
if use_gliner and gliner_available():
found.extend(_gliner_entities(text, labels, threshold))
return _resolve_overlaps(found)