PrivacyShield / model_ner.py
perceptron01's picture
Upload 6 files
e431b8d verified
Raw
History Blame Contribute Delete
2.66 kB
"""model_ner.py — Layer 2: fine-tuned token classifier (PERSON/ADDRESS/ORG).
Loads the fine-tuned NER adapter from the HF Hub (pushed by
train/train_lora.py) if torch is usable and the model is reachable. On any
load/inference failure — no torch, no adapter pushed yet, no network — this
module degrades to returning [] and the firewall keeps working on Layer 1
alone (build spec §2/§6).
Availability is checked with `transformers.utils.is_torch_available()`,
never a bare `import torch`.
"""
from __future__ import annotations
import os
# The fine-tuned adapter pushed by train/train_lora.py, or override via the
# PRIVACYSHIELD_NER_MODEL env var.
MODEL_ID = os.environ.get("PRIVACYSHIELD_NER_MODEL", "perceptron01/privacyshield-ner")
_NER_TYPES = {"PERSON", "ADDRESS", "ORG"}
_pipeline = None
_unavailable_reason: str | None = None
def _load():
global _pipeline, _unavailable_reason
if _pipeline is not None or _unavailable_reason is not None:
return _pipeline
try:
from transformers.utils import is_torch_available
if not is_torch_available():
raise RuntimeError("torch is not installed/usable (need torch>=2.4)")
from transformers import pipeline
_pipeline = pipeline(
"token-classification",
model=MODEL_ID,
aggregation_strategy="simple",
)
except Exception as exc: # noqa: BLE001 — any failure -> graceful degradation
_unavailable_reason = str(exc)
_pipeline = None
return _pipeline
def is_available() -> bool:
"""True if the fine-tuned model loaded successfully."""
return _load() is not None
def unavailable_reason() -> str | None:
"""Human-readable reason the model isn't available, or None if it is."""
_load()
return _unavailable_reason
def detect_entities(text: str) -> list[dict]:
"""Run Layer 2 NER. Returns [] if the model can't load or run."""
pipe = _load()
if pipe is None:
return []
try:
out = pipe(text)
except Exception: # noqa: BLE001 — inference failure -> graceful degradation
return []
spans = []
for ent in out:
label = str(ent.get("entity_group", "")).upper()
if label not in _NER_TYPES:
continue
spans.append({
"start": ent["start"],
"end": ent["end"],
"type": label,
"value": ent["word"],
"source": "model",
"confidence": round(float(ent["score"]), 2),
})
return spans
# Alias matching the integration doc's naming (`model_ner.detect_ner(text)`).
detect_ner = detect_entities