Pii-marking / test_model.py
Prakritttt's picture
Change workflow [delete hf model and use only wangchang]
f0611dd
import json
import re
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
try:
from langdetect import DetectorFactory, detect
DetectorFactory.seed = 0
HAS_LANGDETECT = True
except Exception:
HAS_LANGDETECT = False
try:
import spacy
HAS_SPACY = True
except Exception:
HAS_SPACY = False
MODEL_NAME = "pythainlp/thainer-corpus-v2-base-model"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME)
hf_ner = pipeline(
"ner",
model=model,
tokenizer=tokenizer,
aggregation_strategy="simple",
)
spacy_en_nlp = None
CANONICAL_TAGS = {
"ID_CARD",
"TIME",
"EMAIL",
"LEN",
"LOCATION",
"ORGANIZATION",
"PERSON",
"PHONE",
"TEMPERATURE",
"URL",
"ZIP",
"MONEY",
"LAW",
"PERCENT",
}
ENTITY_PRIORITY = {
"ID_CARD": 100,
"EMAIL": 95,
"URL": 93,
"PHONE": 92,
"PERSON": 90,
"ORGANIZATION": 80,
"LOCATION": 78,
"MONEY": 75,
"ZIP": 72,
"PERCENT": 70,
"TIME": 65,
"TEMPERATURE": 60,
"LEN": 55,
"LAW": 50,
}
SOURCE_PRIORITY = {
"regex": 4,
"wangchanberta_hf": 2,
"spacy_en": 1,
}
THAI_FUNCTION_WORDS = {
"ที่",
"จาก",
"ใน",
"บน",
"และ",
"หรือ",
"คือ",
"ได้",
"กับ",
"ของ",
}
def _normalize_label(entity: dict) -> str:
"""Return a stable label from model output."""
return (entity.get("entity_group") or entity.get("entity") or "PII").upper()
def init_models() -> None:
"""Lazy-load optional models so the core pipeline can still run without them."""
global spacy_en_nlp
if spacy_en_nlp is None and HAS_SPACY:
try:
spacy_en_nlp = spacy.load("en_core_web_sm")
except Exception:
spacy_en_nlp = None
def detect_text_profile(text: str) -> dict:
"""Return language profile with mode: th/en/mixed."""
thai_chars = len(re.findall(r"[ก-๙]", text))
en_chars = len(re.findall(r"[A-Za-z]", text))
total_chars = max(1, thai_chars + en_chars)
th_ratio = thai_chars / total_chars
en_ratio = en_chars / total_chars
mode = "mixed"
if th_ratio >= 0.65:
mode = "th"
elif en_ratio >= 0.65:
mode = "en"
if HAS_LANGDETECT:
try:
detected = detect(text)
if detected == "th" and th_ratio >= 0.2:
mode = "th"
elif detected == "en" and en_ratio >= 0.2:
mode = "en"
elif detected not in {"th", "en"} and th_ratio > 0.2 and en_ratio > 0.2:
mode = "mixed"
except Exception:
pass
return {
"th_ratio": th_ratio,
"en_ratio": en_ratio,
"mode": mode,
}
def _to_tag(label: str) -> str:
"""Map model labels to requested canonical tags."""
if "PER" in label or "PERSON" in label or "NAME" in label:
return "PERSON"
if "ORG" in label or "COMPANY" in label:
return "ORGANIZATION"
if "LOC" in label or "ADDRESS" in label or "GPE" in label:
return "LOCATION"
if "EMAIL" in label or "MAIL" in label:
return "EMAIL"
if "URL" in label or "WEB" in label:
return "URL"
if "ZIP" in label or "POSTAL" in label:
return "ZIP"
if "MONEY" in label or "PRICE" in label or "CURRENCY" in label:
return "MONEY"
if "LAW" in label or "ACT" in label:
return "LAW"
if "PERCENT" in label:
return "PERCENT"
if "TEMP" in label:
return "TEMPERATURE"
if "LEN" in label or "LENGTH" in label or "AGE" in label:
return "LEN"
if "TIME" in label:
return "TIME"
if "DATE" in label or "DATA" in label:
return "DATA"
if "ID_CARD" in label or ("ID" in label and "CARD" in label):
return "ID_CARD"
if "PHONE" in label or "TEL" in label or "MOBILE" in label:
return "PHONE"
return "PERSON"
def normalize_entity(raw: dict, source: str, text: str) -> dict | None:
"""Normalize source-specific entity shape into common schema."""
start, end = _trim_span(text, int(raw["start"]), int(raw["end"]))
if start >= end:
return None
label = _to_tag(raw.get("label", ""))
if label == "DATA":
# Business rule: drop DATE/DATA-like entities.
return None
value = text[start:end]
if _is_noise_span(value, label, source):
return None
return {
"start": start,
"end": end,
"label": label,
"source": source,
}
def _is_phone_number(text: str, label: str) -> bool:
"""Skip masking for phone numbers as requested."""
digits = re.sub(r"\D", "", text)
has_phone_like_length = 9 <= len(digits) <= 12
starts_with_zero_or_plus = text.strip().startswith("0") or text.strip().startswith("+")
phone_label = any(key in label for key in ("PHONE", "TEL", "MOBILE"))
return phone_label or (has_phone_like_length and starts_with_zero_or_plus)
def _trim_span(text: str, start: int, end: int) -> tuple:
"""Trim whitespace around entity span to reduce broken tags."""
while start < end and text[start].isspace():
start += 1
while end > start and text[end - 1].isspace():
end -= 1
return start, end
def _is_noise_span(value: str, tag: str, source: str) -> bool:
"""Drop low-information spans that often appear from model noise."""
stripped = value.strip()
if not stripped:
return True
if len(stripped) == 1 and tag not in {"ZIP", "PERCENT", "MONEY", "LEN"}:
return True
if source in {"model", "wangchanberta_hf"} and stripped in THAI_FUNCTION_WORDS:
return True
if re.fullmatch(r"[\W_]+", stripped):
return True
return False
def _collect_regex_entities(text: str) -> list:
"""Force-detect patterns that are often missed by NER."""
patterns = [
("ID_CARD", r"(?<!\d)(?:\d{13}|\d-\d{4}-\d{5}-\d{2}-\d)(?!\d)"),
("EMAIL", r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}"),
("URL", r"https?://[^\s]+|www\.[^\s]+"),
("PERCENT", r"\b\d+(?:\.\d+)?%"),
("MONEY", r"(?:฿|THB\s?)\d+(?:,\d{3})*(?:\.\d+)?|\b\d+(?:,\d{3})*(?:\.\d+)?\s?บาท"),
("ZIP", r"\b\d{5}\b"),
("TIME", r"\b\d{1,2}:\d{2}\b"),
("TEMPERATURE", r"\b\d+(?:\.\d+)?\s?(?:°C|°F|องศา)\b"),
("PHONE", r"(?:\+66|0)\d(?:[\s-]?\d){7,9}"),
("PERSON", r"(?:นาย|นางสาว|นาง|ดร\.?|คุณ)\s?[ก-๙]{2,}(?:\s[ก-๙]{2,}){0,2}"),
("PERSON", r"\b(?:Mr|Mrs|Ms|Miss|Dr)\.?\s+[A-Z][a-z]+(?:\s+[A-Z][a-z]+){0,2}\b"),
]
found = []
for label, pattern in patterns:
for match in re.finditer(pattern, text, flags=re.IGNORECASE):
found.append(
{
"start": match.start(),
"end": match.end(),
"label": label,
"source": "regex",
}
)
return found
def collect_model_entities_th(text: str) -> list:
"""Collect Thai entities from HF model only."""
found = []
hf_entities = hf_ner(text)
for entity in hf_entities:
found.append(
{
"start": entity["start"],
"end": entity["end"],
"label": _normalize_label(entity),
"source": "wangchanberta_hf",
}
)
return found
def collect_model_entities_en(text: str) -> list:
"""Collect English entities from spaCy if available."""
if spacy_en_nlp is None:
return []
found = []
doc = spacy_en_nlp(text)
for ent in doc.ents:
found.append(
{
"start": ent.start_char,
"end": ent.end_char,
"label": ent.label_,
"source": "spacy_en",
}
)
return found
def _entity_score(entity: dict) -> tuple:
return (
ENTITY_PRIORITY.get(entity["label"], 0),
SOURCE_PRIORITY.get(entity["source"], 0),
entity["end"] - entity["start"],
)
def _is_overlap(a: dict, b: dict) -> bool:
return not (a["end"] <= b["start"] or b["end"] <= a["start"])
def merge_entities_with_priority(entities: list) -> list:
"""Resolve overlap by label priority, source confidence, then span length."""
if not entities:
return []
entities = sorted(
entities,
key=lambda x: (
x["start"],
-_entity_score(x)[0],
-_entity_score(x)[1],
-_entity_score(x)[2],
),
)
resolved = []
for item in entities:
conflict_idx = None
for idx, kept in enumerate(resolved):
if _is_overlap(item, kept):
conflict_idx = idx
break
if conflict_idx is None:
resolved.append(item)
continue
kept = resolved[conflict_idx]
if item["label"] == kept["label"]:
# Same label: keep longer span.
if (item["end"] - item["start"]) > (kept["end"] - kept["start"]):
resolved[conflict_idx] = item
elif _entity_score(item) > _entity_score(kept):
# Different labels: use priority score.
resolved[conflict_idx] = item
return sorted(resolved, key=lambda x: x["start"])
def _merge_entities(text: str) -> list:
"""Collect from all enabled sources, normalize, and merge by priority."""
init_models()
profile = detect_text_profile(text)
mode = profile["mode"]
raw_entities = []
raw_entities.extend(_collect_regex_entities(text))
if mode in {"th", "mixed"}:
raw_entities.extend(collect_model_entities_th(text))
# Run spaCy when text is English/mixed or still contains enough English signal.
if mode in {"en", "mixed"} or profile["en_ratio"] >= 0.12:
raw_entities.extend(collect_model_entities_en(text))
normalized = []
for entity in raw_entities:
cleaned = normalize_entity(entity, entity["source"], text)
if cleaned is not None:
normalized.append(cleaned)
return merge_entities_with_priority(normalized)
def mark_pii_context(text: str) -> dict:
"""Return tagged sentence and long context, excluding phone masking."""
entities = _merge_entities(text)
parts = []
cursor = 0
tagged_entities = []
last_was_placeholder = False
last_placeholder_tag = None
for entity in entities:
start, end = entity["start"], entity["end"]
value = text[start:end]
label = entity["label"]
if start < cursor:
continue
if _is_noise_span(value, label, entity.get("source", "model")):
continue
parts.append(text[cursor:start])
last_was_placeholder = False
last_placeholder_tag = None
if _is_phone_number(value, label):
parts.append(value)
action = "not_marked_phone"
tag_name = "PHONE"
else:
tag_name = label if label in CANONICAL_TAGS else _to_tag(label)
if tag_name == "DATA":
# Defensive guard: date-like entity should not be masked.
parts.append(value)
action = "kept_original"
tagged_entities.append(
{
"text": value,
"label": label,
"tag": tag_name,
"start": start,
"end": end,
"action": action,
"source": entity.get("source", "model"),
}
)
cursor = end
continue
placeholder = f"<{tag_name}>"
if not (last_was_placeholder and last_placeholder_tag == tag_name):
parts.append(placeholder)
action = "marked"
last_was_placeholder = True
last_placeholder_tag = tag_name
tagged_entities.append(
{
"text": value,
"label": label,
"tag": tag_name,
"start": start,
"end": end,
"action": action,
"source": entity.get("source", "model"),
}
)
cursor = end
parts.append(text[cursor:])
tagged_text = re.sub(r"\s+", " ", "".join(parts)).strip()
context = (
"PII Context Output\n"
f"Original: {text}\n\n"
f"Tagged: {tagged_text}\n\n"
"Entities:\n"
f"{json.dumps(tagged_entities, ensure_ascii=False, indent=2)}"
)
return {
"original_text": text,
"sentence": tagged_text,
"tagged_text": tagged_text,
"entities": tagged_entities,
"context": context,
}
if __name__ == "__main__":
sample_text = (
"ฉันชื่อ นางสาวมะลิวา บุญสระดี อาศัยอยู่ที่อำเภอนางรอง จังหวัดบุรีรัมย์ "
"เรียนจบจากมหาวิทยาลัยขอนแก่น ติดต่อได้ที่ 089-123-4567 และอีเมล abc@example.com"
)
output = mark_pii_context(sample_text)
print(output["sentence"])
# ในไฟล์ test_model.py
# ใส่ local_files_only=True เพื่อรองรับคำสั่ง Offline ที่คุณจะใช้
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, local_files_only=True)
model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME, local_files_only=True)
ner = pipeline(
"ner",
model=model,
tokenizer=tokenizer,
aggregation_strategy="simple",
)