Ministral-3B-PII-Preview / postprocess.py
MaziyarPanahi's picture
Add Ministral-3B-PII: GRPO-trained structured PII extraction (text-to-text)
5515f8a verified
"""
Production-style pre/post-processing for multilingual PII extraction.
This module mirrors what a real clinical PII pipeline would apply on top of a raw
model output. We keep each step small and explicit so failures are easy to audit.
Pipeline
--------
1. NFC-normalize and strip text on both inputs and entity values.
2. Filter language-specific stopwords that the model occasionally mistakes for names
(e.g. Swahili "Jina" = "name").
3. Deduplicate same-label spans where one contains another. We keep the MOST
specific (shortest) member, matching how downstream redaction systems would
prefer precise spans over loose ones.
4. For Chinese / Japanese / Korean, split a joined native name into surname + given
name when the model emitted it as one token.
5. Expose a fuzzy text matcher so evaluation tolerates Slavic case inflection
(e.g. "Москве" == "Москва") and Unicode presentation variants.
Nothing here depends on heavy NLP libraries — all heuristics are regex/string-level,
which is how most real PII pipelines bootstrap coverage for languages without a
mature NER model.
"""
from __future__ import annotations
import re
import unicodedata
# ---------------------------------------------------------------------------
# 1. Unicode normalization
# ---------------------------------------------------------------------------
def nfc(text: str) -> str:
"""Unicode NFC normalize + collapse whitespace + strip."""
if not isinstance(text, str):
return ""
text = unicodedata.normalize("NFC", text)
text = re.sub(r"\s+", " ", text).strip()
return text
# ---------------------------------------------------------------------------
# 2. Language stopwords — common words models hallucinate as names
# ---------------------------------------------------------------------------
LANGUAGE_STOPWORDS: dict[str, set[str]] = {
"sw": {"jina", "jina langu", "simu", "simu yangu", "barua", "barua pepe", "ninaishi"},
"vi": {"tôi", "email", "số điện thoại"},
"tr": {"adım", "e-postam", "telefonum"},
"id": {"nama", "saya", "email"},
"pt": {"meu nome"},
"es": {"me llamo", "mi correo"},
}
def is_stopword(text: str, language: str | None) -> bool:
if not language or language not in LANGUAGE_STOPWORDS:
return False
return nfc(text).lower() in LANGUAGE_STOPWORDS[language]
def filter_stopwords(entities: list[dict], language: str | None) -> list[dict]:
return [e for e in entities if not is_stopword(e.get("text", ""), language)]
# ---------------------------------------------------------------------------
# 3. Same-label overlap deduplication
# ---------------------------------------------------------------------------
def dedupe_overlapping(entities: list[dict]) -> list[dict]:
"""Drop longer same-label spans that fully contain a shorter same-label span.
A clinical downstream prefers specific entities (first_name=An) to loose ones
(first_name='Nguyễn Văn An'). When the model emits both, we keep the shorter.
Different-label overlaps are left untouched.
"""
by_label: dict[str, list[dict]] = {}
for e in entities:
by_label.setdefault(e.get("label", ""), []).append(e)
kept: list[dict] = []
for label, group in by_label.items():
# Sort by length ascending; a span survives only if no shorter same-label
# span is a substring of it.
group_sorted = sorted(group, key=lambda x: len(nfc(x.get("text", ""))))
shorter_texts: list[str] = []
for e in group_sorted:
t = nfc(e.get("text", "")).lower()
if not t:
continue
if any(s and s in t and s != t for s in shorter_texts):
continue # a shorter same-label already covers this
kept.append(e)
shorter_texts.append(t)
return kept
# ---------------------------------------------------------------------------
# 4. CJK name splitting
# ---------------------------------------------------------------------------
# A small gazetteer of common 2-char Chinese surnames. Extend as needed.
CHINESE_TWO_CHAR_SURNAMES = {
"欧阳", "司马", "诸葛", "上官", "夏侯", "东方", "皇甫", "尉迟", "公孙",
"慕容", "长孙", "宇文", "司徒", "鲜于", "司空", "轩辕", "令狐", "钟离",
}
# Common Japanese surnames (2-char). Tiny set sufficient for the demo; a real
# system would use a larger dictionary.
JAPANESE_COMMON_SURNAMES = {
"佐藤", "鈴木", "高橋", "田中", "伊藤", "渡辺", "山本", "中村", "小林",
"加藤", "吉田", "山田", "佐々木", "山口", "斎藤", "松本", "井上", "木村",
"林", "清水",
}
_CJK_RE = re.compile(r"^[\u3400-\u9fff\u3040-\u30ff\uac00-\ud7af]+$")
def _is_cjk(text: str) -> bool:
return bool(text) and bool(_CJK_RE.match(text))
def _split_korean_name(text: str) -> tuple[str, str] | None:
# Korean: 1-char surname + 2-char given name is the overwhelming pattern.
# Only split at 3+ chars; 2-char strings are likely a surname or given alone.
if len(text) == 3:
return text[0], text[1:]
if len(text) == 4:
return text[:2], text[2:]
return None
def _split_chinese_name(text: str) -> tuple[str, str] | None:
# Require 3+ chars. A 2-char Chinese string is almost always a given name
# on its own (e.g. "小明") rather than a full name to split.
if len(text) < 3 or len(text) > 4:
return None
if text[:2] in CHINESE_TWO_CHAR_SURNAMES:
return text[:2], text[2:]
return text[0], text[1:]
def _split_japanese_name(text: str) -> tuple[str, str] | None:
# Require 3+ chars. A 2-char Japanese string is typically a given name
# ("太郎", "花子") or a surname alone ("田中", "鈴木") — context-ambiguous,
# so do nothing. 4-char falls back to 2+2 (typical kanji full name).
if len(text) < 3:
return None
for n in (3, 2):
if text[:n] in JAPANESE_COMMON_SURNAMES and len(text) > n:
return text[:n], text[n:]
if len(text) == 4:
return text[:2], text[2:]
return text[:1], text[1:]
def split_cjk_name(text: str, language: str) -> tuple[str, str] | None:
text = nfc(text)
if not _is_cjk(text):
return None
if language == "ko":
return _split_korean_name(text)
if language == "ja":
return _split_japanese_name(text)
if language == "zh":
return _split_chinese_name(text)
return None
VIETNAMESE_COMMON_SURNAMES = {
"Nguyễn", "Trần", "Lê", "Phạm", "Hoàng", "Huỳnh", "Phan", "Vũ", "Võ",
"Đặng", "Bùi", "Đỗ", "Hồ", "Ngô", "Dương", "Lý", "Trịnh", "Đoàn", "Mai",
}
def _looks_like_vietnamese_surname(text: str) -> bool:
return nfc(text) in VIETNAMESE_COMMON_SURNAMES
def swap_vietnamese_name_order(entities: list[dict], language: str | None) -> list[dict]:
"""Vietnamese writes names as <family> <middle> <given>. Models trained on
Western ordering call the first token `first_name` and the last token
`last_name`, which is the opposite of the Vietnamese convention.
We only swap when we can confirm the mistake — specifically, when a value
labeled `first_name` is a known Vietnamese surname. This avoids breaking
ground truth that is already labeled correctly.
"""
if language != "vi":
return entities
needs_swap = any(
e.get("label") == "first_name" and _looks_like_vietnamese_surname(str(e.get("text", "")))
for e in entities
)
if not needs_swap:
return entities
swapped: list[dict] = []
for e in entities:
lbl = e.get("label")
if lbl == "first_name":
swapped.append({**e, "label": "last_name"})
elif lbl == "last_name":
swapped.append({**e, "label": "first_name"})
else:
swapped.append(e)
return swapped
def expand_cjk_names(entities: list[dict], language: str | None) -> list[dict]:
"""If a joined CJK name is emitted as first_name / last_name / full_name,
also emit the split (surname, given_name) pair so matching is generous.
"""
if language not in {"zh", "ja", "ko"}:
return entities
NAME_LABELS = {"first_name", "last_name", "name", "full_name", "person_name"}
expanded = list(entities)
seen = {(nfc(e.get("text", "")).lower(), e.get("label", "")) for e in entities}
for e in entities:
label = str(e.get("label", "")).lower()
if label not in NAME_LABELS:
continue
text = nfc(e.get("text", ""))
split = split_cjk_name(text, language)
if not split:
continue
surname, given = split
for new_text, new_label in [(surname, "last_name"), (given, "first_name")]:
key = (new_text.lower(), new_label)
if key not in seen:
expanded.append({"text": new_text, "label": new_label})
seen.add(key)
return expanded
# ---------------------------------------------------------------------------
# 5. Fuzzy text matching (Slavic case tolerance, substring, NFC)
# ---------------------------------------------------------------------------
SLAVIC_LANGS = {"ru", "uk", "pl", "cs", "bg", "sk", "sr", "hr"}
def _common_prefix_len(a: str, b: str) -> int:
n = 0
for x, y in zip(a, b):
if x == y:
n += 1
else:
break
return n
def fuzzy_text_match(a: str, b: str, language: str | None = None) -> bool:
"""Compare two entity text values with production-style tolerance.
Returns True if:
- exact match after NFC + case-fold
- one is a (word-boundary) substring of the other
- for Slavic languages, strings share a long common prefix (case inflection)
"""
a_norm = nfc(a).lower()
b_norm = nfc(b).lower()
if not a_norm or not b_norm:
return False
if a_norm == b_norm:
return True
# Substring containment (common for "Москве" vs "Москва" isn't substring,
# but "Seattle, WA" vs "Seattle" is).
if a_norm in b_norm or b_norm in a_norm:
# Avoid matching very short substrings inside long ones (e.g. "An" in "Anna").
shorter, longer = sorted([a_norm, b_norm], key=len)
if len(shorter) >= 3 and (len(shorter) / len(longer)) >= 0.5:
return True
# Slavic case inflection: Москва / Москве / Москвы share root "Москв"
if language in SLAVIC_LANGS:
min_len = min(len(a_norm), len(b_norm))
cp = _common_prefix_len(a_norm, b_norm)
if cp >= max(3, min_len - 2):
return True
return False
# ---------------------------------------------------------------------------
# 6. Top-level postprocess
# ---------------------------------------------------------------------------
def postprocess_entities(
entities: list[dict],
language: str | None = None,
expand_cjk: bool = True,
dedupe: bool = True,
filter_stops: bool = True,
) -> list[dict]:
"""Apply the full post-processing pipeline to a list of entity dicts.
The order matters: normalize first, then expand CJK splits so both the joined
and split forms are present, then dedupe same-label overlaps, then filter
language stopwords.
"""
if not entities:
return []
# Normalize text fields
normed: list[dict] = []
for e in entities:
if not isinstance(e, dict):
continue
t = nfc(e.get("text", ""))
if not t:
continue
label = str(e.get("label", "")).strip().lower()
normed.append({"text": t, "label": label})
if expand_cjk:
normed = expand_cjk_names(normed, language)
normed = swap_vietnamese_name_order(normed, language)
if dedupe:
normed = dedupe_overlapping(normed)
if filter_stops:
normed = filter_stopwords(normed, language)
return normed
def preprocess_text(text: str) -> str:
"""Pre-processing applied before the model sees the input.
Mirrors what a clinical pipeline would do to incoming free text:
- NFC normalize
- Strip zero-width and control characters
- Collapse internal whitespace but keep structure (newlines preserved)
"""
if not isinstance(text, str):
return ""
text = unicodedata.normalize("NFC", text)
# Strip zero-width and bidi control characters that confuse tokenizers.
text = re.sub(r"[\u200b-\u200f\u202a-\u202e\u2060\ufeff]", "", text)
# Collapse runs of spaces/tabs but keep newlines.
text = re.sub(r"[ \t]+", " ", text)
return text.strip()