| |
| import argparse |
| import json |
|
|
| from onnx_token_classifier import ( |
| load_onnx_token_classifier, |
| looks_like_eircode, |
| normalize_label, |
| simple_aggregate_spans_onnx, |
| word_aligned_ppsn_spans_onnx, |
| ) |
|
|
|
|
| ALLOWED = { |
| "PPSN", |
| "ACCOUNT_NUMBER", |
| "BANK_ROUTING_NUMBER", |
| "CREDIT_DEBIT_CARD", |
| "PASSPORT_NUMBER", |
| "POSTCODE", |
| "PHONE_NUMBER", |
| "EMAIL", |
| "FIRST_NAME", |
| "LAST_NAME", |
| "SWIFT_BIC", |
| } |
|
|
|
|
| def mask_text(text: str, spans: list[dict]) -> str: |
| out = text |
| for span in sorted(spans, key=lambda item: (item["start"], item["end"]), reverse=True): |
| out = out[:span["start"]] + f"[{span['label']}]" + out[span["end"]:] |
| return out |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model", default=".") |
| parser.add_argument("--text", required=True) |
| parser.add_argument("--ppsn-min-score", type=float, default=0.4) |
| parser.add_argument("--other-min-score", type=float, default=0.5) |
| parser.add_argument("--json", action="store_true") |
| args = parser.parse_args() |
|
|
| session, tokenizer, config = load_onnx_token_classifier(args.model) |
| general = simple_aggregate_spans_onnx(args.text, session, tokenizer, config, min_score=args.other_min_score) |
| ppsn = word_aligned_ppsn_spans_onnx(args.text, session, tokenizer, config, threshold=args.ppsn_min_score) |
| spans = [] |
| for span in general: |
| label = normalize_label(span["label"]) |
| if label in ALLOWED and label != "PPSN": |
| spans.append(span) |
| def overlaps(a, b): |
| return not (a["end"] <= b["start"] or b["end"] <= a["start"]) |
| for span in ppsn: |
| if looks_like_eircode(span["text"]): |
| continue |
| if any(overlaps(span, existing) for existing in spans): |
| continue |
| spans.append(span) |
| spans.sort(key=lambda item: (item["start"], item["end"])) |
| result = { |
| "model": args.model, |
| "masked_text": mask_text(args.text, spans), |
| "spans": spans, |
| "ppsn_decoder": "word_aligned", |
| "ppsn_min_score": args.ppsn_min_score, |
| "other_min_score": args.other_min_score, |
| "backend": "onnx", |
| } |
| if args.json: |
| print(json.dumps(result, indent=2, ensure_ascii=False)) |
| else: |
| print(result["masked_text"]) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|