File size: 3,312 Bytes
2bf9c60 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | #!/usr/bin/env python3
"""Minimal runnable example for FrameByFrame/korean-pii-e5-base.
pip install "transformers>=4.40" torch safetensors
python usage.py
"""
import os, re
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
MODEL_ID = os.environ.get("MODEL_ID", "FrameByFrame/korean-pii-e5-base")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForTokenClassification.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
model.eval()
if torch.cuda.is_available():
model.cuda()
_TRAILING_JOSA = ["μ΄μμ","μ΄λΌκ³ ","μ
λλ€","μ΄μΌ","μ΄λ","νν
","μκ²","μΌλ‘","μ΄κ°","μ΄λ",
"μμ","μ΄κ³ ","μμ","μ¨","λ","μ΄","κ°","μ","λ","μ","λ₯Ό","μΌ","μ","μ","μ","λ","κ»","κ³ "]
_DATE_END = re.compile(r".*(?:μΌ|[0-9])", re.S)
def _normalize(text, label, s, e):
while s < e and text[s] in " .,\t\n": s += 1
while e > s and text[e - 1] in " .,\t\n": e -= 1
if label == "private_date":
m = _DATE_END.match(text[s:e])
if m and m.end() > 0:
e = s + m.end()
elif label in ("private_person", "personal_handle", "private_address"):
for _ in range(2):
seg = text[s:e]
for j in _TRAILING_JOSA:
if seg.endswith(j) and (e - s) - len(j) >= 2:
e -= len(j); break
else:
break
return s, e
def extract_pii(text, max_length=256):
enc = tokenizer(text, truncation=True, max_length=max_length,
return_offsets_mapping=True, return_tensors="pt")
offsets = enc.pop("offset_mapping")[0].tolist()
with torch.no_grad():
logits = model(**{k: v.to(model.device) for k, v in enc.items()}).logits
pred = logits.argmax(-1)[0].tolist()
id2label = model.config.id2label
spans, active = [], None
for i, lid in enumerate(pred):
label = id2label[int(lid)]
cs, ce = offsets[i]
if cs == ce:
if active: spans.append(active); active = None
continue
if label == "O":
if active: spans.append(active); active = None
continue
prefix, cat = label.split("-", 1)
if prefix in ("B", "S") or not active or active[0] != cat:
if active: spans.append(active)
active = [cat, cs, ce]
else:
active[2] = ce
if active: spans.append(active)
out = []
for cat, s, e in spans:
s, e = _normalize(text, cat, s, e)
if text[s:e].strip():
out.append({"label": cat, "start": s, "end": e, "text": text[s:e]})
return out
def redact(text):
spans = sorted(extract_pii(text), key=lambda s: s["start"], reverse=True)
for s in spans:
text = text[:s["start"]] + f"[{s['label'].upper()}]" + text[s["end"]:]
return text
if __name__ == "__main__":
for t in ["κΉλ―Όμλμ λ²νΈλ 010-1234-5678μ
λλ€.",
"κ³μ’ 110-234-567890μΌλ‘ μ
κΈνκ³ minsu@example.comμΌλ‘ μλ €μ£ΌμΈμ.",
"μ΄μμ§ κ³ κ°λ μλ
μμΌμ 1985λ
3μ 12μΌμ
λλ€."]:
print(t)
for sp in extract_pii(t):
print(f" {sp['label']:16} {sp['text']!r}")
print(" REDACT:", redact(t)); print()
|