| |
| import argparse |
| import json |
| import os |
|
|
| os.environ.setdefault("TRANSFORMERS_NO_TF", "1") |
| os.environ.setdefault("TRANSFORMERS_NO_FLAX", "1") |
| os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1") |
| os.environ["USE_TF"] = "0" |
| os.environ["USE_FLAX"] = "0" |
| os.environ["USE_TORCH"] = "1" |
|
|
| import torch |
| from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline |
|
|
| import regex as re |
|
|
|
|
| TOKEN_RE = re.compile(r"[A-Za-z0-9]+|[^\w\s]", re.UNICODE) |
| EIRCODE_RE = re.compile(r"^(?:[ACDEFHKNPRTVWXY]\d{2}|D6W)\s?[0-9ACDEFHKNPRTVWXY]{4}$", re.IGNORECASE) |
| ALLOWED = { |
| "PPSN", |
| "ACCOUNT_NUMBER", |
| "BANK_ROUTING_NUMBER", |
| "CREDIT_DEBIT_CARD", |
| "PASSPORT_NUMBER", |
| "POSTCODE", |
| "PHONE_NUMBER", |
| "EMAIL", |
| "FIRST_NAME", |
| "LAST_NAME", |
| "SWIFT_BIC", |
| } |
|
|
|
|
| def tokenize_with_spans(text: str): |
| return [(m.group(0), m.start(), m.end()) for m in TOKEN_RE.finditer(text)] |
|
|
|
|
| def normalize_label(label: str) -> str: |
| label = (label or "").strip() |
| if label.startswith("B-") or label.startswith("I-"): |
| label = label[2:] |
| return label.upper() |
|
|
|
|
| def looks_like_eircode(value: str) -> bool: |
| return EIRCODE_RE.match(value.strip()) is not None |
|
|
|
|
| def ppsn_label_ids(model): |
| ids = [] |
| for raw_id, raw_label in model.config.id2label.items(): |
| label_id = int(raw_id) |
| label = str(raw_label or "").strip() |
| if label.endswith("PPSN"): |
| ids.append(label_id) |
| return sorted(ids) |
|
|
|
|
| def word_aligned_ppsn_spans(text: str, model, tokenizer, threshold: float): |
| pieces = tokenize_with_spans(text) |
| if not pieces: |
| return [] |
| words = [word for word, _, _ in pieces] |
| encoded = tokenizer(words, is_split_into_words=True, return_tensors="pt", truncation=True) |
| word_ids = encoded.word_ids(batch_index=0) |
| device = next(model.parameters()).device |
| encoded = {k: v.to(device) for k, v in encoded.items()} |
| with torch.no_grad(): |
| logits = model(**encoded).logits[0] |
| probs = torch.softmax(logits, dim=-1) |
| label_ids = ppsn_label_ids(model) |
| word_scores = [] |
| for word_index in range(len(pieces)): |
| score = 0.0 |
| for token_index, wid in enumerate(word_ids): |
| if wid != word_index: |
| continue |
| for label_id in label_ids: |
| score = max(score, float(probs[token_index, label_id])) |
| word_scores.append(score) |
| spans = [] |
| active = None |
| for (_, start, end), score in zip(pieces, word_scores): |
| if score >= threshold: |
| if active is None: |
| active = {"start": start, "end": end, "score": score} |
| else: |
| active["end"] = end |
| active["score"] = max(active["score"], score) |
| elif active is not None: |
| spans.append(active) |
| active = None |
| if active is not None: |
| spans.append(active) |
| for span in spans: |
| span["label"] = "PPSN" |
| span["text"] = text[span["start"]:span["end"]] |
| return spans |
|
|
|
|
| def merge_spans(text: str, general_spans: list[dict], ppsn_spans: list[dict], other_min_score: float): |
| out = [] |
| for span in general_spans: |
| label = normalize_label(span.get("entity_group") or span.get("entity") or "") |
| if label not in ALLOWED or label == "PPSN": |
| continue |
| if float(span.get("score", 0.0)) < other_min_score: |
| continue |
| out.append({ |
| "label": label, |
| "start": int(span["start"]), |
| "end": int(span["end"]), |
| "score": float(span["score"]), |
| "text": text[int(span["start"]):int(span["end"])], |
| }) |
| def overlaps(a, b): |
| return not (a["end"] <= b["start"] or b["end"] <= a["start"]) |
| for span in ppsn_spans: |
| if looks_like_eircode(span["text"]): |
| continue |
| if any(overlaps(span, existing) for existing in out): |
| continue |
| out.append(span) |
| out.sort(key=lambda item: (item["start"], item["end"])) |
| return out |
|
|
|
|
| def mask_text(text: str, spans: list[dict]) -> str: |
| out = text |
| for span in sorted(spans, key=lambda item: (item["start"], item["end"]), reverse=True): |
| out = out[:span["start"]] + f"[{span['label']}]" + out[span["end"]:] |
| return out |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model", default=".") |
| parser.add_argument("--text", required=True) |
| parser.add_argument("--device", choices=["auto", "cpu", "cuda"], default="auto") |
| parser.add_argument("--ppsn-min-score", type=float, default=0.4) |
| parser.add_argument("--other-min-score", type=float, default=0.5) |
| parser.add_argument("--json", action="store_true") |
| args = parser.parse_args() |
|
|
| try: |
| tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True, fix_mistral_regex=True) |
| except Exception: |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True, fix_mistral_regex=False) |
| except TypeError: |
| tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True) |
| model = AutoModelForTokenClassification.from_pretrained(args.model) |
| if args.device == "auto": |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| else: |
| device = args.device |
| model.to(device) |
| model.eval() |
|
|
| nlp = pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple", device=0 if device == "cuda" else -1) |
| general = nlp(args.text) |
| ppsn = word_aligned_ppsn_spans(args.text, model, tokenizer, threshold=args.ppsn_min_score) |
| spans = merge_spans(args.text, general, ppsn, other_min_score=args.other_min_score) |
| result = { |
| "model": args.model, |
| "masked_text": mask_text(args.text, spans), |
| "spans": spans, |
| "ppsn_decoder": "word_aligned", |
| "ppsn_min_score": args.ppsn_min_score, |
| "other_min_score": args.other_min_score, |
| } |
| if args.json: |
| print(json.dumps(result, indent=2, ensure_ascii=False)) |
| else: |
| print(result["masked_text"]) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|