File size: 3,312 Bytes
2bf9c60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#!/usr/bin/env python3
"""Minimal runnable example for FrameByFrame/korean-pii-e5-base.

    pip install "transformers>=4.40" torch safetensors
    python usage.py
"""
import os, re
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification

MODEL_ID = os.environ.get("MODEL_ID", "FrameByFrame/korean-pii-e5-base")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForTokenClassification.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
model.eval()
if torch.cuda.is_available():
    model.cuda()

_TRAILING_JOSA = ["μ΄μ—μš”","이라고","μž…λ‹ˆλ‹€","이야","μ΄λž‘","ν•œν…Œ","μ—κ²Œ","으둜","이가","μ΄λŠ”",
                  "μ—μ„œ","이고","μ˜ˆμš”","씨","λ‹˜","이","κ°€","은","λŠ”","을","λ₯Ό","μ•Ό","μ•„","에","의","λž‘","께","κ³ "]
_DATE_END = re.compile(r".*(?:일|[0-9])", re.S)


def _normalize(text, label, s, e):
    while s < e and text[s] in " .,\t\n": s += 1
    while e > s and text[e - 1] in " .,\t\n": e -= 1
    if label == "private_date":
        m = _DATE_END.match(text[s:e])
        if m and m.end() > 0:
            e = s + m.end()
    elif label in ("private_person", "personal_handle", "private_address"):
        for _ in range(2):
            seg = text[s:e]
            for j in _TRAILING_JOSA:
                if seg.endswith(j) and (e - s) - len(j) >= 2:
                    e -= len(j); break
            else:
                break
    return s, e


def extract_pii(text, max_length=256):
    enc = tokenizer(text, truncation=True, max_length=max_length,
                    return_offsets_mapping=True, return_tensors="pt")
    offsets = enc.pop("offset_mapping")[0].tolist()
    with torch.no_grad():
        logits = model(**{k: v.to(model.device) for k, v in enc.items()}).logits
    pred = logits.argmax(-1)[0].tolist()
    id2label = model.config.id2label
    spans, active = [], None
    for i, lid in enumerate(pred):
        label = id2label[int(lid)]
        cs, ce = offsets[i]
        if cs == ce:
            if active: spans.append(active); active = None
            continue
        if label == "O":
            if active: spans.append(active); active = None
            continue
        prefix, cat = label.split("-", 1)
        if prefix in ("B", "S") or not active or active[0] != cat:
            if active: spans.append(active)
            active = [cat, cs, ce]
        else:
            active[2] = ce
    if active: spans.append(active)
    out = []
    for cat, s, e in spans:
        s, e = _normalize(text, cat, s, e)
        if text[s:e].strip():
            out.append({"label": cat, "start": s, "end": e, "text": text[s:e]})
    return out


def redact(text):
    spans = sorted(extract_pii(text), key=lambda s: s["start"], reverse=True)
    for s in spans:
        text = text[:s["start"]] + f"[{s['label'].upper()}]" + text[s["end"]:]
    return text


if __name__ == "__main__":
    for t in ["κΉ€λ―Όμˆ˜λ‹˜μ˜ λ²ˆν˜ΈλŠ” 010-1234-5678μž…λ‹ˆλ‹€.",
              "κ³„μ’Œ 110-234-567890으둜 μž…κΈˆν•˜κ³  minsu@example.com으둜 μ•Œλ €μ£Όμ„Έμš”.",
              "μ΄μˆ˜μ§„ κ³ κ°λ‹˜ 생년월일은 1985λ…„ 3μ›” 12μΌμž…λ‹ˆλ‹€."]:
        print(t)
        for sp in extract_pii(t):
            print(f"   {sp['label']:16} {sp['text']!r}")
        print("   REDACT:", redact(t)); print()