accujuris-api / app /services /ner_redaction_service.py
arnavam's picture
feat: add English NER (dslim/bert-base-NER) as second pass after IndicNER for PII redaction
6394504
import logging
import re
import threading
from typing import Any, Optional
from app.config import settings
logger = logging.getLogger(__name__)
_MALAYALAM_BLOCK = re.compile(r"[\u0d00-\u0d7f]")
# ── Indic NER pipeline (ai4bharat/IndicNER) ──────────────────────────
_NER_PIPELINE: Any = None
_NER_MODEL_NAME: Optional[str] = None
_NER_INIT_ERROR: Optional[str] = None
_NER_LOCK = threading.Lock()
# ── English NER pipeline (dslim/bert-base-NER) ──────────────────────
_EN_NER_PIPELINE: Any = None
_EN_NER_MODEL_NAME: Optional[str] = None
_EN_NER_INIT_ERROR: Optional[str] = None
_EN_NER_LOCK = threading.Lock()
def _canonical_entity_label(raw_label: str) -> Optional[str]:
if not raw_label:
return None
normalized = raw_label.upper().strip()
for prefix in ("B-", "I-", "S-", "E-"):
if normalized.startswith(prefix):
normalized = normalized[len(prefix):]
break
if normalized in {"PER", "PERSON"}:
return "PERSON"
if normalized in {"LOC", "LOCATION", "GPE", "PLACE"}:
return "LOCATION"
return None
def _target_labels() -> set[str]:
out: set[str] = set()
for token in (settings.PII_REDACTION_TARGET_LABELS or "PER,LOC").split(","):
label = _canonical_entity_label(token)
if label:
out.add(label)
if not out:
out = {"PERSON", "LOCATION"}
return out
def _model_candidates() -> list[str]:
candidates = []
if settings.PII_REDACTION_MODEL:
candidates.append(settings.PII_REDACTION_MODEL.strip())
if settings.PII_REDACTION_FALLBACK_MODEL:
candidates.append(settings.PII_REDACTION_FALLBACK_MODEL.strip())
seen = set()
deduped = []
for model_name in candidates:
if model_name and model_name not in seen:
seen.add(model_name)
deduped.append(model_name)
return deduped
# ── Indic pipeline loader ─────────────────────────────────────────────
def _get_or_load_pipeline():
global _NER_PIPELINE, _NER_MODEL_NAME, _NER_INIT_ERROR
if _NER_PIPELINE is not None or _NER_INIT_ERROR is not None:
return _NER_PIPELINE, _NER_MODEL_NAME, _NER_INIT_ERROR
with _NER_LOCK:
if _NER_PIPELINE is not None or _NER_INIT_ERROR is not None:
return _NER_PIPELINE, _NER_MODEL_NAME, _NER_INIT_ERROR
try:
from transformers import pipeline
except Exception as exc:
_NER_INIT_ERROR = f"transformers import failed: {exc}"
logger.warning("IndicNER unavailable: %s", _NER_INIT_ERROR)
return None, None, _NER_INIT_ERROR
model_errors: list[str] = []
for model_name in _model_candidates():
try:
pipe = pipeline(
task="token-classification",
model=model_name,
tokenizer=model_name,
aggregation_strategy="simple",
)
_NER_PIPELINE = pipe
_NER_MODEL_NAME = model_name
logger.info("Loaded IndicNER model: %s", model_name)
return _NER_PIPELINE, _NER_MODEL_NAME, None
except Exception as exc:
model_errors.append(f"{model_name}: {exc}")
logger.warning("Failed to load IndicNER model '%s': %s", model_name, exc)
_NER_INIT_ERROR = "; ".join(model_errors) if model_errors else "No IndicNER model configured"
return None, None, _NER_INIT_ERROR
# ── English pipeline loader ──────────────────────────────────────────
def _get_or_load_english_pipeline():
global _EN_NER_PIPELINE, _EN_NER_MODEL_NAME, _EN_NER_INIT_ERROR
if _EN_NER_PIPELINE is not None or _EN_NER_INIT_ERROR is not None:
return _EN_NER_PIPELINE, _EN_NER_MODEL_NAME, _EN_NER_INIT_ERROR
with _EN_NER_LOCK:
if _EN_NER_PIPELINE is not None or _EN_NER_INIT_ERROR is not None:
return _EN_NER_PIPELINE, _EN_NER_MODEL_NAME, _EN_NER_INIT_ERROR
model_name = (settings.PII_REDACTION_ENGLISH_MODEL or "").strip()
if not model_name:
_EN_NER_INIT_ERROR = "No English NER model configured"
return None, None, _EN_NER_INIT_ERROR
try:
from transformers import pipeline
except Exception as exc:
_EN_NER_INIT_ERROR = f"transformers import failed: {exc}"
logger.warning("English NER unavailable: %s", _EN_NER_INIT_ERROR)
return None, None, _EN_NER_INIT_ERROR
try:
pipe = pipeline(
task="token-classification",
model=model_name,
tokenizer=model_name,
aggregation_strategy="simple",
)
_EN_NER_PIPELINE = pipe
_EN_NER_MODEL_NAME = model_name
logger.info("Loaded English NER model: %s", model_name)
return _EN_NER_PIPELINE, _EN_NER_MODEL_NAME, None
except Exception as exc:
_EN_NER_INIT_ERROR = f"{model_name}: {exc}"
logger.warning("Failed to load English NER model '%s': %s", model_name, exc)
return None, None, _EN_NER_INIT_ERROR
def _chunk_text(text: str, max_chars: int = 900):
cursor = 0
text_len = len(text)
while cursor < text_len:
end = min(text_len, cursor + max_chars)
if end < text_len:
split_at = text.rfind("\n", cursor, end)
if split_at == -1:
split_at = text.rfind(" ", cursor, end)
if split_at > cursor + (max_chars // 3):
end = split_at + 1
chunk = text[cursor:end]
if chunk:
yield cursor, chunk
cursor = end
def _extract_spans_with_pipeline(text: str, pipe: Any) -> list[dict[str, Any]]:
"""Run a given NER pipeline on text and return entity spans."""
targets = _target_labels()
spans: list[dict[str, Any]] = []
for base_offset, chunk in _chunk_text(text):
if not chunk.strip():
continue
try:
predictions = pipe(chunk)
except Exception as exc:
logger.warning("NER inference failed for a chunk: %s", exc)
continue
for pred in predictions:
label = _canonical_entity_label(
str(pred.get("entity_group") or pred.get("entity") or "")
)
if label not in targets:
continue
start = pred.get("start")
end = pred.get("end")
if start is None or end is None:
continue
try:
start_i = int(start) + base_offset
end_i = int(end) + base_offset
except Exception:
continue
if start_i < 0 or end_i <= start_i or end_i > len(text):
continue
spans.append(
{
"start": start_i,
"end": end_i,
"label": label,
"score": float(pred.get("score", 0.0) or 0.0),
}
)
return spans
def _extract_spans(text: str) -> tuple[list[dict[str, Any]], Optional[str], Optional[str]]:
"""Extract entity spans using the Indic NER pipeline."""
pipe, model_name, init_error = _get_or_load_pipeline()
if pipe is None:
return [], None, init_error
spans = _extract_spans_with_pipeline(text, pipe)
return spans, model_name, None
def _extract_spans_english(text: str) -> tuple[list[dict[str, Any]], Optional[str], Optional[str]]:
"""Extract entity spans using the English NER pipeline."""
pipe, model_name, init_error = _get_or_load_english_pipeline()
if pipe is None:
return [], None, init_error
spans = _extract_spans_with_pipeline(text, pipe)
return spans, model_name, None
def _resolve_overlaps(spans: list[dict[str, Any]]) -> list[dict[str, Any]]:
if not spans:
return []
sorted_spans = sorted(
spans,
key=lambda item: (item["start"], -(item["end"] - item["start"]), -item["score"]),
)
resolved: list[dict[str, Any]] = []
for span in sorted_spans:
if not resolved:
resolved.append(span)
continue
prev = resolved[-1]
if span["start"] < prev["end"]:
prev_size = prev["end"] - prev["start"]
span_size = span["end"] - span["start"]
if (span_size, span["score"]) > (prev_size, prev["score"]):
resolved[-1] = span
continue
resolved.append(span)
return resolved
def _replacement_for_label(label: str) -> str:
if label == "PERSON":
return settings.PII_REDACTION_PERSON_TOKEN or "X"
if label == "LOCATION":
return settings.PII_REDACTION_LOCATION_TOKEN or "Y"
return settings.PII_REDACTION_OTHER_TOKEN or "Z"
def _apply_redaction(text: str, spans: list[dict[str, Any]], counts: dict[str, int]) -> str:
"""Replace entity spans with redaction tokens and update counts."""
resolved_spans = _resolve_overlaps(spans)
parts = []
cursor = 0
for span in resolved_spans:
parts.append(text[cursor:span["start"]])
parts.append(_replacement_for_label(span["label"]))
cursor = span["end"]
if span["label"] in counts:
counts[span["label"]] += 1
parts.append(text[cursor:])
return "".join(parts)
def redact_sensitive_entities(text: str) -> tuple[str, dict[str, Any]]:
if not isinstance(text, str):
return "", {"enabled": False, "applied": False, "reason": "non_string_input", "counts": {"PERSON": 0, "LOCATION": 0}}
counts = {"PERSON": 0, "LOCATION": 0}
if not settings.PII_REDACTION_ENABLED:
return text, {"enabled": False, "applied": False, "reason": "redaction_disabled", "counts": counts}
if not text.strip():
return text, {"enabled": True, "applied": False, "reason": "empty_text", "counts": counts}
models_used = []
# ── Pass 1: Indic NER ─────────────────────────────────────────────
indic_spans, indic_model, indic_error = _extract_spans(text)
if indic_error:
logger.warning("Indic NER init error: %s", indic_error)
if indic_spans:
text = _apply_redaction(text, indic_spans, counts)
models_used.append(indic_model)
# ── Pass 2: English NER ───────────────────────────────────────────
en_spans, en_model, en_error = _extract_spans_english(text)
if en_error:
logger.warning("English NER init error: %s", en_error)
if en_spans:
text = _apply_redaction(text, en_spans, counts)
models_used.append(en_model)
applied = counts["PERSON"] > 0 or counts["LOCATION"] > 0
model_str = " + ".join(models_used) if models_used else (indic_model or en_model or "unknown")
if not applied and indic_error and en_error:
_save_ner_log(text, model_str, applied=False)
return text, {
"enabled": True,
"applied": False,
"model": model_str,
"reason": f"indic: {indic_error}; english: {en_error}",
"counts": counts,
}
if not applied:
_save_ner_log(text, model_str, applied=False)
return text, {
"enabled": True,
"applied": False,
"model": model_str,
"reason": "no_entities_found",
"counts": counts,
}
_save_ner_log(text, model_str, applied=True)
return text, {
"enabled": True,
"applied": True,
"model": model_str,
"reason": None,
"counts": counts,
}
def _save_ner_log(text: str, model_name: Optional[str], applied: bool) -> None:
"""Helper to save NER redacted text to log file"""
try:
from pathlib import Path
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
log_file = log_dir / "ocr_output.log"
status = "Redacted" if applied else "Unredacted (No entities/Error)"
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== NER Output ({model_name}) - {status} ===\n")
f.write(text)
f.write("\n===============================\n\n")
except Exception as e:
logger.error(f"Failed to write NER output to log file: {e}")