#!/usr/bin/env python3 import regex as re import torch from raw_word_aligned import word_aligned_ppsn_spans TOKEN_RE = re.compile(r"[A-Za-z0-9]+|[^\w\s]", re.UNICODE) PHONE_RE = re.compile(r"^(?:\+353\s?(?:\(0\))?\s?\d(?:[\s-]?\d){7,8}|0\d(?:[\s-]?\d){7,8})$") PASSPORT_RE = re.compile(r"^[A-Z]{2}\s?\d{7}$") SORT_RE = re.compile(r"^(?:\d{6}|\d{2}[ -]\d{2}[ -]\d{2})$") IBAN_IE_RE = re.compile(r"^IE\d{2}(?:\s?[A-Z]{4})(?:\s?\d{4}){3}\s?\d{2}$") BIC_RE = re.compile(r"^[A-Z]{4}[A-Z]{2}[A-Z0-9]{2}(?:[A-Z0-9]{3})?$") EIRCODE_RE = re.compile(r"^(?:[ACDEFHKNPRTVWXY]\d{2}|D6W)\s?[0-9ACDEFHKNPRTVWXY]{4}$", re.IGNORECASE) DEFAULT_LABEL_THRESHOLDS = { "PHONE_NUMBER": 0.35, "PASSPORT_NUMBER": 0.11, "BANK_ROUTING_NUMBER": 0.35, "ACCOUNT_NUMBER": 0.40, "CREDIT_DEBIT_CARD": 0.08, "SWIFT_BIC": 0.50, } FORMAT_LABELS = set(DEFAULT_LABEL_THRESHOLDS) OUTPUT_PRIORITY = { "PPSN": 0, "PASSPORT_NUMBER": 1, "ACCOUNT_NUMBER": 2, "BANK_ROUTING_NUMBER": 3, "CREDIT_DEBIT_CARD": 4, "PHONE_NUMBER": 5, "SWIFT_BIC": 6, "POSTCODE": 7, "EMAIL": 8, "FIRST_NAME": 9, "LAST_NAME": 10, } def tokenize_with_spans(text: str): return [(m.group(0), m.start(), m.end()) for m in TOKEN_RE.finditer(text)] def normalize_label(label: str) -> str: label = (label or "").strip() if label.startswith("B-") or label.startswith("I-"): label = label[2:] return label.upper() def luhn_ok(value: str) -> bool: digits = "".join(ch for ch in value if ch.isdigit()) if not (13 <= len(digits) <= 19): return False total = 0 double = False for ch in reversed(digits): number = int(ch) if double: number *= 2 if number > 9: number -= 9 total += number double = not double return total % 10 == 0 def plausible_label_text(label: str, value: str) -> bool: value = value.strip() if label == "PHONE_NUMBER": return PHONE_RE.match(value) is not None if label == "PASSPORT_NUMBER": return PASSPORT_RE.match(value) is not None if label == "BANK_ROUTING_NUMBER": return SORT_RE.match(value) is not None if label == "ACCOUNT_NUMBER": compact = value.replace(" ", "") return IBAN_IE_RE.match(value) is not None or (compact.isdigit() and len(compact) == 8) if label == "CREDIT_DEBIT_CARD": return luhn_ok(value) if label == "SWIFT_BIC": return BIC_RE.match(value) is not None if label == "POSTCODE": return EIRCODE_RE.match(value) is not None return True def label_ids_from_mapping(id2label, label: str): target = label.upper() ids = [] for raw_id, raw_label in id2label.items(): if normalize_label(str(raw_label)) == target: ids.append(int(raw_id)) return ids def label_ids(model, label: str): return label_ids_from_mapping(model.config.id2label, label) def word_scores_for_label(text: str, model, tokenizer, label: str): pieces = tokenize_with_spans(text) if not pieces: return pieces, [] words = [word for word, _, _ in pieces] encoded = tokenizer(words, is_split_into_words=True, return_tensors="pt", truncation=True) word_ids = encoded.word_ids(batch_index=0) device = next(model.parameters()).device encoded = {key: value.to(device) for key, value in encoded.items()} with torch.no_grad(): logits = model(**encoded).logits[0] probs = torch.softmax(logits, dim=-1) ids = label_ids(model, label) scores = [] for word_index in range(len(pieces)): score = 0.0 for token_index, wid in enumerate(word_ids): if wid != word_index: continue for label_id in ids: score = max(score, float(probs[token_index, label_id])) scores.append(score) return pieces, scores def word_scores_for_label_onnx(text: str, session, tokenizer, config, label: str): from onnx_token_classifier import _run_onnx, _softmax pieces = tokenize_with_spans(text) if not pieces: return pieces, [] words = [word for word, _, _ in pieces] encoded = tokenizer(words, is_split_into_words=True, return_tensors="np", truncation=True) word_ids = encoded.word_ids(batch_index=0) logits = _run_onnx(session, encoded)[0] probs = _softmax(logits, axis=-1) ids = label_ids_from_mapping(config.id2label, label) scores = [] for word_index in range(len(pieces)): score = 0.0 for token_index, wid in enumerate(word_ids): if wid != word_index: continue for label_id in ids: score = max(score, float(probs[token_index, label_id])) scores.append(score) return pieces, scores def _word_aligned_label_spans_from_scores(text: str, label: str, threshold: float, pieces, scores): spans = [] active = None for (word, start, end), score in zip(pieces, scores): keep = score >= threshold if label in {"PHONE_NUMBER", "BANK_ROUTING_NUMBER", "CREDIT_DEBIT_CARD"} and word in {"-", "/"}: keep = active is not None and score >= threshold / 2.0 if keep: if active is None: active = {"start": start, "end": end, "label": label} else: if start - active["end"] <= 1: active["end"] = end else: spans.append(active) active = {"start": start, "end": end, "label": label} elif active is not None: spans.append(active) active = None if active is not None: spans.append(active) out = [] for span in spans: value = text[span["start"] : span["end"]] if plausible_label_text(label, value): out.append( { "label": label, "start": span["start"], "end": span["end"], "text": value, } ) return out def word_aligned_label_spans( text: str, model, tokenizer, label: str, threshold: float, ): pieces, scores = word_scores_for_label(text, model, tokenizer, label) return _word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores) def word_aligned_label_spans_onnx( text: str, session, tokenizer, config, label: str, threshold: float, ): pieces, scores = word_scores_for_label_onnx(text, session, tokenizer, config, label) return _word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores) def pipeline_to_spans(text: str, outputs: list[dict], min_score: float): spans = [] for output in outputs: label = normalize_label(output.get("entity_group") or output.get("entity") or "") if not label: continue score = float(output.get("score", 0.0)) if score < min_score: continue spans.append( { "label": label, "start": int(output["start"]), "end": int(output["end"]), "score": score, "text": text[int(output["start"]) : int(output["end"])], } ) return spans def overlaps(a: dict, b: dict) -> bool: return not (a["end"] <= b["start"] or b["end"] <= a["start"]) def span_length(span: dict) -> int: return int(span["end"]) - int(span["start"]) def normalize_simple_span(span: dict): label = normalize_label(span["label"]) value = span["text"] if label == "PHONE_NUMBER" and plausible_label_text("CREDIT_DEBIT_CARD", value): label = "CREDIT_DEBIT_CARD" if label in FORMAT_LABELS or label == "POSTCODE": if not plausible_label_text(label, value): return None return { "label": label, "start": int(span["start"]), "end": int(span["end"]), "score": float(span.get("score", 0.0)), "text": value, } def dedupe_and_sort(spans: list[dict]): ordered = sorted( spans, key=lambda span: ( int(span["start"]), -span_length(span), OUTPUT_PRIORITY.get(str(span["label"]).upper(), 99), ), ) kept = [] for span in ordered: if any(overlaps(span, other) for other in kept): continue kept.append(span) return kept def repair_irish_core_spans( text: str, model, tokenizer, general_outputs: list[dict], other_min_score: float, ppsn_min_score: float, label_thresholds: dict[str, float] | None = None, ): thresholds = dict(DEFAULT_LABEL_THRESHOLDS) if label_thresholds: thresholds.update({key.upper(): value for key, value in label_thresholds.items()}) spans = [] for span in pipeline_to_spans(text, general_outputs, min_score=other_min_score): normalized = normalize_simple_span(span) if normalized is not None and normalized["label"] != "PPSN": spans.append(normalized) ppsn_spans = word_aligned_ppsn_spans(text, model, tokenizer, threshold=ppsn_min_score) for span in ppsn_spans: spans.append( { "label": "PPSN", "start": int(span["start"]), "end": int(span["end"]), "score": float(span.get("score", 0.0)), "text": text[int(span["start"]) : int(span["end"])], } ) repairs = [] for label, threshold in thresholds.items(): repairs.extend(word_aligned_label_spans(text, model, tokenizer, label, threshold)) for candidate in repairs: updated = [] replaced = False for span in spans: if not overlaps(candidate, span): updated.append(span) continue if candidate["label"] == span["label"] and span_length(candidate) > span_length(span): replaced = True continue if candidate["label"] in FORMAT_LABELS and span["label"] in FORMAT_LABELS and span_length(candidate) > span_length(span): replaced = True continue updated.append(span) spans = updated if replaced or not any(overlaps(candidate, span) for span in spans): spans.append(candidate) return dedupe_and_sort(spans) def repair_irish_core_spans_onnx( text: str, session, tokenizer, config, other_min_score: float, ppsn_min_score: float, label_thresholds: dict[str, float] | None = None, general_outputs: list[dict] | None = None, ): from onnx_token_classifier import simple_aggregate_spans_onnx, word_aligned_ppsn_spans_onnx thresholds = dict(DEFAULT_LABEL_THRESHOLDS) if label_thresholds: thresholds.update({key.upper(): value for key, value in label_thresholds.items()}) spans = [] if general_outputs is None: general_outputs = simple_aggregate_spans_onnx( text, session, tokenizer, config, min_score=other_min_score, ) for span in pipeline_to_spans(text, general_outputs, min_score=other_min_score): normalized = normalize_simple_span(span) if normalized is not None and normalized["label"] != "PPSN": spans.append(normalized) ppsn_spans = word_aligned_ppsn_spans_onnx(text, session, tokenizer, config, threshold=ppsn_min_score) for span in ppsn_spans: spans.append( { "label": "PPSN", "start": int(span["start"]), "end": int(span["end"]), "score": float(span.get("score", 0.0)), "text": text[int(span["start"]) : int(span["end"])], } ) repairs = [] for label, threshold in thresholds.items(): repairs.extend(word_aligned_label_spans_onnx(text, session, tokenizer, config, label, threshold)) for candidate in repairs: updated = [] replaced = False for span in spans: if not overlaps(candidate, span): updated.append(span) continue if candidate["label"] == span["label"] and span_length(candidate) > span_length(span): replaced = True continue if candidate["label"] in FORMAT_LABELS and span["label"] in FORMAT_LABELS and span_length(candidate) > span_length(span): replaced = True continue updated.append(span) spans = updated if replaced or not any(overlaps(candidate, span) for span in spans): spans.append(candidate) return dedupe_and_sort(spans)