#!/usr/bin/env python3 import re import torch from eircode import iter_eircode_candidates, is_valid_eircode from irish_core_generated_scanner_spec import SCANNER_SPEC from ppsn import is_plausible_ppsn, iter_ppsn_candidates from raw_word_aligned import word_aligned_ppsn_spans TOKEN_RE = re.compile(r"[A-Za-z0-9]+|[^\w\s]", re.UNICODE) TRAILING_TRIM_CHARS = set(" \t\r\n\u00A0-") KNOWN_IE_IBAN_BANK_CODES = { "AIBK", "BOFI", "IPBS", "IRCE", "ULSB", "PTSB", "EBSI", "DABA", "CITI", "TRWI", "REVO", } DEFAULT_LABEL_THRESHOLDS = { "PHONE_NUMBER": 0.35, "PASSPORT_NUMBER": 0.11, "BANK_ROUTING_NUMBER": 0.35, "ACCOUNT_NUMBER": 0.40, "CREDIT_DEBIT_CARD": 0.08, "SWIFT_BIC": 0.50, } FORMAT_LABELS = set(DEFAULT_LABEL_THRESHOLDS) OUTPUT_PRIORITY = { "PPSN": 0, "PASSPORT_NUMBER": 1, "ACCOUNT_NUMBER": 2, "BANK_ROUTING_NUMBER": 3, "CREDIT_DEBIT_CARD": 4, "PHONE_NUMBER": 5, "SWIFT_BIC": 6, "POSTCODE": 7, "EMAIL": 8, "FIRST_NAME": 9, "LAST_NAME": 10, } def tokenize_with_spans(text: str): return [(m.group(0), m.start(), m.end()) for m in TOKEN_RE.finditer(text)] def is_ascii_digit(ch: str) -> bool: return "0" <= ch <= "9" def is_ascii_letter(ch: str) -> bool: upper = ch.upper() return "A" <= upper <= "Z" def is_ascii_alnum(ch: str) -> bool: return is_ascii_digit(ch) or is_ascii_letter(ch) def is_word_boundary(text: str, index: int) -> bool: if index < 0 or index >= len(text): return True return not text[index].isalnum() def normalize_compact(value: str, uppercase: bool = True) -> str: chars = [] for ch in value.strip(): if ch.isalnum(): chars.append(ch.upper() if uppercase else ch) return "".join(chars) def normalize_label(label: str) -> str: label = (label or "").strip() if label.startswith("B-") or label.startswith("I-"): label = label[2:] return label.upper() def luhn_ok(value: str) -> bool: digits = "".join(ch for ch in value if ch.isdigit()) if not (13 <= len(digits) <= 19): return False total = 0 double = False for ch in reversed(digits): number = int(ch) if double: number *= 2 if number > 9: number -= 9 total += number double = not double return total % 10 == 0 def iban_mod97_ok(value: str) -> bool: compact = normalize_compact(value) if len(compact) != 22 or not compact.startswith("IE"): return False if not compact[2:4].isdigit(): return False if not all(is_ascii_letter(ch) for ch in compact[4:8]): return False if not compact[8:].isdigit(): return False rearranged = compact[4:] + compact[:4] remainder = 0 for ch in rearranged: if ch.isdigit(): digits = ch else: digits = str(ord(ch) - ord("A") + 10) for digit in digits: remainder = (remainder * 10 + int(digit)) % 97 return remainder == 1 def is_plausible_ie_iban(value: str) -> bool: compact = normalize_compact(value) if len(compact) != 22 or not compact.startswith("IE"): return False if not compact[2:4].isdigit(): return False if not all(is_ascii_letter(ch) for ch in compact[4:8]): return False if not compact[8:].isdigit(): return False if iban_mod97_ok(compact): return True return compact[4:8] in KNOWN_IE_IBAN_BANK_CODES def normalize_irish_phone(value: str) -> str: compact = value.strip() compact = compact.replace("(0)", "0") chars = [] for ch in compact: if ch in " -()": continue chars.append(ch) compact = "".join(chars) if compact.startswith("00353"): compact = "+" + compact[2:] return compact def is_valid_irish_phone(value: str) -> bool: compact = normalize_irish_phone(value) if compact.startswith("+353"): rest = compact[4:] if rest.startswith("0"): rest = rest[1:] if not rest.isdigit(): return False if rest.startswith("8"): return len(rest) == 9 return len(rest) in {8, 9} if not compact.startswith("0") or not compact.isdigit(): return False if compact.startswith("08"): return len(compact) == 10 return len(compact) in {9, 10} def is_plausible_card(value: str) -> bool: digits = "".join(ch for ch in value if ch.isdigit()) if not (13 <= len(digits) <= 19): return False if luhn_ok(value): return True stripped = value.strip() if not stripped: return False groups = [] current = [] saw_sep = False for ch in stripped: if ch.isdigit(): current.append(ch) continue if ch not in {" ", "-"}: return False saw_sep = True if not current: return False groups.append("".join(current)) current = [] if current: groups.append("".join(current)) if not saw_sep: return False lengths = [len(group) for group in groups] return lengths in ([4, 4, 4, 4], [4, 4, 4, 4, 3], [4, 6, 5]) def normalize_passport(value: str) -> str: chars = [] for ch in value.strip(): if ch.isspace(): continue chars.append(ch.upper()) return "".join(chars) def is_valid_passport(value: str) -> bool: compact = normalize_passport(value) return len(compact) == 9 and all(is_ascii_letter(ch) for ch in compact[:2]) and compact[2:].isdigit() def is_valid_sort_code(value: str) -> bool: stripped = value.strip() if not stripped: return False if stripped.isdigit(): return len(stripped) == 6 groups = [] current = [] for ch in stripped: if ch.isdigit(): current.append(ch) continue if ch not in {" ", "-"}: return False if not current: return False groups.append("".join(current)) current = [] if current: groups.append("".join(current)) return len(groups) == 3 and all(len(group) == 2 and group.isdigit() for group in groups) def is_valid_bic(value: str) -> bool: compact = normalize_compact(value) if len(compact) not in {8, 11}: return False if not all(is_ascii_letter(ch) for ch in compact[:6]): return False return all(is_ascii_alnum(ch) for ch in compact[6:]) def scan_candidates( text: str, *, start_ok, allowed_chars: set[str], min_len: int, max_len: int, validator, ): i = 0 n = len(text) while i < n: ch = text[i] if not start_ok(ch) or not is_word_boundary(text, i - 1): i += 1 continue run_end = i while run_end < n and run_end - i < max_len and text[run_end] in allowed_chars: run_end += 1 best_end = None end = run_end while end > i: while end > i and text[end - 1] in TRAILING_TRIM_CHARS: end -= 1 if end - i < min_len: break if is_word_boundary(text, end): candidate = text[i:end] if validator(candidate): best_end = end break end -= 1 if best_end is not None: value = text[i:best_end] yield { "start": i, "end": best_end, "text": value, "normalized": normalize_compact(value, uppercase=False), } i = best_end else: i += 1 def spec_candidates_for_label(text: str, label: str): label = label.upper() spec = SCANNER_SPEC["scanners"].get(label) if spec is None: return if spec["kind"] == "delegate": delegate_name = spec["function"] if delegate_name == "iter_ppsn_candidates": yield from iter_ppsn_candidates(text) elif delegate_name == "iter_eircode_candidates": yield from iter_eircode_candidates(text) return start_spec = SCANNER_SPEC["start_predicates"][spec["start_predicate"]] validators = { "is_valid_irish_phone": is_valid_irish_phone, "is_valid_passport": is_valid_passport, "is_valid_sort_code": is_valid_sort_code, "is_plausible_ie_iban": is_plausible_ie_iban, "is_plausible_card": is_plausible_card, "is_valid_bic": is_valid_bic, } if "builtin" in start_spec: builtin = start_spec["builtin"] if builtin == "ascii_letter": start_ok = is_ascii_letter elif builtin == "ascii_digit": start_ok = is_ascii_digit else: raise ValueError(f"Unknown builtin start predicate: {builtin}") else: allowed = set(start_spec["any_of"]) start_ok = lambda ch, allowed=allowed: ch in allowed yield from scan_candidates( text, start_ok=start_ok, allowed_chars=set(SCANNER_SPEC["char_classes"][spec["allowed_chars"]]), min_len=int(spec["min_len"]), max_len=int(spec["max_len"]), validator=validators[spec["validator"]], ) def plausible_label_text(label: str, value: str) -> bool: value = value.strip() if label == "PPSN": return is_plausible_ppsn(value) if label == "PHONE_NUMBER": return is_valid_irish_phone(value) if label == "PASSPORT_NUMBER": return is_valid_passport(value) if label == "BANK_ROUTING_NUMBER": return is_valid_sort_code(value) if label == "ACCOUNT_NUMBER": compact = normalize_compact(value) return is_plausible_ie_iban(value) or (compact.isdigit() and len(compact) == 8) if label == "CREDIT_DEBIT_CARD": return is_plausible_card(value) if label == "SWIFT_BIC": return is_valid_bic(value) if label == "POSTCODE": return is_valid_eircode(value) return True def label_ids_from_mapping(id2label, label: str): target = label.upper() ids = [] for raw_id, raw_label in id2label.items(): if normalize_label(str(raw_label)) == target: ids.append(int(raw_id)) return ids def label_ids(model, label: str): return label_ids_from_mapping(model.config.id2label, label) def word_scores_for_label(text: str, model, tokenizer, label: str): pieces = tokenize_with_spans(text) if not pieces: return pieces, [] words = [word for word, _, _ in pieces] encoded = tokenizer(words, is_split_into_words=True, return_tensors="pt", truncation=True) word_ids = encoded.word_ids(batch_index=0) device = next(model.parameters()).device encoded = {key: value.to(device) for key, value in encoded.items()} with torch.no_grad(): logits = model(**encoded).logits[0] probs = torch.softmax(logits, dim=-1) ids = label_ids(model, label) scores = [] for word_index in range(len(pieces)): score = 0.0 for token_index, wid in enumerate(word_ids): if wid != word_index: continue for label_id in ids: score = max(score, float(probs[token_index, label_id])) scores.append(score) return pieces, scores def word_scores_for_label_onnx(text: str, session, tokenizer, config, label: str): from onnx_token_classifier import _run_onnx, _softmax pieces = tokenize_with_spans(text) if not pieces: return pieces, [] words = [word for word, _, _ in pieces] encoded = tokenizer(words, is_split_into_words=True, return_tensors="np", truncation=True) word_ids = encoded.word_ids(batch_index=0) logits = _run_onnx(session, encoded)[0] probs = _softmax(logits, axis=-1) ids = label_ids_from_mapping(config.id2label, label) scores = [] for word_index in range(len(pieces)): score = 0.0 for token_index, wid in enumerate(word_ids): if wid != word_index: continue for label_id in ids: score = max(score, float(probs[token_index, label_id])) scores.append(score) return pieces, scores def _word_aligned_label_spans_from_scores(text: str, label: str, threshold: float, pieces, scores): spans = [] active = None for (word, start, end), score in zip(pieces, scores): keep = score >= threshold if label in {"PHONE_NUMBER", "BANK_ROUTING_NUMBER", "CREDIT_DEBIT_CARD"} and word in {"-", "/"}: keep = active is not None and score >= threshold / 2.0 if keep: if active is None: active = {"start": start, "end": end, "label": label} else: if start - active["end"] <= 1: active["end"] = end else: spans.append(active) active = {"start": start, "end": end, "label": label} elif active is not None: spans.append(active) active = None if active is not None: spans.append(active) out = [] for span in spans: value = text[span["start"] : span["end"]] if plausible_label_text(label, value): out.append( { "label": label, "start": span["start"], "end": span["end"], "text": value, "source": "word_aligned", } ) return out def word_aligned_label_spans( text: str, model, tokenizer, label: str, threshold: float, ): pieces, scores = word_scores_for_label(text, model, tokenizer, label) return _word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores) def word_aligned_label_spans_onnx( text: str, session, tokenizer, config, label: str, threshold: float, ): pieces, scores = word_scores_for_label_onnx(text, session, tokenizer, config, label) return _word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores) def scanner_guided_label_spans(text: str, label: str, threshold: float, pieces, scores): if not pieces: return [] out = [] for candidate in spec_candidates_for_label(text, label): start = int(candidate["start"]) end = int(candidate["end"]) while start < end and text[start].isspace(): start += 1 while end > start and text[end - 1].isspace(): end -= 1 support = 0.0 for (_, piece_start, piece_end), score in zip(pieces, scores): if piece_end <= start or piece_start >= end: continue support = max(support, float(score)) value = text[start:end] if support >= threshold and plausible_label_text(label, value): out.append( { "label": label, "start": start, "end": end, "text": value, "score": support, "source": "scanner_guided", } ) return out def pipeline_to_spans(text: str, outputs: list[dict], min_score: float): spans = [] for output in outputs: label = normalize_label(output.get("entity_group") or output.get("entity") or "") if not label: continue score = float(output.get("score", 0.0)) if score < min_score: continue spans.append( { "label": label, "start": int(output["start"]), "end": int(output["end"]), "score": score, "text": text[int(output["start"]) : int(output["end"])], } ) return spans def overlaps(a: dict, b: dict) -> bool: return not (a["end"] <= b["start"] or b["end"] <= a["start"]) def span_length(span: dict) -> int: return int(span["end"]) - int(span["start"]) def normalize_simple_span(span: dict): label = normalize_label(span["label"]) value = span["text"] if label == "PHONE_NUMBER" and plausible_label_text("CREDIT_DEBIT_CARD", value): label = "CREDIT_DEBIT_CARD" if label in FORMAT_LABELS or label == "POSTCODE": if not plausible_label_text(label, value): return None return { "label": label, "start": int(span["start"]), "end": int(span["end"]), "score": float(span.get("score", 0.0)), "text": value, "source": span.get("source", "model"), } def dedupe_and_sort(spans: list[dict]): ordered = sorted( spans, key=lambda span: ( int(span["start"]), -span_length(span), OUTPUT_PRIORITY.get(str(span["label"]).upper(), 99), ), ) kept = [] for span in ordered: if any(overlaps(span, other) for other in kept): continue kept.append(span) return kept def repair_irish_core_spans( text: str, model, tokenizer, general_outputs: list[dict], other_min_score: float, ppsn_min_score: float, label_thresholds: dict[str, float] | None = None, ): thresholds = dict(DEFAULT_LABEL_THRESHOLDS) if label_thresholds: thresholds.update({key.upper(): value for key, value in label_thresholds.items()}) spans = [] for span in pipeline_to_spans(text, general_outputs, min_score=other_min_score): normalized = normalize_simple_span(span) if normalized is not None and normalized["label"] != "PPSN": spans.append(normalized) ppsn_spans = word_aligned_ppsn_spans(text, model, tokenizer, threshold=ppsn_min_score) for span in ppsn_spans: value = text[int(span["start"]) : int(span["end"])] if plausible_label_text("PPSN", value): spans.append( { "label": "PPSN", "start": int(span["start"]), "end": int(span["end"]), "score": float(span.get("score", 0.0)), "text": value, "source": span.get("source", "model"), } ) repairs = [] ppsn_pieces, ppsn_scores = word_scores_for_label(text, model, tokenizer, "PPSN") repairs.extend(scanner_guided_label_spans(text, "PPSN", ppsn_min_score, ppsn_pieces, ppsn_scores)) for label, threshold in thresholds.items(): pieces, scores = word_scores_for_label(text, model, tokenizer, label) repairs.extend(_word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores)) repairs.extend(scanner_guided_label_spans(text, label, threshold, pieces, scores)) for candidate in repairs: updated = [] replaced = False for span in spans: if not overlaps(candidate, span): updated.append(span) continue if candidate["label"] == span["label"] and span_length(candidate) > span_length(span): replaced = True continue if ( candidate["label"] == span["label"] and candidate.get("source") == "scanner_guided" and span.get("source") != "scanner_guided" ): replaced = True continue if candidate["label"] in FORMAT_LABELS and span["label"] in FORMAT_LABELS and span_length(candidate) > span_length(span): replaced = True continue updated.append(span) spans = updated if replaced or not any(overlaps(candidate, span) for span in spans): spans.append(candidate) return dedupe_and_sort(spans) def repair_irish_core_spans_onnx( text: str, session, tokenizer, config, other_min_score: float, ppsn_min_score: float, label_thresholds: dict[str, float] | None = None, general_outputs: list[dict] | None = None, ): from onnx_token_classifier import simple_aggregate_spans_onnx, word_aligned_ppsn_spans_onnx thresholds = dict(DEFAULT_LABEL_THRESHOLDS) if label_thresholds: thresholds.update({key.upper(): value for key, value in label_thresholds.items()}) spans = [] if general_outputs is None: general_outputs = simple_aggregate_spans_onnx( text, session, tokenizer, config, min_score=other_min_score, ) for span in pipeline_to_spans(text, general_outputs, min_score=other_min_score): normalized = normalize_simple_span(span) if normalized is not None and normalized["label"] != "PPSN": spans.append(normalized) ppsn_spans = word_aligned_ppsn_spans_onnx(text, session, tokenizer, config, threshold=ppsn_min_score) for span in ppsn_spans: value = text[int(span["start"]) : int(span["end"])] if plausible_label_text("PPSN", value): spans.append( { "label": "PPSN", "start": int(span["start"]), "end": int(span["end"]), "score": float(span.get("score", 0.0)), "text": value, "source": span.get("source", "model"), } ) repairs = [] ppsn_pieces, ppsn_scores = word_scores_for_label_onnx(text, session, tokenizer, config, "PPSN") repairs.extend(scanner_guided_label_spans(text, "PPSN", ppsn_min_score, ppsn_pieces, ppsn_scores)) for label, threshold in thresholds.items(): pieces, scores = word_scores_for_label_onnx(text, session, tokenizer, config, label) repairs.extend(_word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores)) repairs.extend(scanner_guided_label_spans(text, label, threshold, pieces, scores)) for candidate in repairs: updated = [] replaced = False for span in spans: if not overlaps(candidate, span): updated.append(span) continue if candidate["label"] == span["label"] and span_length(candidate) > span_length(span): replaced = True continue if ( candidate["label"] == span["label"] and candidate.get("source") == "scanner_guided" and span.get("source") != "scanner_guided" ): replaced = True continue if candidate["label"] in FORMAT_LABELS and span["label"] in FORMAT_LABELS and span_length(candidate) > span_length(span): replaced = True continue updated.append(span) spans = updated if replaced or not any(overlaps(candidate, span) for span in spans): spans.append(candidate) return dedupe_and_sort(spans)