| |
| import regex as re |
| import torch |
|
|
| from raw_word_aligned import word_aligned_ppsn_spans |
|
|
|
|
| TOKEN_RE = re.compile(r"[A-Za-z0-9]+|[^\w\s]", re.UNICODE) |
| PHONE_RE = re.compile(r"^(?:\+353\s?(?:\(0\))?\s?\d(?:[\s-]?\d){7,8}|0\d(?:[\s-]?\d){7,8})$") |
| PASSPORT_RE = re.compile(r"^[A-Z]{2}\s?\d{7}$") |
| SORT_RE = re.compile(r"^(?:\d{6}|\d{2}[ -]\d{2}[ -]\d{2})$") |
| IBAN_IE_RE = re.compile(r"^IE\d{2}(?:\s?[A-Z]{4})(?:\s?\d{4}){3}\s?\d{2}$") |
| BIC_RE = re.compile(r"^[A-Z]{4}[A-Z]{2}[A-Z0-9]{2}(?:[A-Z0-9]{3})?$") |
| EIRCODE_RE = re.compile(r"^(?:[ACDEFHKNPRTVWXY]\d{2}|D6W)\s?[0-9ACDEFHKNPRTVWXY]{4}$", re.IGNORECASE) |
|
|
| DEFAULT_LABEL_THRESHOLDS = { |
| "PHONE_NUMBER": 0.35, |
| "PASSPORT_NUMBER": 0.11, |
| "BANK_ROUTING_NUMBER": 0.35, |
| "ACCOUNT_NUMBER": 0.40, |
| "CREDIT_DEBIT_CARD": 0.08, |
| "SWIFT_BIC": 0.50, |
| } |
|
|
| FORMAT_LABELS = set(DEFAULT_LABEL_THRESHOLDS) |
| OUTPUT_PRIORITY = { |
| "PPSN": 0, |
| "PASSPORT_NUMBER": 1, |
| "ACCOUNT_NUMBER": 2, |
| "BANK_ROUTING_NUMBER": 3, |
| "CREDIT_DEBIT_CARD": 4, |
| "PHONE_NUMBER": 5, |
| "SWIFT_BIC": 6, |
| "POSTCODE": 7, |
| "EMAIL": 8, |
| "FIRST_NAME": 9, |
| "LAST_NAME": 10, |
| } |
|
|
|
|
| def tokenize_with_spans(text: str): |
| return [(m.group(0), m.start(), m.end()) for m in TOKEN_RE.finditer(text)] |
|
|
|
|
| def normalize_label(label: str) -> str: |
| label = (label or "").strip() |
| if label.startswith("B-") or label.startswith("I-"): |
| label = label[2:] |
| return label.upper() |
|
|
|
|
| def luhn_ok(value: str) -> bool: |
| digits = "".join(ch for ch in value if ch.isdigit()) |
| if not (13 <= len(digits) <= 19): |
| return False |
| total = 0 |
| double = False |
| for ch in reversed(digits): |
| number = int(ch) |
| if double: |
| number *= 2 |
| if number > 9: |
| number -= 9 |
| total += number |
| double = not double |
| return total % 10 == 0 |
|
|
|
|
| def plausible_label_text(label: str, value: str) -> bool: |
| value = value.strip() |
| if label == "PHONE_NUMBER": |
| return PHONE_RE.match(value) is not None |
| if label == "PASSPORT_NUMBER": |
| return PASSPORT_RE.match(value) is not None |
| if label == "BANK_ROUTING_NUMBER": |
| return SORT_RE.match(value) is not None |
| if label == "ACCOUNT_NUMBER": |
| compact = value.replace(" ", "") |
| return IBAN_IE_RE.match(value) is not None or (compact.isdigit() and len(compact) == 8) |
| if label == "CREDIT_DEBIT_CARD": |
| return luhn_ok(value) |
| if label == "SWIFT_BIC": |
| return BIC_RE.match(value) is not None |
| if label == "POSTCODE": |
| return EIRCODE_RE.match(value) is not None |
| return True |
|
|
|
|
| def label_ids_from_mapping(id2label, label: str): |
| target = label.upper() |
| ids = [] |
| for raw_id, raw_label in id2label.items(): |
| if normalize_label(str(raw_label)) == target: |
| ids.append(int(raw_id)) |
| return ids |
|
|
|
|
| def label_ids(model, label: str): |
| return label_ids_from_mapping(model.config.id2label, label) |
|
|
|
|
| def word_scores_for_label(text: str, model, tokenizer, label: str): |
| pieces = tokenize_with_spans(text) |
| if not pieces: |
| return pieces, [] |
| words = [word for word, _, _ in pieces] |
| encoded = tokenizer(words, is_split_into_words=True, return_tensors="pt", truncation=True) |
| word_ids = encoded.word_ids(batch_index=0) |
| device = next(model.parameters()).device |
| encoded = {key: value.to(device) for key, value in encoded.items()} |
| with torch.no_grad(): |
| logits = model(**encoded).logits[0] |
| probs = torch.softmax(logits, dim=-1) |
| ids = label_ids(model, label) |
| scores = [] |
| for word_index in range(len(pieces)): |
| score = 0.0 |
| for token_index, wid in enumerate(word_ids): |
| if wid != word_index: |
| continue |
| for label_id in ids: |
| score = max(score, float(probs[token_index, label_id])) |
| scores.append(score) |
| return pieces, scores |
|
|
|
|
| def word_scores_for_label_onnx(text: str, session, tokenizer, config, label: str): |
| from onnx_token_classifier import _run_onnx, _softmax |
|
|
| pieces = tokenize_with_spans(text) |
| if not pieces: |
| return pieces, [] |
| words = [word for word, _, _ in pieces] |
| encoded = tokenizer(words, is_split_into_words=True, return_tensors="np", truncation=True) |
| word_ids = encoded.word_ids(batch_index=0) |
| logits = _run_onnx(session, encoded)[0] |
| probs = _softmax(logits, axis=-1) |
| ids = label_ids_from_mapping(config.id2label, label) |
| scores = [] |
| for word_index in range(len(pieces)): |
| score = 0.0 |
| for token_index, wid in enumerate(word_ids): |
| if wid != word_index: |
| continue |
| for label_id in ids: |
| score = max(score, float(probs[token_index, label_id])) |
| scores.append(score) |
| return pieces, scores |
|
|
|
|
| def _word_aligned_label_spans_from_scores(text: str, label: str, threshold: float, pieces, scores): |
| spans = [] |
| active = None |
| for (word, start, end), score in zip(pieces, scores): |
| keep = score >= threshold |
| if label in {"PHONE_NUMBER", "BANK_ROUTING_NUMBER", "CREDIT_DEBIT_CARD"} and word in {"-", "/"}: |
| keep = active is not None and score >= threshold / 2.0 |
| if keep: |
| if active is None: |
| active = {"start": start, "end": end, "label": label} |
| else: |
| if start - active["end"] <= 1: |
| active["end"] = end |
| else: |
| spans.append(active) |
| active = {"start": start, "end": end, "label": label} |
| elif active is not None: |
| spans.append(active) |
| active = None |
| if active is not None: |
| spans.append(active) |
| out = [] |
| for span in spans: |
| value = text[span["start"] : span["end"]] |
| if plausible_label_text(label, value): |
| out.append( |
| { |
| "label": label, |
| "start": span["start"], |
| "end": span["end"], |
| "text": value, |
| } |
| ) |
| return out |
|
|
|
|
| def word_aligned_label_spans( |
| text: str, |
| model, |
| tokenizer, |
| label: str, |
| threshold: float, |
| ): |
| pieces, scores = word_scores_for_label(text, model, tokenizer, label) |
| return _word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores) |
|
|
|
|
| def word_aligned_label_spans_onnx( |
| text: str, |
| session, |
| tokenizer, |
| config, |
| label: str, |
| threshold: float, |
| ): |
| pieces, scores = word_scores_for_label_onnx(text, session, tokenizer, config, label) |
| return _word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores) |
|
|
|
|
| def pipeline_to_spans(text: str, outputs: list[dict], min_score: float): |
| spans = [] |
| for output in outputs: |
| label = normalize_label(output.get("entity_group") or output.get("entity") or "") |
| if not label: |
| continue |
| score = float(output.get("score", 0.0)) |
| if score < min_score: |
| continue |
| spans.append( |
| { |
| "label": label, |
| "start": int(output["start"]), |
| "end": int(output["end"]), |
| "score": score, |
| "text": text[int(output["start"]) : int(output["end"])], |
| } |
| ) |
| return spans |
|
|
|
|
| def overlaps(a: dict, b: dict) -> bool: |
| return not (a["end"] <= b["start"] or b["end"] <= a["start"]) |
|
|
|
|
| def span_length(span: dict) -> int: |
| return int(span["end"]) - int(span["start"]) |
|
|
|
|
| def normalize_simple_span(span: dict): |
| label = normalize_label(span["label"]) |
| value = span["text"] |
| if label == "PHONE_NUMBER" and plausible_label_text("CREDIT_DEBIT_CARD", value): |
| label = "CREDIT_DEBIT_CARD" |
| if label in FORMAT_LABELS or label == "POSTCODE": |
| if not plausible_label_text(label, value): |
| return None |
| return { |
| "label": label, |
| "start": int(span["start"]), |
| "end": int(span["end"]), |
| "score": float(span.get("score", 0.0)), |
| "text": value, |
| } |
|
|
|
|
| def dedupe_and_sort(spans: list[dict]): |
| ordered = sorted( |
| spans, |
| key=lambda span: ( |
| int(span["start"]), |
| -span_length(span), |
| OUTPUT_PRIORITY.get(str(span["label"]).upper(), 99), |
| ), |
| ) |
| kept = [] |
| for span in ordered: |
| if any(overlaps(span, other) for other in kept): |
| continue |
| kept.append(span) |
| return kept |
|
|
|
|
| def repair_irish_core_spans( |
| text: str, |
| model, |
| tokenizer, |
| general_outputs: list[dict], |
| other_min_score: float, |
| ppsn_min_score: float, |
| label_thresholds: dict[str, float] | None = None, |
| ): |
| thresholds = dict(DEFAULT_LABEL_THRESHOLDS) |
| if label_thresholds: |
| thresholds.update({key.upper(): value for key, value in label_thresholds.items()}) |
|
|
| spans = [] |
| for span in pipeline_to_spans(text, general_outputs, min_score=other_min_score): |
| normalized = normalize_simple_span(span) |
| if normalized is not None and normalized["label"] != "PPSN": |
| spans.append(normalized) |
|
|
| ppsn_spans = word_aligned_ppsn_spans(text, model, tokenizer, threshold=ppsn_min_score) |
| for span in ppsn_spans: |
| spans.append( |
| { |
| "label": "PPSN", |
| "start": int(span["start"]), |
| "end": int(span["end"]), |
| "score": float(span.get("score", 0.0)), |
| "text": text[int(span["start"]) : int(span["end"])], |
| } |
| ) |
|
|
| repairs = [] |
| for label, threshold in thresholds.items(): |
| repairs.extend(word_aligned_label_spans(text, model, tokenizer, label, threshold)) |
|
|
| for candidate in repairs: |
| updated = [] |
| replaced = False |
| for span in spans: |
| if not overlaps(candidate, span): |
| updated.append(span) |
| continue |
| if candidate["label"] == span["label"] and span_length(candidate) > span_length(span): |
| replaced = True |
| continue |
| if candidate["label"] in FORMAT_LABELS and span["label"] in FORMAT_LABELS and span_length(candidate) > span_length(span): |
| replaced = True |
| continue |
| updated.append(span) |
| spans = updated |
| if replaced or not any(overlaps(candidate, span) for span in spans): |
| spans.append(candidate) |
|
|
| return dedupe_and_sort(spans) |
|
|
|
|
| def repair_irish_core_spans_onnx( |
| text: str, |
| session, |
| tokenizer, |
| config, |
| other_min_score: float, |
| ppsn_min_score: float, |
| label_thresholds: dict[str, float] | None = None, |
| general_outputs: list[dict] | None = None, |
| ): |
| from onnx_token_classifier import simple_aggregate_spans_onnx, word_aligned_ppsn_spans_onnx |
|
|
| thresholds = dict(DEFAULT_LABEL_THRESHOLDS) |
| if label_thresholds: |
| thresholds.update({key.upper(): value for key, value in label_thresholds.items()}) |
|
|
| spans = [] |
| if general_outputs is None: |
| general_outputs = simple_aggregate_spans_onnx( |
| text, |
| session, |
| tokenizer, |
| config, |
| min_score=other_min_score, |
| ) |
| for span in pipeline_to_spans(text, general_outputs, min_score=other_min_score): |
| normalized = normalize_simple_span(span) |
| if normalized is not None and normalized["label"] != "PPSN": |
| spans.append(normalized) |
|
|
| ppsn_spans = word_aligned_ppsn_spans_onnx(text, session, tokenizer, config, threshold=ppsn_min_score) |
| for span in ppsn_spans: |
| spans.append( |
| { |
| "label": "PPSN", |
| "start": int(span["start"]), |
| "end": int(span["end"]), |
| "score": float(span.get("score", 0.0)), |
| "text": text[int(span["start"]) : int(span["end"])], |
| } |
| ) |
|
|
| repairs = [] |
| for label, threshold in thresholds.items(): |
| repairs.extend(word_aligned_label_spans_onnx(text, session, tokenizer, config, label, threshold)) |
|
|
| for candidate in repairs: |
| updated = [] |
| replaced = False |
| for span in spans: |
| if not overlaps(candidate, span): |
| updated.append(span) |
| continue |
| if candidate["label"] == span["label"] and span_length(candidate) > span_length(span): |
| replaced = True |
| continue |
| if candidate["label"] in FORMAT_LABELS and span["label"] in FORMAT_LABELS and span_length(candidate) > span_length(span): |
| replaced = True |
| continue |
| updated.append(span) |
| spans = updated |
| if replaced or not any(overlaps(candidate, span) for span in spans): |
| spans.append(candidate) |
|
|
| return dedupe_and_sort(spans) |
|
|