temsa's picture
Publish rc7 with spec-driven scanner release
32bcb86 verified
#!/usr/bin/env python3
import re
import torch
from eircode import iter_eircode_candidates, is_valid_eircode
from irish_core_generated_scanner_spec import SCANNER_SPEC
from ppsn import is_plausible_ppsn, iter_ppsn_candidates
from raw_word_aligned import word_aligned_ppsn_spans
TOKEN_RE = re.compile(r"[A-Za-z0-9]+|[^\w\s]", re.UNICODE)
TRAILING_TRIM_CHARS = set(" \t\r\n\u00A0-")
KNOWN_IE_IBAN_BANK_CODES = {
"AIBK",
"BOFI",
"IPBS",
"IRCE",
"ULSB",
"PTSB",
"EBSI",
"DABA",
"CITI",
"TRWI",
"REVO",
}
DEFAULT_LABEL_THRESHOLDS = {
"PHONE_NUMBER": 0.35,
"PASSPORT_NUMBER": 0.11,
"BANK_ROUTING_NUMBER": 0.35,
"ACCOUNT_NUMBER": 0.40,
"CREDIT_DEBIT_CARD": 0.08,
"SWIFT_BIC": 0.50,
}
FORMAT_LABELS = set(DEFAULT_LABEL_THRESHOLDS)
OUTPUT_PRIORITY = {
"PPSN": 0,
"PASSPORT_NUMBER": 1,
"ACCOUNT_NUMBER": 2,
"BANK_ROUTING_NUMBER": 3,
"CREDIT_DEBIT_CARD": 4,
"PHONE_NUMBER": 5,
"SWIFT_BIC": 6,
"POSTCODE": 7,
"EMAIL": 8,
"FIRST_NAME": 9,
"LAST_NAME": 10,
}
def tokenize_with_spans(text: str):
return [(m.group(0), m.start(), m.end()) for m in TOKEN_RE.finditer(text)]
def is_ascii_digit(ch: str) -> bool:
return "0" <= ch <= "9"
def is_ascii_letter(ch: str) -> bool:
upper = ch.upper()
return "A" <= upper <= "Z"
def is_ascii_alnum(ch: str) -> bool:
return is_ascii_digit(ch) or is_ascii_letter(ch)
def is_word_boundary(text: str, index: int) -> bool:
if index < 0 or index >= len(text):
return True
return not text[index].isalnum()
def normalize_compact(value: str, uppercase: bool = True) -> str:
chars = []
for ch in value.strip():
if ch.isalnum():
chars.append(ch.upper() if uppercase else ch)
return "".join(chars)
def normalize_label(label: str) -> str:
label = (label or "").strip()
if label.startswith("B-") or label.startswith("I-"):
label = label[2:]
return label.upper()
def luhn_ok(value: str) -> bool:
digits = "".join(ch for ch in value if ch.isdigit())
if not (13 <= len(digits) <= 19):
return False
total = 0
double = False
for ch in reversed(digits):
number = int(ch)
if double:
number *= 2
if number > 9:
number -= 9
total += number
double = not double
return total % 10 == 0
def iban_mod97_ok(value: str) -> bool:
compact = normalize_compact(value)
if len(compact) != 22 or not compact.startswith("IE"):
return False
if not compact[2:4].isdigit():
return False
if not all(is_ascii_letter(ch) for ch in compact[4:8]):
return False
if not compact[8:].isdigit():
return False
rearranged = compact[4:] + compact[:4]
remainder = 0
for ch in rearranged:
if ch.isdigit():
digits = ch
else:
digits = str(ord(ch) - ord("A") + 10)
for digit in digits:
remainder = (remainder * 10 + int(digit)) % 97
return remainder == 1
def is_plausible_ie_iban(value: str) -> bool:
compact = normalize_compact(value)
if len(compact) != 22 or not compact.startswith("IE"):
return False
if not compact[2:4].isdigit():
return False
if not all(is_ascii_letter(ch) for ch in compact[4:8]):
return False
if not compact[8:].isdigit():
return False
if iban_mod97_ok(compact):
return True
return compact[4:8] in KNOWN_IE_IBAN_BANK_CODES
def normalize_irish_phone(value: str) -> str:
compact = value.strip()
compact = compact.replace("(0)", "0")
chars = []
for ch in compact:
if ch in " -()":
continue
chars.append(ch)
compact = "".join(chars)
if compact.startswith("00353"):
compact = "+" + compact[2:]
return compact
def is_valid_irish_phone(value: str) -> bool:
compact = normalize_irish_phone(value)
if compact.startswith("+353"):
rest = compact[4:]
if rest.startswith("0"):
rest = rest[1:]
if not rest.isdigit():
return False
if rest.startswith("8"):
return len(rest) == 9
return len(rest) in {8, 9}
if not compact.startswith("0") or not compact.isdigit():
return False
if compact.startswith("08"):
return len(compact) == 10
return len(compact) in {9, 10}
def is_plausible_card(value: str) -> bool:
digits = "".join(ch for ch in value if ch.isdigit())
if not (13 <= len(digits) <= 19):
return False
if luhn_ok(value):
return True
stripped = value.strip()
if not stripped:
return False
groups = []
current = []
saw_sep = False
for ch in stripped:
if ch.isdigit():
current.append(ch)
continue
if ch not in {" ", "-"}:
return False
saw_sep = True
if not current:
return False
groups.append("".join(current))
current = []
if current:
groups.append("".join(current))
if not saw_sep:
return False
lengths = [len(group) for group in groups]
return lengths in ([4, 4, 4, 4], [4, 4, 4, 4, 3], [4, 6, 5])
def normalize_passport(value: str) -> str:
chars = []
for ch in value.strip():
if ch.isspace():
continue
chars.append(ch.upper())
return "".join(chars)
def is_valid_passport(value: str) -> bool:
compact = normalize_passport(value)
return len(compact) == 9 and all(is_ascii_letter(ch) for ch in compact[:2]) and compact[2:].isdigit()
def is_valid_sort_code(value: str) -> bool:
stripped = value.strip()
if not stripped:
return False
if stripped.isdigit():
return len(stripped) == 6
groups = []
current = []
for ch in stripped:
if ch.isdigit():
current.append(ch)
continue
if ch not in {" ", "-"}:
return False
if not current:
return False
groups.append("".join(current))
current = []
if current:
groups.append("".join(current))
return len(groups) == 3 and all(len(group) == 2 and group.isdigit() for group in groups)
def is_valid_bic(value: str) -> bool:
compact = normalize_compact(value)
if len(compact) not in {8, 11}:
return False
if not all(is_ascii_letter(ch) for ch in compact[:6]):
return False
return all(is_ascii_alnum(ch) for ch in compact[6:])
def scan_candidates(
text: str,
*,
start_ok,
allowed_chars: set[str],
min_len: int,
max_len: int,
validator,
):
i = 0
n = len(text)
while i < n:
ch = text[i]
if not start_ok(ch) or not is_word_boundary(text, i - 1):
i += 1
continue
run_end = i
while run_end < n and run_end - i < max_len and text[run_end] in allowed_chars:
run_end += 1
best_end = None
end = run_end
while end > i:
while end > i and text[end - 1] in TRAILING_TRIM_CHARS:
end -= 1
if end - i < min_len:
break
if is_word_boundary(text, end):
candidate = text[i:end]
if validator(candidate):
best_end = end
break
end -= 1
if best_end is not None:
value = text[i:best_end]
yield {
"start": i,
"end": best_end,
"text": value,
"normalized": normalize_compact(value, uppercase=False),
}
i = best_end
else:
i += 1
def spec_candidates_for_label(text: str, label: str):
label = label.upper()
spec = SCANNER_SPEC["scanners"].get(label)
if spec is None:
return
if spec["kind"] == "delegate":
delegate_name = spec["function"]
if delegate_name == "iter_ppsn_candidates":
yield from iter_ppsn_candidates(text)
elif delegate_name == "iter_eircode_candidates":
yield from iter_eircode_candidates(text)
return
start_spec = SCANNER_SPEC["start_predicates"][spec["start_predicate"]]
validators = {
"is_valid_irish_phone": is_valid_irish_phone,
"is_valid_passport": is_valid_passport,
"is_valid_sort_code": is_valid_sort_code,
"is_plausible_ie_iban": is_plausible_ie_iban,
"is_plausible_card": is_plausible_card,
"is_valid_bic": is_valid_bic,
}
if "builtin" in start_spec:
builtin = start_spec["builtin"]
if builtin == "ascii_letter":
start_ok = is_ascii_letter
elif builtin == "ascii_digit":
start_ok = is_ascii_digit
else:
raise ValueError(f"Unknown builtin start predicate: {builtin}")
else:
allowed = set(start_spec["any_of"])
start_ok = lambda ch, allowed=allowed: ch in allowed
yield from scan_candidates(
text,
start_ok=start_ok,
allowed_chars=set(SCANNER_SPEC["char_classes"][spec["allowed_chars"]]),
min_len=int(spec["min_len"]),
max_len=int(spec["max_len"]),
validator=validators[spec["validator"]],
)
def plausible_label_text(label: str, value: str) -> bool:
value = value.strip()
if label == "PPSN":
return is_plausible_ppsn(value)
if label == "PHONE_NUMBER":
return is_valid_irish_phone(value)
if label == "PASSPORT_NUMBER":
return is_valid_passport(value)
if label == "BANK_ROUTING_NUMBER":
return is_valid_sort_code(value)
if label == "ACCOUNT_NUMBER":
compact = normalize_compact(value)
return is_plausible_ie_iban(value) or (compact.isdigit() and len(compact) == 8)
if label == "CREDIT_DEBIT_CARD":
return is_plausible_card(value)
if label == "SWIFT_BIC":
return is_valid_bic(value)
if label == "POSTCODE":
return is_valid_eircode(value)
return True
def label_ids_from_mapping(id2label, label: str):
target = label.upper()
ids = []
for raw_id, raw_label in id2label.items():
if normalize_label(str(raw_label)) == target:
ids.append(int(raw_id))
return ids
def label_ids(model, label: str):
return label_ids_from_mapping(model.config.id2label, label)
def word_scores_for_label(text: str, model, tokenizer, label: str):
pieces = tokenize_with_spans(text)
if not pieces:
return pieces, []
words = [word for word, _, _ in pieces]
encoded = tokenizer(words, is_split_into_words=True, return_tensors="pt", truncation=True)
word_ids = encoded.word_ids(batch_index=0)
device = next(model.parameters()).device
encoded = {key: value.to(device) for key, value in encoded.items()}
with torch.no_grad():
logits = model(**encoded).logits[0]
probs = torch.softmax(logits, dim=-1)
ids = label_ids(model, label)
scores = []
for word_index in range(len(pieces)):
score = 0.0
for token_index, wid in enumerate(word_ids):
if wid != word_index:
continue
for label_id in ids:
score = max(score, float(probs[token_index, label_id]))
scores.append(score)
return pieces, scores
def word_scores_for_label_onnx(text: str, session, tokenizer, config, label: str):
from onnx_token_classifier import _run_onnx, _softmax
pieces = tokenize_with_spans(text)
if not pieces:
return pieces, []
words = [word for word, _, _ in pieces]
encoded = tokenizer(words, is_split_into_words=True, return_tensors="np", truncation=True)
word_ids = encoded.word_ids(batch_index=0)
logits = _run_onnx(session, encoded)[0]
probs = _softmax(logits, axis=-1)
ids = label_ids_from_mapping(config.id2label, label)
scores = []
for word_index in range(len(pieces)):
score = 0.0
for token_index, wid in enumerate(word_ids):
if wid != word_index:
continue
for label_id in ids:
score = max(score, float(probs[token_index, label_id]))
scores.append(score)
return pieces, scores
def _word_aligned_label_spans_from_scores(text: str, label: str, threshold: float, pieces, scores):
spans = []
active = None
for (word, start, end), score in zip(pieces, scores):
keep = score >= threshold
if label in {"PHONE_NUMBER", "BANK_ROUTING_NUMBER", "CREDIT_DEBIT_CARD"} and word in {"-", "/"}:
keep = active is not None and score >= threshold / 2.0
if keep:
if active is None:
active = {"start": start, "end": end, "label": label}
else:
if start - active["end"] <= 1:
active["end"] = end
else:
spans.append(active)
active = {"start": start, "end": end, "label": label}
elif active is not None:
spans.append(active)
active = None
if active is not None:
spans.append(active)
out = []
for span in spans:
value = text[span["start"] : span["end"]]
if plausible_label_text(label, value):
out.append(
{
"label": label,
"start": span["start"],
"end": span["end"],
"text": value,
"source": "word_aligned",
}
)
return out
def word_aligned_label_spans(
text: str,
model,
tokenizer,
label: str,
threshold: float,
):
pieces, scores = word_scores_for_label(text, model, tokenizer, label)
return _word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores)
def word_aligned_label_spans_onnx(
text: str,
session,
tokenizer,
config,
label: str,
threshold: float,
):
pieces, scores = word_scores_for_label_onnx(text, session, tokenizer, config, label)
return _word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores)
def scanner_guided_label_spans(text: str, label: str, threshold: float, pieces, scores):
if not pieces:
return []
out = []
for candidate in spec_candidates_for_label(text, label):
start = int(candidate["start"])
end = int(candidate["end"])
while start < end and text[start].isspace():
start += 1
while end > start and text[end - 1].isspace():
end -= 1
support = 0.0
for (_, piece_start, piece_end), score in zip(pieces, scores):
if piece_end <= start or piece_start >= end:
continue
support = max(support, float(score))
value = text[start:end]
if support >= threshold and plausible_label_text(label, value):
out.append(
{
"label": label,
"start": start,
"end": end,
"text": value,
"score": support,
"source": "scanner_guided",
}
)
return out
def pipeline_to_spans(text: str, outputs: list[dict], min_score: float):
spans = []
for output in outputs:
label = normalize_label(output.get("entity_group") or output.get("entity") or "")
if not label:
continue
score = float(output.get("score", 0.0))
if score < min_score:
continue
spans.append(
{
"label": label,
"start": int(output["start"]),
"end": int(output["end"]),
"score": score,
"text": text[int(output["start"]) : int(output["end"])],
}
)
return spans
def overlaps(a: dict, b: dict) -> bool:
return not (a["end"] <= b["start"] or b["end"] <= a["start"])
def span_length(span: dict) -> int:
return int(span["end"]) - int(span["start"])
def normalize_simple_span(span: dict):
label = normalize_label(span["label"])
value = span["text"]
if label == "PHONE_NUMBER" and plausible_label_text("CREDIT_DEBIT_CARD", value):
label = "CREDIT_DEBIT_CARD"
if label in FORMAT_LABELS or label == "POSTCODE":
if not plausible_label_text(label, value):
return None
return {
"label": label,
"start": int(span["start"]),
"end": int(span["end"]),
"score": float(span.get("score", 0.0)),
"text": value,
"source": span.get("source", "model"),
}
def dedupe_and_sort(spans: list[dict]):
ordered = sorted(
spans,
key=lambda span: (
int(span["start"]),
-span_length(span),
OUTPUT_PRIORITY.get(str(span["label"]).upper(), 99),
),
)
kept = []
for span in ordered:
if any(overlaps(span, other) for other in kept):
continue
kept.append(span)
return kept
def repair_irish_core_spans(
text: str,
model,
tokenizer,
general_outputs: list[dict],
other_min_score: float,
ppsn_min_score: float,
label_thresholds: dict[str, float] | None = None,
):
thresholds = dict(DEFAULT_LABEL_THRESHOLDS)
if label_thresholds:
thresholds.update({key.upper(): value for key, value in label_thresholds.items()})
spans = []
for span in pipeline_to_spans(text, general_outputs, min_score=other_min_score):
normalized = normalize_simple_span(span)
if normalized is not None and normalized["label"] != "PPSN":
spans.append(normalized)
ppsn_spans = word_aligned_ppsn_spans(text, model, tokenizer, threshold=ppsn_min_score)
for span in ppsn_spans:
value = text[int(span["start"]) : int(span["end"])]
if plausible_label_text("PPSN", value):
spans.append(
{
"label": "PPSN",
"start": int(span["start"]),
"end": int(span["end"]),
"score": float(span.get("score", 0.0)),
"text": value,
"source": span.get("source", "model"),
}
)
repairs = []
ppsn_pieces, ppsn_scores = word_scores_for_label(text, model, tokenizer, "PPSN")
repairs.extend(scanner_guided_label_spans(text, "PPSN", ppsn_min_score, ppsn_pieces, ppsn_scores))
for label, threshold in thresholds.items():
pieces, scores = word_scores_for_label(text, model, tokenizer, label)
repairs.extend(_word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores))
repairs.extend(scanner_guided_label_spans(text, label, threshold, pieces, scores))
for candidate in repairs:
updated = []
replaced = False
for span in spans:
if not overlaps(candidate, span):
updated.append(span)
continue
if candidate["label"] == span["label"] and span_length(candidate) > span_length(span):
replaced = True
continue
if (
candidate["label"] == span["label"]
and candidate.get("source") == "scanner_guided"
and span.get("source") != "scanner_guided"
):
replaced = True
continue
if candidate["label"] in FORMAT_LABELS and span["label"] in FORMAT_LABELS and span_length(candidate) > span_length(span):
replaced = True
continue
updated.append(span)
spans = updated
if replaced or not any(overlaps(candidate, span) for span in spans):
spans.append(candidate)
return dedupe_and_sort(spans)
def repair_irish_core_spans_onnx(
text: str,
session,
tokenizer,
config,
other_min_score: float,
ppsn_min_score: float,
label_thresholds: dict[str, float] | None = None,
general_outputs: list[dict] | None = None,
):
from onnx_token_classifier import simple_aggregate_spans_onnx, word_aligned_ppsn_spans_onnx
thresholds = dict(DEFAULT_LABEL_THRESHOLDS)
if label_thresholds:
thresholds.update({key.upper(): value for key, value in label_thresholds.items()})
spans = []
if general_outputs is None:
general_outputs = simple_aggregate_spans_onnx(
text,
session,
tokenizer,
config,
min_score=other_min_score,
)
for span in pipeline_to_spans(text, general_outputs, min_score=other_min_score):
normalized = normalize_simple_span(span)
if normalized is not None and normalized["label"] != "PPSN":
spans.append(normalized)
ppsn_spans = word_aligned_ppsn_spans_onnx(text, session, tokenizer, config, threshold=ppsn_min_score)
for span in ppsn_spans:
value = text[int(span["start"]) : int(span["end"])]
if plausible_label_text("PPSN", value):
spans.append(
{
"label": "PPSN",
"start": int(span["start"]),
"end": int(span["end"]),
"score": float(span.get("score", 0.0)),
"text": value,
"source": span.get("source", "model"),
}
)
repairs = []
ppsn_pieces, ppsn_scores = word_scores_for_label_onnx(text, session, tokenizer, config, "PPSN")
repairs.extend(scanner_guided_label_spans(text, "PPSN", ppsn_min_score, ppsn_pieces, ppsn_scores))
for label, threshold in thresholds.items():
pieces, scores = word_scores_for_label_onnx(text, session, tokenizer, config, label)
repairs.extend(_word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores))
repairs.extend(scanner_guided_label_spans(text, label, threshold, pieces, scores))
for candidate in repairs:
updated = []
replaced = False
for span in spans:
if not overlaps(candidate, span):
updated.append(span)
continue
if candidate["label"] == span["label"] and span_length(candidate) > span_length(span):
replaced = True
continue
if (
candidate["label"] == span["label"]
and candidate.get("source") == "scanner_guided"
and span.get("source") != "scanner_guided"
):
replaced = True
continue
if candidate["label"] in FORMAT_LABELS and span["label"] in FORMAT_LABELS and span_length(candidate) > span_length(span):
replaced = True
continue
updated.append(span)
spans = updated
if replaced or not any(overlaps(candidate, span) for span in spans):
spans.append(candidate)
return dedupe_and_sort(spans)