File size: 2,661 Bytes
e431b8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""model_ner.py — Layer 2: fine-tuned token classifier (PERSON/ADDRESS/ORG).

Loads the fine-tuned NER adapter from the HF Hub (pushed by
train/train_lora.py) if torch is usable and the model is reachable. On any
load/inference failure — no torch, no adapter pushed yet, no network — this
module degrades to returning [] and the firewall keeps working on Layer 1
alone (build spec §2/§6).

Availability is checked with `transformers.utils.is_torch_available()`,
never a bare `import torch`.
"""
from __future__ import annotations

import os

# The fine-tuned adapter pushed by train/train_lora.py, or override via the
# PRIVACYSHIELD_NER_MODEL env var.
MODEL_ID = os.environ.get("PRIVACYSHIELD_NER_MODEL", "perceptron01/privacyshield-ner")

_NER_TYPES = {"PERSON", "ADDRESS", "ORG"}

_pipeline = None
_unavailable_reason: str | None = None


def _load():
    global _pipeline, _unavailable_reason
    if _pipeline is not None or _unavailable_reason is not None:
        return _pipeline
    try:
        from transformers.utils import is_torch_available
        if not is_torch_available():
            raise RuntimeError("torch is not installed/usable (need torch>=2.4)")

        from transformers import pipeline
        _pipeline = pipeline(
            "token-classification",
            model=MODEL_ID,
            aggregation_strategy="simple",
        )
    except Exception as exc:  # noqa: BLE001 — any failure -> graceful degradation
        _unavailable_reason = str(exc)
        _pipeline = None
    return _pipeline


def is_available() -> bool:
    """True if the fine-tuned model loaded successfully."""
    return _load() is not None


def unavailable_reason() -> str | None:
    """Human-readable reason the model isn't available, or None if it is."""
    _load()
    return _unavailable_reason


def detect_entities(text: str) -> list[dict]:
    """Run Layer 2 NER. Returns [] if the model can't load or run."""
    pipe = _load()
    if pipe is None:
        return []
    try:
        out = pipe(text)
    except Exception:  # noqa: BLE001 — inference failure -> graceful degradation
        return []

    spans = []
    for ent in out:
        label = str(ent.get("entity_group", "")).upper()
        if label not in _NER_TYPES:
            continue
        spans.append({
            "start": ent["start"],
            "end": ent["end"],
            "type": label,
            "value": ent["word"],
            "source": "model",
            "confidence": round(float(ent["score"]), 2),
        })
    return spans


# Alias matching the integration doc's naming (`model_ner.detect_ner(text)`).
detect_ner = detect_entities