"""model_ner.py — Layer 2: fine-tuned token classifier (PERSON/ADDRESS/ORG). Loads the fine-tuned NER adapter from the HF Hub (pushed by train/train_lora.py) if torch is usable and the model is reachable. On any load/inference failure — no torch, no adapter pushed yet, no network — this module degrades to returning [] and the firewall keeps working on Layer 1 alone (build spec §2/§6). Availability is checked with `transformers.utils.is_torch_available()`, never a bare `import torch`. """ from __future__ import annotations import os # The fine-tuned adapter pushed by train/train_lora.py, or override via the # PRIVACYSHIELD_NER_MODEL env var. MODEL_ID = os.environ.get("PRIVACYSHIELD_NER_MODEL", "perceptron01/privacyshield-ner") _NER_TYPES = {"PERSON", "ADDRESS", "ORG"} _pipeline = None _unavailable_reason: str | None = None def _load(): global _pipeline, _unavailable_reason if _pipeline is not None or _unavailable_reason is not None: return _pipeline try: from transformers.utils import is_torch_available if not is_torch_available(): raise RuntimeError("torch is not installed/usable (need torch>=2.4)") from transformers import pipeline _pipeline = pipeline( "token-classification", model=MODEL_ID, aggregation_strategy="simple", ) except Exception as exc: # noqa: BLE001 — any failure -> graceful degradation _unavailable_reason = str(exc) _pipeline = None return _pipeline def is_available() -> bool: """True if the fine-tuned model loaded successfully.""" return _load() is not None def unavailable_reason() -> str | None: """Human-readable reason the model isn't available, or None if it is.""" _load() return _unavailable_reason def detect_entities(text: str) -> list[dict]: """Run Layer 2 NER. Returns [] if the model can't load or run.""" pipe = _load() if pipe is None: return [] try: out = pipe(text) except Exception: # noqa: BLE001 — inference failure -> graceful degradation return [] spans = [] for ent in out: label = str(ent.get("entity_group", "")).upper() if label not in _NER_TYPES: continue spans.append({ "start": ent["start"], "end": ent["end"], "type": label, "value": ent["word"], "source": "model", "confidence": round(float(ent["score"]), 2), }) return spans # Alias matching the integration doc's naming (`model_ner.detect_ner(text)`). detect_ner = detect_entities