temsa's picture
Upload folder using huggingface_hub
d49b0d0 verified
#!/usr/bin/env python3
import argparse
import json
from onnx_token_classifier import (
load_onnx_token_classifier,
looks_like_eircode,
normalize_label,
simple_aggregate_spans_onnx,
word_aligned_ppsn_spans_onnx,
)
ALLOWED = {
"PPSN",
"ACCOUNT_NUMBER",
"BANK_ROUTING_NUMBER",
"CREDIT_DEBIT_CARD",
"PASSPORT_NUMBER",
"POSTCODE",
"PHONE_NUMBER",
"EMAIL",
"FIRST_NAME",
"LAST_NAME",
"SWIFT_BIC",
}
def mask_text(text: str, spans: list[dict]) -> str:
out = text
for span in sorted(spans, key=lambda item: (item["start"], item["end"]), reverse=True):
out = out[:span["start"]] + f"[{span['label']}]" + out[span["end"]:]
return out
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", default=".")
parser.add_argument("--text", required=True)
parser.add_argument("--ppsn-min-score", type=float, default=0.4)
parser.add_argument("--other-min-score", type=float, default=0.5)
parser.add_argument("--json", action="store_true")
args = parser.parse_args()
session, tokenizer, config = load_onnx_token_classifier(args.model)
general = simple_aggregate_spans_onnx(args.text, session, tokenizer, config, min_score=args.other_min_score)
ppsn = word_aligned_ppsn_spans_onnx(args.text, session, tokenizer, config, threshold=args.ppsn_min_score)
spans = []
for span in general:
label = normalize_label(span["label"])
if label in ALLOWED and label != "PPSN":
spans.append(span)
def overlaps(a, b):
return not (a["end"] <= b["start"] or b["end"] <= a["start"])
for span in ppsn:
if looks_like_eircode(span["text"]):
continue
if any(overlaps(span, existing) for existing in spans):
continue
spans.append(span)
spans.sort(key=lambda item: (item["start"], item["end"]))
result = {
"model": args.model,
"masked_text": mask_text(args.text, spans),
"spans": spans,
"ppsn_decoder": "word_aligned",
"ppsn_min_score": args.ppsn_min_score,
"other_min_score": args.other_min_score,
"backend": "onnx",
}
if args.json:
print(json.dumps(result, indent=2, ensure_ascii=False))
else:
print(result["masked_text"])
if __name__ == "__main__":
main()