| |
| import regex as re |
| import torch |
|
|
| from eircode import iter_eircode_candidates, is_valid_eircode |
| from ppsn import is_plausible_ppsn, iter_ppsn_candidates |
| from raw_word_aligned import word_aligned_ppsn_spans |
|
|
|
|
| TOKEN_RE = re.compile(r"[A-Za-z0-9]+|[^\w\s]", re.UNICODE) |
| PHONE_RE = re.compile(r"^(?:\+353(?:\s*\(0\))?[\s-]*\(?\d{1,3}\)?(?:[\s-]*\d){6,8}|0\(?\d{1,3}\)?(?:[\s-]*\d){6,8})$") |
| PHONE_CANDIDATE_RE = re.compile(r"(?<![A-Za-z0-9])(?:\+353(?:\s*\(0\))?[\s-]*\(?\d{1,3}\)?(?:[\s-]*\d){6,8}|0\(?\d{1,3}\)?(?:[\s-]*\d){6,8})(?![A-Za-z0-9])") |
| PASSPORT_RE = re.compile(r"^[A-Z]{2}\s?\d{7}$") |
| PASSPORT_CANDIDATE_RE = re.compile(r"(?<![A-Za-z0-9])[A-Z]{2}\s?\d{7}(?![A-Za-z0-9])", re.IGNORECASE) |
| SORT_RE = re.compile(r"^(?:\d{6}|\d{2}[ -]\d{2}[ -]\d{2})$") |
| BANK_ROUTING_CANDIDATE_RE = re.compile(r"(?<!\d)(?:\d{6}|\d{2}[ -]\d{2}[ -]\d{2})(?!\d)") |
| IBAN_IE_RE = re.compile(r"^IE\d{2}[A-Z]{4}\d{14}$") |
| IBAN_IE_CANDIDATE_RE = re.compile(r"(?<![A-Za-z0-9])IE\d{2}(?:[\s-]?[A-Z0-9]){18}(?![A-Za-z0-9])", re.IGNORECASE) |
| BIC_RE = re.compile(r"^[A-Z]{4}[A-Z]{2}[A-Z0-9]{2}(?:[A-Z0-9]{3})?$") |
| BIC_CANDIDATE_RE = re.compile(r"(?<![A-Za-z0-9])[A-Z]{4}[A-Z]{2}[A-Z0-9]{2}(?:[A-Z0-9]{3})?(?![A-Za-z0-9])", re.IGNORECASE) |
| EIRCODE_RE = re.compile(r"^(?:[ACDEFHKNPRTVWXY]\d{2}|D6W)\s?[0-9ACDEFHKNPRTVWXY]{4}$", re.IGNORECASE) |
| CARD_GROUPED_RE = re.compile(r"^(?:\d{4}(?:[ -]\d{4}){3,4}|\d{4}[ -]\d{6}[ -]\d{5})$") |
| CARD_CANDIDATE_RE = re.compile(r"(?<!\d)(?:\d[ -]?){13,19}(?!\d)") |
| KNOWN_IE_IBAN_BANK_CODES = { |
| "AIBK", |
| "BOFI", |
| "IPBS", |
| "IRCE", |
| "ULSB", |
| "PTSB", |
| "EBSI", |
| "DABA", |
| "CITI", |
| "TRWI", |
| "REVO", |
| } |
|
|
| DEFAULT_LABEL_THRESHOLDS = { |
| "PHONE_NUMBER": 0.35, |
| "PASSPORT_NUMBER": 0.11, |
| "BANK_ROUTING_NUMBER": 0.35, |
| "ACCOUNT_NUMBER": 0.40, |
| "CREDIT_DEBIT_CARD": 0.08, |
| "SWIFT_BIC": 0.50, |
| } |
|
|
| FORMAT_LABELS = set(DEFAULT_LABEL_THRESHOLDS) |
| OUTPUT_PRIORITY = { |
| "PPSN": 0, |
| "PASSPORT_NUMBER": 1, |
| "ACCOUNT_NUMBER": 2, |
| "BANK_ROUTING_NUMBER": 3, |
| "CREDIT_DEBIT_CARD": 4, |
| "PHONE_NUMBER": 5, |
| "SWIFT_BIC": 6, |
| "POSTCODE": 7, |
| "EMAIL": 8, |
| "FIRST_NAME": 9, |
| "LAST_NAME": 10, |
| } |
|
|
|
|
| def tokenize_with_spans(text: str): |
| return [(m.group(0), m.start(), m.end()) for m in TOKEN_RE.finditer(text)] |
|
|
|
|
| def normalize_label(label: str) -> str: |
| label = (label or "").strip() |
| if label.startswith("B-") or label.startswith("I-"): |
| label = label[2:] |
| return label.upper() |
|
|
|
|
| def luhn_ok(value: str) -> bool: |
| digits = "".join(ch for ch in value if ch.isdigit()) |
| if not (13 <= len(digits) <= 19): |
| return False |
| total = 0 |
| double = False |
| for ch in reversed(digits): |
| number = int(ch) |
| if double: |
| number *= 2 |
| if number > 9: |
| number -= 9 |
| total += number |
| double = not double |
| return total % 10 == 0 |
|
|
|
|
| def iban_mod97_ok(value: str) -> bool: |
| compact = re.sub(r"[\s-]+", "", value.strip().upper()) |
| if not IBAN_IE_RE.match(compact): |
| return False |
| rearranged = compact[4:] + compact[:4] |
| remainder = 0 |
| for ch in rearranged: |
| if ch.isdigit(): |
| digits = ch |
| else: |
| digits = str(ord(ch) - ord("A") + 10) |
| for digit in digits: |
| remainder = (remainder * 10 + int(digit)) % 97 |
| return remainder == 1 |
|
|
|
|
| def is_plausible_ie_iban(value: str) -> bool: |
| compact = re.sub(r"[\s-]+", "", value.strip().upper()) |
| if not IBAN_IE_RE.match(compact): |
| return False |
| if iban_mod97_ok(compact): |
| return True |
| return compact[4:8] in KNOWN_IE_IBAN_BANK_CODES |
|
|
|
|
| def normalize_irish_phone(value: str) -> str: |
| compact = value.strip() |
| compact = compact.replace("(0)", "0") |
| compact = re.sub(r"[\s\-\(\)]", "", compact) |
| if compact.startswith("00353"): |
| compact = "+" + compact[2:] |
| return compact |
|
|
|
|
| def is_valid_irish_phone(value: str) -> bool: |
| compact = normalize_irish_phone(value) |
| if compact.startswith("+353"): |
| rest = compact[4:] |
| if rest.startswith("0"): |
| rest = rest[1:] |
| if not rest.isdigit(): |
| return False |
| if rest.startswith("8"): |
| return len(rest) == 9 |
| return len(rest) in {8, 9} |
| if not compact.startswith("0") or not compact.isdigit(): |
| return False |
| if compact.startswith("08"): |
| return len(compact) == 10 |
| return len(compact) in {9, 10} |
|
|
|
|
| def is_plausible_card(value: str) -> bool: |
| digits = "".join(ch for ch in value if ch.isdigit()) |
| if not (13 <= len(digits) <= 19): |
| return False |
| if luhn_ok(value): |
| return True |
| return CARD_GROUPED_RE.match(value.strip()) is not None |
|
|
|
|
| def normalize_passport(value: str) -> str: |
| return re.sub(r"\s+", "", value.strip().upper()) |
|
|
|
|
| def regex_candidates_for_label(text: str, label: str): |
| label = label.upper() |
| if label == "PPSN": |
| for candidate in iter_ppsn_candidates(text): |
| yield candidate |
| return |
| if label == "POSTCODE": |
| for candidate in iter_eircode_candidates(text): |
| yield candidate |
| return |
| pattern = { |
| "PHONE_NUMBER": PHONE_CANDIDATE_RE, |
| "PASSPORT_NUMBER": PASSPORT_CANDIDATE_RE, |
| "BANK_ROUTING_NUMBER": BANK_ROUTING_CANDIDATE_RE, |
| "ACCOUNT_NUMBER": IBAN_IE_CANDIDATE_RE, |
| "CREDIT_DEBIT_CARD": CARD_CANDIDATE_RE, |
| "SWIFT_BIC": BIC_CANDIDATE_RE, |
| }.get(label) |
| if pattern is None: |
| return |
| for match in pattern.finditer(text): |
| yield { |
| "start": match.start(), |
| "end": match.end(), |
| "text": match.group(0), |
| "normalized": match.group(0), |
| } |
|
|
|
|
| def plausible_label_text(label: str, value: str) -> bool: |
| value = value.strip() |
| if label == "PPSN": |
| return is_plausible_ppsn(value) |
| if label == "PHONE_NUMBER": |
| return PHONE_RE.match(value) is not None and is_valid_irish_phone(value) |
| if label == "PASSPORT_NUMBER": |
| return PASSPORT_RE.match(normalize_passport(value)) is not None |
| if label == "BANK_ROUTING_NUMBER": |
| return SORT_RE.match(value) is not None |
| if label == "ACCOUNT_NUMBER": |
| compact = re.sub(r"[\s-]+", "", value) |
| return is_plausible_ie_iban(value) or (compact.isdigit() and len(compact) == 8) |
| if label == "CREDIT_DEBIT_CARD": |
| return is_plausible_card(value) |
| if label == "SWIFT_BIC": |
| return BIC_RE.match(value.upper()) is not None |
| if label == "POSTCODE": |
| return EIRCODE_RE.match(value) is not None and is_valid_eircode(value) |
| return True |
|
|
|
|
| def label_ids_from_mapping(id2label, label: str): |
| target = label.upper() |
| ids = [] |
| for raw_id, raw_label in id2label.items(): |
| if normalize_label(str(raw_label)) == target: |
| ids.append(int(raw_id)) |
| return ids |
|
|
|
|
| def label_ids(model, label: str): |
| return label_ids_from_mapping(model.config.id2label, label) |
|
|
|
|
| def word_scores_for_label(text: str, model, tokenizer, label: str): |
| pieces = tokenize_with_spans(text) |
| if not pieces: |
| return pieces, [] |
| words = [word for word, _, _ in pieces] |
| encoded = tokenizer(words, is_split_into_words=True, return_tensors="pt", truncation=True) |
| word_ids = encoded.word_ids(batch_index=0) |
| device = next(model.parameters()).device |
| encoded = {key: value.to(device) for key, value in encoded.items()} |
| with torch.no_grad(): |
| logits = model(**encoded).logits[0] |
| probs = torch.softmax(logits, dim=-1) |
| ids = label_ids(model, label) |
| scores = [] |
| for word_index in range(len(pieces)): |
| score = 0.0 |
| for token_index, wid in enumerate(word_ids): |
| if wid != word_index: |
| continue |
| for label_id in ids: |
| score = max(score, float(probs[token_index, label_id])) |
| scores.append(score) |
| return pieces, scores |
|
|
|
|
| def word_scores_for_label_onnx(text: str, session, tokenizer, config, label: str): |
| from onnx_token_classifier import _run_onnx, _softmax |
|
|
| pieces = tokenize_with_spans(text) |
| if not pieces: |
| return pieces, [] |
| words = [word for word, _, _ in pieces] |
| encoded = tokenizer(words, is_split_into_words=True, return_tensors="np", truncation=True) |
| word_ids = encoded.word_ids(batch_index=0) |
| logits = _run_onnx(session, encoded)[0] |
| probs = _softmax(logits, axis=-1) |
| ids = label_ids_from_mapping(config.id2label, label) |
| scores = [] |
| for word_index in range(len(pieces)): |
| score = 0.0 |
| for token_index, wid in enumerate(word_ids): |
| if wid != word_index: |
| continue |
| for label_id in ids: |
| score = max(score, float(probs[token_index, label_id])) |
| scores.append(score) |
| return pieces, scores |
|
|
|
|
| def _word_aligned_label_spans_from_scores(text: str, label: str, threshold: float, pieces, scores): |
| spans = [] |
| active = None |
| for (word, start, end), score in zip(pieces, scores): |
| keep = score >= threshold |
| if label in {"PHONE_NUMBER", "BANK_ROUTING_NUMBER", "CREDIT_DEBIT_CARD"} and word in {"-", "/"}: |
| keep = active is not None and score >= threshold / 2.0 |
| if keep: |
| if active is None: |
| active = {"start": start, "end": end, "label": label} |
| else: |
| if start - active["end"] <= 1: |
| active["end"] = end |
| else: |
| spans.append(active) |
| active = {"start": start, "end": end, "label": label} |
| elif active is not None: |
| spans.append(active) |
| active = None |
| if active is not None: |
| spans.append(active) |
| out = [] |
| for span in spans: |
| value = text[span["start"] : span["end"]] |
| if plausible_label_text(label, value): |
| out.append( |
| { |
| "label": label, |
| "start": span["start"], |
| "end": span["end"], |
| "text": value, |
| } |
| ) |
| return out |
|
|
|
|
| def word_aligned_label_spans( |
| text: str, |
| model, |
| tokenizer, |
| label: str, |
| threshold: float, |
| ): |
| pieces, scores = word_scores_for_label(text, model, tokenizer, label) |
| return _word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores) |
|
|
|
|
| def word_aligned_label_spans_onnx( |
| text: str, |
| session, |
| tokenizer, |
| config, |
| label: str, |
| threshold: float, |
| ): |
| pieces, scores = word_scores_for_label_onnx(text, session, tokenizer, config, label) |
| return _word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores) |
|
|
|
|
| def regex_guided_label_spans(text: str, label: str, threshold: float, pieces, scores): |
| if not pieces: |
| return [] |
| out = [] |
| for candidate in regex_candidates_for_label(text, label): |
| start = int(candidate["start"]) |
| end = int(candidate["end"]) |
| while start < end and text[start].isspace(): |
| start += 1 |
| while end > start and text[end - 1].isspace(): |
| end -= 1 |
| support = 0.0 |
| for (_, piece_start, piece_end), score in zip(pieces, scores): |
| if piece_end <= start or piece_start >= end: |
| continue |
| support = max(support, float(score)) |
| value = text[start:end] |
| if support >= threshold and plausible_label_text(label, value): |
| out.append( |
| { |
| "label": label, |
| "start": start, |
| "end": end, |
| "text": value, |
| "score": support, |
| } |
| ) |
| return out |
|
|
|
|
| def pipeline_to_spans(text: str, outputs: list[dict], min_score: float): |
| spans = [] |
| for output in outputs: |
| label = normalize_label(output.get("entity_group") or output.get("entity") or "") |
| if not label: |
| continue |
| score = float(output.get("score", 0.0)) |
| if score < min_score: |
| continue |
| spans.append( |
| { |
| "label": label, |
| "start": int(output["start"]), |
| "end": int(output["end"]), |
| "score": score, |
| "text": text[int(output["start"]) : int(output["end"])], |
| } |
| ) |
| return spans |
|
|
|
|
| def overlaps(a: dict, b: dict) -> bool: |
| return not (a["end"] <= b["start"] or b["end"] <= a["start"]) |
|
|
|
|
| def span_length(span: dict) -> int: |
| return int(span["end"]) - int(span["start"]) |
|
|
|
|
| def normalize_simple_span(span: dict): |
| label = normalize_label(span["label"]) |
| value = span["text"] |
| if label == "PHONE_NUMBER" and plausible_label_text("CREDIT_DEBIT_CARD", value): |
| label = "CREDIT_DEBIT_CARD" |
| if label in FORMAT_LABELS or label == "POSTCODE": |
| if not plausible_label_text(label, value): |
| return None |
| return { |
| "label": label, |
| "start": int(span["start"]), |
| "end": int(span["end"]), |
| "score": float(span.get("score", 0.0)), |
| "text": value, |
| } |
|
|
|
|
| def dedupe_and_sort(spans: list[dict]): |
| ordered = sorted( |
| spans, |
| key=lambda span: ( |
| int(span["start"]), |
| -span_length(span), |
| OUTPUT_PRIORITY.get(str(span["label"]).upper(), 99), |
| ), |
| ) |
| kept = [] |
| for span in ordered: |
| if any(overlaps(span, other) for other in kept): |
| continue |
| kept.append(span) |
| return kept |
|
|
|
|
| def repair_irish_core_spans( |
| text: str, |
| model, |
| tokenizer, |
| general_outputs: list[dict], |
| other_min_score: float, |
| ppsn_min_score: float, |
| label_thresholds: dict[str, float] | None = None, |
| ): |
| thresholds = dict(DEFAULT_LABEL_THRESHOLDS) |
| if label_thresholds: |
| thresholds.update({key.upper(): value for key, value in label_thresholds.items()}) |
|
|
| spans = [] |
| for span in pipeline_to_spans(text, general_outputs, min_score=other_min_score): |
| normalized = normalize_simple_span(span) |
| if normalized is not None and normalized["label"] != "PPSN": |
| spans.append(normalized) |
|
|
| ppsn_spans = word_aligned_ppsn_spans(text, model, tokenizer, threshold=ppsn_min_score) |
| for span in ppsn_spans: |
| value = text[int(span["start"]) : int(span["end"])] |
| if plausible_label_text("PPSN", value): |
| spans.append( |
| { |
| "label": "PPSN", |
| "start": int(span["start"]), |
| "end": int(span["end"]), |
| "score": float(span.get("score", 0.0)), |
| "text": value, |
| } |
| ) |
|
|
| repairs = [] |
| ppsn_pieces, ppsn_scores = word_scores_for_label(text, model, tokenizer, "PPSN") |
| repairs.extend(regex_guided_label_spans(text, "PPSN", ppsn_min_score, ppsn_pieces, ppsn_scores)) |
| for label, threshold in thresholds.items(): |
| pieces, scores = word_scores_for_label(text, model, tokenizer, label) |
| repairs.extend(_word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores)) |
| repairs.extend(regex_guided_label_spans(text, label, threshold, pieces, scores)) |
|
|
| for candidate in repairs: |
| updated = [] |
| replaced = False |
| for span in spans: |
| if not overlaps(candidate, span): |
| updated.append(span) |
| continue |
| if candidate["label"] == span["label"] and span_length(candidate) > span_length(span): |
| replaced = True |
| continue |
| if candidate["label"] in FORMAT_LABELS and span["label"] in FORMAT_LABELS and span_length(candidate) > span_length(span): |
| replaced = True |
| continue |
| updated.append(span) |
| spans = updated |
| if replaced or not any(overlaps(candidate, span) for span in spans): |
| spans.append(candidate) |
|
|
| return dedupe_and_sort(spans) |
|
|
|
|
| def repair_irish_core_spans_onnx( |
| text: str, |
| session, |
| tokenizer, |
| config, |
| other_min_score: float, |
| ppsn_min_score: float, |
| label_thresholds: dict[str, float] | None = None, |
| general_outputs: list[dict] | None = None, |
| ): |
| from onnx_token_classifier import simple_aggregate_spans_onnx, word_aligned_ppsn_spans_onnx |
|
|
| thresholds = dict(DEFAULT_LABEL_THRESHOLDS) |
| if label_thresholds: |
| thresholds.update({key.upper(): value for key, value in label_thresholds.items()}) |
|
|
| spans = [] |
| if general_outputs is None: |
| general_outputs = simple_aggregate_spans_onnx( |
| text, |
| session, |
| tokenizer, |
| config, |
| min_score=other_min_score, |
| ) |
| for span in pipeline_to_spans(text, general_outputs, min_score=other_min_score): |
| normalized = normalize_simple_span(span) |
| if normalized is not None and normalized["label"] != "PPSN": |
| spans.append(normalized) |
|
|
| ppsn_spans = word_aligned_ppsn_spans_onnx(text, session, tokenizer, config, threshold=ppsn_min_score) |
| for span in ppsn_spans: |
| value = text[int(span["start"]) : int(span["end"])] |
| if plausible_label_text("PPSN", value): |
| spans.append( |
| { |
| "label": "PPSN", |
| "start": int(span["start"]), |
| "end": int(span["end"]), |
| "score": float(span.get("score", 0.0)), |
| "text": value, |
| } |
| ) |
|
|
| repairs = [] |
| ppsn_pieces, ppsn_scores = word_scores_for_label_onnx(text, session, tokenizer, config, "PPSN") |
| repairs.extend(regex_guided_label_spans(text, "PPSN", ppsn_min_score, ppsn_pieces, ppsn_scores)) |
| for label, threshold in thresholds.items(): |
| pieces, scores = word_scores_for_label_onnx(text, session, tokenizer, config, label) |
| repairs.extend(_word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores)) |
| repairs.extend(regex_guided_label_spans(text, label, threshold, pieces, scores)) |
|
|
| for candidate in repairs: |
| updated = [] |
| replaced = False |
| for span in spans: |
| if not overlaps(candidate, span): |
| updated.append(span) |
| continue |
| if candidate["label"] == span["label"] and span_length(candidate) > span_length(span): |
| replaced = True |
| continue |
| if candidate["label"] in FORMAT_LABELS and span["label"] in FORMAT_LABELS and span_length(candidate) > span_length(span): |
| replaced = True |
| continue |
| updated.append(span) |
| spans = updated |
| if replaced or not any(overlaps(candidate, span) for span in spans): |
| spans.append(candidate) |
|
|
| return dedupe_and_sort(spans) |
|
|