Spaces:
Running
Running
| import json | |
| import re | |
| from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline | |
| try: | |
| from langdetect import DetectorFactory, detect | |
| DetectorFactory.seed = 0 | |
| HAS_LANGDETECT = True | |
| except Exception: | |
| HAS_LANGDETECT = False | |
| try: | |
| import spacy | |
| HAS_SPACY = True | |
| except Exception: | |
| HAS_SPACY = False | |
| MODEL_NAME = "pythainlp/thainer-corpus-v2-base-model" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME) | |
| hf_ner = pipeline( | |
| "ner", | |
| model=model, | |
| tokenizer=tokenizer, | |
| aggregation_strategy="simple", | |
| ) | |
| spacy_en_nlp = None | |
| CANONICAL_TAGS = { | |
| "ID_CARD", | |
| "TIME", | |
| "EMAIL", | |
| "LEN", | |
| "LOCATION", | |
| "ORGANIZATION", | |
| "PERSON", | |
| "PHONE", | |
| "TEMPERATURE", | |
| "URL", | |
| "ZIP", | |
| "MONEY", | |
| "LAW", | |
| "PERCENT", | |
| } | |
| ENTITY_PRIORITY = { | |
| "ID_CARD": 100, | |
| "EMAIL": 95, | |
| "URL": 93, | |
| "PHONE": 92, | |
| "PERSON": 90, | |
| "ORGANIZATION": 80, | |
| "LOCATION": 78, | |
| "MONEY": 75, | |
| "ZIP": 72, | |
| "PERCENT": 70, | |
| "TIME": 65, | |
| "TEMPERATURE": 60, | |
| "LEN": 55, | |
| "LAW": 50, | |
| } | |
| SOURCE_PRIORITY = { | |
| "regex": 4, | |
| "wangchanberta_hf": 2, | |
| "spacy_en": 1, | |
| } | |
| THAI_FUNCTION_WORDS = { | |
| "ที่", | |
| "จาก", | |
| "ใน", | |
| "บน", | |
| "และ", | |
| "หรือ", | |
| "คือ", | |
| "ได้", | |
| "กับ", | |
| "ของ", | |
| } | |
| def _normalize_label(entity: dict) -> str: | |
| """Return a stable label from model output.""" | |
| return (entity.get("entity_group") or entity.get("entity") or "PII").upper() | |
| def init_models() -> None: | |
| """Lazy-load optional models so the core pipeline can still run without them.""" | |
| global spacy_en_nlp | |
| if spacy_en_nlp is None and HAS_SPACY: | |
| try: | |
| spacy_en_nlp = spacy.load("en_core_web_sm") | |
| except Exception: | |
| spacy_en_nlp = None | |
| def detect_text_profile(text: str) -> dict: | |
| """Return language profile with mode: th/en/mixed.""" | |
| thai_chars = len(re.findall(r"[ก-๙]", text)) | |
| en_chars = len(re.findall(r"[A-Za-z]", text)) | |
| total_chars = max(1, thai_chars + en_chars) | |
| th_ratio = thai_chars / total_chars | |
| en_ratio = en_chars / total_chars | |
| mode = "mixed" | |
| if th_ratio >= 0.65: | |
| mode = "th" | |
| elif en_ratio >= 0.65: | |
| mode = "en" | |
| if HAS_LANGDETECT: | |
| try: | |
| detected = detect(text) | |
| if detected == "th" and th_ratio >= 0.2: | |
| mode = "th" | |
| elif detected == "en" and en_ratio >= 0.2: | |
| mode = "en" | |
| elif detected not in {"th", "en"} and th_ratio > 0.2 and en_ratio > 0.2: | |
| mode = "mixed" | |
| except Exception: | |
| pass | |
| return { | |
| "th_ratio": th_ratio, | |
| "en_ratio": en_ratio, | |
| "mode": mode, | |
| } | |
| def _to_tag(label: str) -> str: | |
| """Map model labels to requested canonical tags.""" | |
| if "PER" in label or "PERSON" in label or "NAME" in label: | |
| return "PERSON" | |
| if "ORG" in label or "COMPANY" in label: | |
| return "ORGANIZATION" | |
| if "LOC" in label or "ADDRESS" in label or "GPE" in label: | |
| return "LOCATION" | |
| if "EMAIL" in label or "MAIL" in label: | |
| return "EMAIL" | |
| if "URL" in label or "WEB" in label: | |
| return "URL" | |
| if "ZIP" in label or "POSTAL" in label: | |
| return "ZIP" | |
| if "MONEY" in label or "PRICE" in label or "CURRENCY" in label: | |
| return "MONEY" | |
| if "LAW" in label or "ACT" in label: | |
| return "LAW" | |
| if "PERCENT" in label: | |
| return "PERCENT" | |
| if "TEMP" in label: | |
| return "TEMPERATURE" | |
| if "LEN" in label or "LENGTH" in label or "AGE" in label: | |
| return "LEN" | |
| if "TIME" in label: | |
| return "TIME" | |
| if "DATE" in label or "DATA" in label: | |
| return "DATA" | |
| if "ID_CARD" in label or ("ID" in label and "CARD" in label): | |
| return "ID_CARD" | |
| if "PHONE" in label or "TEL" in label or "MOBILE" in label: | |
| return "PHONE" | |
| return "PERSON" | |
| def normalize_entity(raw: dict, source: str, text: str) -> dict | None: | |
| """Normalize source-specific entity shape into common schema.""" | |
| start, end = _trim_span(text, int(raw["start"]), int(raw["end"])) | |
| if start >= end: | |
| return None | |
| label = _to_tag(raw.get("label", "")) | |
| if label == "DATA": | |
| # Business rule: drop DATE/DATA-like entities. | |
| return None | |
| value = text[start:end] | |
| if _is_noise_span(value, label, source): | |
| return None | |
| return { | |
| "start": start, | |
| "end": end, | |
| "label": label, | |
| "source": source, | |
| } | |
| def _is_phone_number(text: str, label: str) -> bool: | |
| """Skip masking for phone numbers as requested.""" | |
| digits = re.sub(r"\D", "", text) | |
| has_phone_like_length = 9 <= len(digits) <= 12 | |
| starts_with_zero_or_plus = text.strip().startswith("0") or text.strip().startswith("+") | |
| phone_label = any(key in label for key in ("PHONE", "TEL", "MOBILE")) | |
| return phone_label or (has_phone_like_length and starts_with_zero_or_plus) | |
| def _trim_span(text: str, start: int, end: int) -> tuple: | |
| """Trim whitespace around entity span to reduce broken tags.""" | |
| while start < end and text[start].isspace(): | |
| start += 1 | |
| while end > start and text[end - 1].isspace(): | |
| end -= 1 | |
| return start, end | |
| def _is_noise_span(value: str, tag: str, source: str) -> bool: | |
| """Drop low-information spans that often appear from model noise.""" | |
| stripped = value.strip() | |
| if not stripped: | |
| return True | |
| if len(stripped) == 1 and tag not in {"ZIP", "PERCENT", "MONEY", "LEN"}: | |
| return True | |
| if source in {"model", "wangchanberta_hf"} and stripped in THAI_FUNCTION_WORDS: | |
| return True | |
| if re.fullmatch(r"[\W_]+", stripped): | |
| return True | |
| return False | |
| def _collect_regex_entities(text: str) -> list: | |
| """Force-detect patterns that are often missed by NER.""" | |
| patterns = [ | |
| ("ID_CARD", r"(?<!\d)(?:\d{13}|\d-\d{4}-\d{5}-\d{2}-\d)(?!\d)"), | |
| ("EMAIL", r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}"), | |
| ("URL", r"https?://[^\s]+|www\.[^\s]+"), | |
| ("PERCENT", r"\b\d+(?:\.\d+)?%"), | |
| ("MONEY", r"(?:฿|THB\s?)\d+(?:,\d{3})*(?:\.\d+)?|\b\d+(?:,\d{3})*(?:\.\d+)?\s?บาท"), | |
| ("ZIP", r"\b\d{5}\b"), | |
| ("TIME", r"\b\d{1,2}:\d{2}\b"), | |
| ("TEMPERATURE", r"\b\d+(?:\.\d+)?\s?(?:°C|°F|องศา)\b"), | |
| ("PHONE", r"(?:\+66|0)\d(?:[\s-]?\d){7,9}"), | |
| ("PERSON", r"(?:นาย|นางสาว|นาง|ดร\.?|คุณ)\s?[ก-๙]{2,}(?:\s[ก-๙]{2,}){0,2}"), | |
| ("PERSON", r"\b(?:Mr|Mrs|Ms|Miss|Dr)\.?\s+[A-Z][a-z]+(?:\s+[A-Z][a-z]+){0,2}\b"), | |
| ] | |
| found = [] | |
| for label, pattern in patterns: | |
| for match in re.finditer(pattern, text, flags=re.IGNORECASE): | |
| found.append( | |
| { | |
| "start": match.start(), | |
| "end": match.end(), | |
| "label": label, | |
| "source": "regex", | |
| } | |
| ) | |
| return found | |
| def collect_model_entities_th(text: str) -> list: | |
| """Collect Thai entities from HF model only.""" | |
| found = [] | |
| hf_entities = hf_ner(text) | |
| for entity in hf_entities: | |
| found.append( | |
| { | |
| "start": entity["start"], | |
| "end": entity["end"], | |
| "label": _normalize_label(entity), | |
| "source": "wangchanberta_hf", | |
| } | |
| ) | |
| return found | |
| def collect_model_entities_en(text: str) -> list: | |
| """Collect English entities from spaCy if available.""" | |
| if spacy_en_nlp is None: | |
| return [] | |
| found = [] | |
| doc = spacy_en_nlp(text) | |
| for ent in doc.ents: | |
| found.append( | |
| { | |
| "start": ent.start_char, | |
| "end": ent.end_char, | |
| "label": ent.label_, | |
| "source": "spacy_en", | |
| } | |
| ) | |
| return found | |
| def _entity_score(entity: dict) -> tuple: | |
| return ( | |
| ENTITY_PRIORITY.get(entity["label"], 0), | |
| SOURCE_PRIORITY.get(entity["source"], 0), | |
| entity["end"] - entity["start"], | |
| ) | |
| def _is_overlap(a: dict, b: dict) -> bool: | |
| return not (a["end"] <= b["start"] or b["end"] <= a["start"]) | |
| def merge_entities_with_priority(entities: list) -> list: | |
| """Resolve overlap by label priority, source confidence, then span length.""" | |
| if not entities: | |
| return [] | |
| entities = sorted( | |
| entities, | |
| key=lambda x: ( | |
| x["start"], | |
| -_entity_score(x)[0], | |
| -_entity_score(x)[1], | |
| -_entity_score(x)[2], | |
| ), | |
| ) | |
| resolved = [] | |
| for item in entities: | |
| conflict_idx = None | |
| for idx, kept in enumerate(resolved): | |
| if _is_overlap(item, kept): | |
| conflict_idx = idx | |
| break | |
| if conflict_idx is None: | |
| resolved.append(item) | |
| continue | |
| kept = resolved[conflict_idx] | |
| if item["label"] == kept["label"]: | |
| # Same label: keep longer span. | |
| if (item["end"] - item["start"]) > (kept["end"] - kept["start"]): | |
| resolved[conflict_idx] = item | |
| elif _entity_score(item) > _entity_score(kept): | |
| # Different labels: use priority score. | |
| resolved[conflict_idx] = item | |
| return sorted(resolved, key=lambda x: x["start"]) | |
| def _merge_entities(text: str) -> list: | |
| """Collect from all enabled sources, normalize, and merge by priority.""" | |
| init_models() | |
| profile = detect_text_profile(text) | |
| mode = profile["mode"] | |
| raw_entities = [] | |
| raw_entities.extend(_collect_regex_entities(text)) | |
| if mode in {"th", "mixed"}: | |
| raw_entities.extend(collect_model_entities_th(text)) | |
| # Run spaCy when text is English/mixed or still contains enough English signal. | |
| if mode in {"en", "mixed"} or profile["en_ratio"] >= 0.12: | |
| raw_entities.extend(collect_model_entities_en(text)) | |
| normalized = [] | |
| for entity in raw_entities: | |
| cleaned = normalize_entity(entity, entity["source"], text) | |
| if cleaned is not None: | |
| normalized.append(cleaned) | |
| return merge_entities_with_priority(normalized) | |
| def mark_pii_context(text: str) -> dict: | |
| """Return tagged sentence and long context, excluding phone masking.""" | |
| entities = _merge_entities(text) | |
| parts = [] | |
| cursor = 0 | |
| tagged_entities = [] | |
| last_was_placeholder = False | |
| last_placeholder_tag = None | |
| for entity in entities: | |
| start, end = entity["start"], entity["end"] | |
| value = text[start:end] | |
| label = entity["label"] | |
| if start < cursor: | |
| continue | |
| if _is_noise_span(value, label, entity.get("source", "model")): | |
| continue | |
| parts.append(text[cursor:start]) | |
| last_was_placeholder = False | |
| last_placeholder_tag = None | |
| if _is_phone_number(value, label): | |
| parts.append(value) | |
| action = "not_marked_phone" | |
| tag_name = "PHONE" | |
| else: | |
| tag_name = label if label in CANONICAL_TAGS else _to_tag(label) | |
| if tag_name == "DATA": | |
| # Defensive guard: date-like entity should not be masked. | |
| parts.append(value) | |
| action = "kept_original" | |
| tagged_entities.append( | |
| { | |
| "text": value, | |
| "label": label, | |
| "tag": tag_name, | |
| "start": start, | |
| "end": end, | |
| "action": action, | |
| "source": entity.get("source", "model"), | |
| } | |
| ) | |
| cursor = end | |
| continue | |
| placeholder = f"<{tag_name}>" | |
| if not (last_was_placeholder and last_placeholder_tag == tag_name): | |
| parts.append(placeholder) | |
| action = "marked" | |
| last_was_placeholder = True | |
| last_placeholder_tag = tag_name | |
| tagged_entities.append( | |
| { | |
| "text": value, | |
| "label": label, | |
| "tag": tag_name, | |
| "start": start, | |
| "end": end, | |
| "action": action, | |
| "source": entity.get("source", "model"), | |
| } | |
| ) | |
| cursor = end | |
| parts.append(text[cursor:]) | |
| tagged_text = re.sub(r"\s+", " ", "".join(parts)).strip() | |
| context = ( | |
| "PII Context Output\n" | |
| f"Original: {text}\n\n" | |
| f"Tagged: {tagged_text}\n\n" | |
| "Entities:\n" | |
| f"{json.dumps(tagged_entities, ensure_ascii=False, indent=2)}" | |
| ) | |
| return { | |
| "original_text": text, | |
| "sentence": tagged_text, | |
| "tagged_text": tagged_text, | |
| "entities": tagged_entities, | |
| "context": context, | |
| } | |
| if __name__ == "__main__": | |
| sample_text = ( | |
| "ฉันชื่อ นางสาวมะลิวา บุญสระดี อาศัยอยู่ที่อำเภอนางรอง จังหวัดบุรีรัมย์ " | |
| "เรียนจบจากมหาวิทยาลัยขอนแก่น ติดต่อได้ที่ 089-123-4567 และอีเมล abc@example.com" | |
| ) | |
| output = mark_pii_context(sample_text) | |
| print(output["sentence"]) | |
| # ในไฟล์ test_model.py | |
| # ใส่ local_files_only=True เพื่อรองรับคำสั่ง Offline ที่คุณจะใช้ | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, local_files_only=True) | |
| model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME, local_files_only=True) | |
| ner = pipeline( | |
| "ner", | |
| model=model, | |
| tokenizer=tokenizer, | |
| aggregation_strategy="simple", | |
| ) |