#!/usr/bin/env python3 import regex as re import torch from eircode import iter_eircode_candidates, is_valid_eircode from ppsn import is_plausible_ppsn, iter_ppsn_candidates from raw_word_aligned import word_aligned_ppsn_spans TOKEN_RE = re.compile(r"[A-Za-z0-9]+|[^\w\s]", re.UNICODE) PHONE_RE = re.compile(r"^(?:\+353(?:\s*\(0\))?[\s-]*\(?\d{1,3}\)?(?:[\s-]*\d){6,8}|0\(?\d{1,3}\)?(?:[\s-]*\d){6,8})$") PHONE_CANDIDATE_RE = re.compile(r"(? str: label = (label or "").strip() if label.startswith("B-") or label.startswith("I-"): label = label[2:] return label.upper() def luhn_ok(value: str) -> bool: digits = "".join(ch for ch in value if ch.isdigit()) if not (13 <= len(digits) <= 19): return False total = 0 double = False for ch in reversed(digits): number = int(ch) if double: number *= 2 if number > 9: number -= 9 total += number double = not double return total % 10 == 0 def iban_mod97_ok(value: str) -> bool: compact = re.sub(r"[\s-]+", "", value.strip().upper()) if not IBAN_IE_RE.match(compact): return False rearranged = compact[4:] + compact[:4] remainder = 0 for ch in rearranged: if ch.isdigit(): digits = ch else: digits = str(ord(ch) - ord("A") + 10) for digit in digits: remainder = (remainder * 10 + int(digit)) % 97 return remainder == 1 def is_plausible_ie_iban(value: str) -> bool: compact = re.sub(r"[\s-]+", "", value.strip().upper()) if not IBAN_IE_RE.match(compact): return False if iban_mod97_ok(compact): return True return compact[4:8] in KNOWN_IE_IBAN_BANK_CODES def normalize_irish_phone(value: str) -> str: compact = value.strip() compact = compact.replace("(0)", "0") compact = re.sub(r"[\s\-\(\)]", "", compact) if compact.startswith("00353"): compact = "+" + compact[2:] return compact def is_valid_irish_phone(value: str) -> bool: compact = normalize_irish_phone(value) if compact.startswith("+353"): rest = compact[4:] if rest.startswith("0"): rest = rest[1:] if not rest.isdigit(): return False if rest.startswith("8"): return len(rest) == 9 return len(rest) in {8, 9} if not compact.startswith("0") or not compact.isdigit(): return False if compact.startswith("08"): return len(compact) == 10 return len(compact) in {9, 10} def is_plausible_card(value: str) -> bool: digits = "".join(ch for ch in value if ch.isdigit()) if not (13 <= len(digits) <= 19): return False if luhn_ok(value): return True return CARD_GROUPED_RE.match(value.strip()) is not None def normalize_passport(value: str) -> str: return re.sub(r"\s+", "", value.strip().upper()) def regex_candidates_for_label(text: str, label: str): label = label.upper() if label == "PPSN": for candidate in iter_ppsn_candidates(text): yield candidate return if label == "POSTCODE": for candidate in iter_eircode_candidates(text): yield candidate return pattern = { "PHONE_NUMBER": PHONE_CANDIDATE_RE, "PASSPORT_NUMBER": PASSPORT_CANDIDATE_RE, "BANK_ROUTING_NUMBER": BANK_ROUTING_CANDIDATE_RE, "ACCOUNT_NUMBER": IBAN_IE_CANDIDATE_RE, "CREDIT_DEBIT_CARD": CARD_CANDIDATE_RE, "SWIFT_BIC": BIC_CANDIDATE_RE, }.get(label) if pattern is None: return for match in pattern.finditer(text): yield { "start": match.start(), "end": match.end(), "text": match.group(0), "normalized": match.group(0), } def plausible_label_text(label: str, value: str) -> bool: value = value.strip() if label == "PPSN": return is_plausible_ppsn(value) if label == "PHONE_NUMBER": return PHONE_RE.match(value) is not None and is_valid_irish_phone(value) if label == "PASSPORT_NUMBER": return PASSPORT_RE.match(normalize_passport(value)) is not None if label == "BANK_ROUTING_NUMBER": return SORT_RE.match(value) is not None if label == "ACCOUNT_NUMBER": compact = re.sub(r"[\s-]+", "", value) return is_plausible_ie_iban(value) or (compact.isdigit() and len(compact) == 8) if label == "CREDIT_DEBIT_CARD": return is_plausible_card(value) if label == "SWIFT_BIC": return BIC_RE.match(value.upper()) is not None if label == "POSTCODE": return EIRCODE_RE.match(value) is not None and is_valid_eircode(value) return True def label_ids_from_mapping(id2label, label: str): target = label.upper() ids = [] for raw_id, raw_label in id2label.items(): if normalize_label(str(raw_label)) == target: ids.append(int(raw_id)) return ids def label_ids(model, label: str): return label_ids_from_mapping(model.config.id2label, label) def word_scores_for_label(text: str, model, tokenizer, label: str): pieces = tokenize_with_spans(text) if not pieces: return pieces, [] words = [word for word, _, _ in pieces] encoded = tokenizer(words, is_split_into_words=True, return_tensors="pt", truncation=True) word_ids = encoded.word_ids(batch_index=0) device = next(model.parameters()).device encoded = {key: value.to(device) for key, value in encoded.items()} with torch.no_grad(): logits = model(**encoded).logits[0] probs = torch.softmax(logits, dim=-1) ids = label_ids(model, label) scores = [] for word_index in range(len(pieces)): score = 0.0 for token_index, wid in enumerate(word_ids): if wid != word_index: continue for label_id in ids: score = max(score, float(probs[token_index, label_id])) scores.append(score) return pieces, scores def word_scores_for_label_onnx(text: str, session, tokenizer, config, label: str): from onnx_token_classifier import _run_onnx, _softmax pieces = tokenize_with_spans(text) if not pieces: return pieces, [] words = [word for word, _, _ in pieces] encoded = tokenizer(words, is_split_into_words=True, return_tensors="np", truncation=True) word_ids = encoded.word_ids(batch_index=0) logits = _run_onnx(session, encoded)[0] probs = _softmax(logits, axis=-1) ids = label_ids_from_mapping(config.id2label, label) scores = [] for word_index in range(len(pieces)): score = 0.0 for token_index, wid in enumerate(word_ids): if wid != word_index: continue for label_id in ids: score = max(score, float(probs[token_index, label_id])) scores.append(score) return pieces, scores def _word_aligned_label_spans_from_scores(text: str, label: str, threshold: float, pieces, scores): spans = [] active = None for (word, start, end), score in zip(pieces, scores): keep = score >= threshold if label in {"PHONE_NUMBER", "BANK_ROUTING_NUMBER", "CREDIT_DEBIT_CARD"} and word in {"-", "/"}: keep = active is not None and score >= threshold / 2.0 if keep: if active is None: active = {"start": start, "end": end, "label": label} else: if start - active["end"] <= 1: active["end"] = end else: spans.append(active) active = {"start": start, "end": end, "label": label} elif active is not None: spans.append(active) active = None if active is not None: spans.append(active) out = [] for span in spans: value = text[span["start"] : span["end"]] if plausible_label_text(label, value): out.append( { "label": label, "start": span["start"], "end": span["end"], "text": value, } ) return out def word_aligned_label_spans( text: str, model, tokenizer, label: str, threshold: float, ): pieces, scores = word_scores_for_label(text, model, tokenizer, label) return _word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores) def word_aligned_label_spans_onnx( text: str, session, tokenizer, config, label: str, threshold: float, ): pieces, scores = word_scores_for_label_onnx(text, session, tokenizer, config, label) return _word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores) def regex_guided_label_spans(text: str, label: str, threshold: float, pieces, scores): if not pieces: return [] out = [] for candidate in regex_candidates_for_label(text, label): start = int(candidate["start"]) end = int(candidate["end"]) while start < end and text[start].isspace(): start += 1 while end > start and text[end - 1].isspace(): end -= 1 support = 0.0 for (_, piece_start, piece_end), score in zip(pieces, scores): if piece_end <= start or piece_start >= end: continue support = max(support, float(score)) value = text[start:end] if support >= threshold and plausible_label_text(label, value): out.append( { "label": label, "start": start, "end": end, "text": value, "score": support, } ) return out def pipeline_to_spans(text: str, outputs: list[dict], min_score: float): spans = [] for output in outputs: label = normalize_label(output.get("entity_group") or output.get("entity") or "") if not label: continue score = float(output.get("score", 0.0)) if score < min_score: continue spans.append( { "label": label, "start": int(output["start"]), "end": int(output["end"]), "score": score, "text": text[int(output["start"]) : int(output["end"])], } ) return spans def overlaps(a: dict, b: dict) -> bool: return not (a["end"] <= b["start"] or b["end"] <= a["start"]) def span_length(span: dict) -> int: return int(span["end"]) - int(span["start"]) def normalize_simple_span(span: dict): label = normalize_label(span["label"]) value = span["text"] if label == "PHONE_NUMBER" and plausible_label_text("CREDIT_DEBIT_CARD", value): label = "CREDIT_DEBIT_CARD" if label in FORMAT_LABELS or label == "POSTCODE": if not plausible_label_text(label, value): return None return { "label": label, "start": int(span["start"]), "end": int(span["end"]), "score": float(span.get("score", 0.0)), "text": value, } def dedupe_and_sort(spans: list[dict]): ordered = sorted( spans, key=lambda span: ( int(span["start"]), -span_length(span), OUTPUT_PRIORITY.get(str(span["label"]).upper(), 99), ), ) kept = [] for span in ordered: if any(overlaps(span, other) for other in kept): continue kept.append(span) return kept def repair_irish_core_spans( text: str, model, tokenizer, general_outputs: list[dict], other_min_score: float, ppsn_min_score: float, label_thresholds: dict[str, float] | None = None, ): thresholds = dict(DEFAULT_LABEL_THRESHOLDS) if label_thresholds: thresholds.update({key.upper(): value for key, value in label_thresholds.items()}) spans = [] for span in pipeline_to_spans(text, general_outputs, min_score=other_min_score): normalized = normalize_simple_span(span) if normalized is not None and normalized["label"] != "PPSN": spans.append(normalized) ppsn_spans = word_aligned_ppsn_spans(text, model, tokenizer, threshold=ppsn_min_score) for span in ppsn_spans: value = text[int(span["start"]) : int(span["end"])] if plausible_label_text("PPSN", value): spans.append( { "label": "PPSN", "start": int(span["start"]), "end": int(span["end"]), "score": float(span.get("score", 0.0)), "text": value, } ) repairs = [] ppsn_pieces, ppsn_scores = word_scores_for_label(text, model, tokenizer, "PPSN") repairs.extend(regex_guided_label_spans(text, "PPSN", ppsn_min_score, ppsn_pieces, ppsn_scores)) for label, threshold in thresholds.items(): pieces, scores = word_scores_for_label(text, model, tokenizer, label) repairs.extend(_word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores)) repairs.extend(regex_guided_label_spans(text, label, threshold, pieces, scores)) for candidate in repairs: updated = [] replaced = False for span in spans: if not overlaps(candidate, span): updated.append(span) continue if candidate["label"] == span["label"] and span_length(candidate) > span_length(span): replaced = True continue if candidate["label"] in FORMAT_LABELS and span["label"] in FORMAT_LABELS and span_length(candidate) > span_length(span): replaced = True continue updated.append(span) spans = updated if replaced or not any(overlaps(candidate, span) for span in spans): spans.append(candidate) return dedupe_and_sort(spans) def repair_irish_core_spans_onnx( text: str, session, tokenizer, config, other_min_score: float, ppsn_min_score: float, label_thresholds: dict[str, float] | None = None, general_outputs: list[dict] | None = None, ): from onnx_token_classifier import simple_aggregate_spans_onnx, word_aligned_ppsn_spans_onnx thresholds = dict(DEFAULT_LABEL_THRESHOLDS) if label_thresholds: thresholds.update({key.upper(): value for key, value in label_thresholds.items()}) spans = [] if general_outputs is None: general_outputs = simple_aggregate_spans_onnx( text, session, tokenizer, config, min_score=other_min_score, ) for span in pipeline_to_spans(text, general_outputs, min_score=other_min_score): normalized = normalize_simple_span(span) if normalized is not None and normalized["label"] != "PPSN": spans.append(normalized) ppsn_spans = word_aligned_ppsn_spans_onnx(text, session, tokenizer, config, threshold=ppsn_min_score) for span in ppsn_spans: value = text[int(span["start"]) : int(span["end"])] if plausible_label_text("PPSN", value): spans.append( { "label": "PPSN", "start": int(span["start"]), "end": int(span["end"]), "score": float(span.get("score", 0.0)), "text": value, } ) repairs = [] ppsn_pieces, ppsn_scores = word_scores_for_label_onnx(text, session, tokenizer, config, "PPSN") repairs.extend(regex_guided_label_spans(text, "PPSN", ppsn_min_score, ppsn_pieces, ppsn_scores)) for label, threshold in thresholds.items(): pieces, scores = word_scores_for_label_onnx(text, session, tokenizer, config, label) repairs.extend(_word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores)) repairs.extend(regex_guided_label_spans(text, label, threshold, pieces, scores)) for candidate in repairs: updated = [] replaced = False for span in spans: if not overlaps(candidate, span): updated.append(span) continue if candidate["label"] == span["label"] and span_length(candidate) > span_length(span): replaced = True continue if candidate["label"] in FORMAT_LABELS and span["label"] in FORMAT_LABELS and span_length(candidate) > span_length(span): replaced = True continue updated.append(span) spans = updated if replaced or not any(overlaps(candidate, span) for span in spans): spans.append(candidate) return dedupe_and_sort(spans)