temsa's picture
Add rc6 release with decoder repair improvements
c487a4b verified
#!/usr/bin/env python3
import regex as re
import torch
from eircode import iter_eircode_candidates, is_valid_eircode
from ppsn import is_plausible_ppsn, iter_ppsn_candidates
from raw_word_aligned import word_aligned_ppsn_spans
TOKEN_RE = re.compile(r"[A-Za-z0-9]+|[^\w\s]", re.UNICODE)
PHONE_RE = re.compile(r"^(?:\+353(?:\s*\(0\))?[\s-]*\(?\d{1,3}\)?(?:[\s-]*\d){6,8}|0\(?\d{1,3}\)?(?:[\s-]*\d){6,8})$")
PHONE_CANDIDATE_RE = re.compile(r"(?<![A-Za-z0-9])(?:\+353(?:\s*\(0\))?[\s-]*\(?\d{1,3}\)?(?:[\s-]*\d){6,8}|0\(?\d{1,3}\)?(?:[\s-]*\d){6,8})(?![A-Za-z0-9])")
PASSPORT_RE = re.compile(r"^[A-Z]{2}\s?\d{7}$")
PASSPORT_CANDIDATE_RE = re.compile(r"(?<![A-Za-z0-9])[A-Z]{2}\s?\d{7}(?![A-Za-z0-9])", re.IGNORECASE)
SORT_RE = re.compile(r"^(?:\d{6}|\d{2}[ -]\d{2}[ -]\d{2})$")
BANK_ROUTING_CANDIDATE_RE = re.compile(r"(?<!\d)(?:\d{6}|\d{2}[ -]\d{2}[ -]\d{2})(?!\d)")
IBAN_IE_RE = re.compile(r"^IE\d{2}[A-Z]{4}\d{14}$")
IBAN_IE_CANDIDATE_RE = re.compile(r"(?<![A-Za-z0-9])IE\d{2}(?:[\s-]?[A-Z0-9]){18}(?![A-Za-z0-9])", re.IGNORECASE)
BIC_RE = re.compile(r"^[A-Z]{4}[A-Z]{2}[A-Z0-9]{2}(?:[A-Z0-9]{3})?$")
BIC_CANDIDATE_RE = re.compile(r"(?<![A-Za-z0-9])[A-Z]{4}[A-Z]{2}[A-Z0-9]{2}(?:[A-Z0-9]{3})?(?![A-Za-z0-9])", re.IGNORECASE)
EIRCODE_RE = re.compile(r"^(?:[ACDEFHKNPRTVWXY]\d{2}|D6W)\s?[0-9ACDEFHKNPRTVWXY]{4}$", re.IGNORECASE)
CARD_GROUPED_RE = re.compile(r"^(?:\d{4}(?:[ -]\d{4}){3,4}|\d{4}[ -]\d{6}[ -]\d{5})$")
CARD_CANDIDATE_RE = re.compile(r"(?<!\d)(?:\d[ -]?){13,19}(?!\d)")
KNOWN_IE_IBAN_BANK_CODES = {
"AIBK",
"BOFI",
"IPBS",
"IRCE",
"ULSB",
"PTSB",
"EBSI",
"DABA",
"CITI",
"TRWI",
"REVO",
}
DEFAULT_LABEL_THRESHOLDS = {
"PHONE_NUMBER": 0.35,
"PASSPORT_NUMBER": 0.11,
"BANK_ROUTING_NUMBER": 0.35,
"ACCOUNT_NUMBER": 0.40,
"CREDIT_DEBIT_CARD": 0.08,
"SWIFT_BIC": 0.50,
}
FORMAT_LABELS = set(DEFAULT_LABEL_THRESHOLDS)
OUTPUT_PRIORITY = {
"PPSN": 0,
"PASSPORT_NUMBER": 1,
"ACCOUNT_NUMBER": 2,
"BANK_ROUTING_NUMBER": 3,
"CREDIT_DEBIT_CARD": 4,
"PHONE_NUMBER": 5,
"SWIFT_BIC": 6,
"POSTCODE": 7,
"EMAIL": 8,
"FIRST_NAME": 9,
"LAST_NAME": 10,
}
def tokenize_with_spans(text: str):
return [(m.group(0), m.start(), m.end()) for m in TOKEN_RE.finditer(text)]
def normalize_label(label: str) -> str:
label = (label or "").strip()
if label.startswith("B-") or label.startswith("I-"):
label = label[2:]
return label.upper()
def luhn_ok(value: str) -> bool:
digits = "".join(ch for ch in value if ch.isdigit())
if not (13 <= len(digits) <= 19):
return False
total = 0
double = False
for ch in reversed(digits):
number = int(ch)
if double:
number *= 2
if number > 9:
number -= 9
total += number
double = not double
return total % 10 == 0
def iban_mod97_ok(value: str) -> bool:
compact = re.sub(r"[\s-]+", "", value.strip().upper())
if not IBAN_IE_RE.match(compact):
return False
rearranged = compact[4:] + compact[:4]
remainder = 0
for ch in rearranged:
if ch.isdigit():
digits = ch
else:
digits = str(ord(ch) - ord("A") + 10)
for digit in digits:
remainder = (remainder * 10 + int(digit)) % 97
return remainder == 1
def is_plausible_ie_iban(value: str) -> bool:
compact = re.sub(r"[\s-]+", "", value.strip().upper())
if not IBAN_IE_RE.match(compact):
return False
if iban_mod97_ok(compact):
return True
return compact[4:8] in KNOWN_IE_IBAN_BANK_CODES
def normalize_irish_phone(value: str) -> str:
compact = value.strip()
compact = compact.replace("(0)", "0")
compact = re.sub(r"[\s\-\(\)]", "", compact)
if compact.startswith("00353"):
compact = "+" + compact[2:]
return compact
def is_valid_irish_phone(value: str) -> bool:
compact = normalize_irish_phone(value)
if compact.startswith("+353"):
rest = compact[4:]
if rest.startswith("0"):
rest = rest[1:]
if not rest.isdigit():
return False
if rest.startswith("8"):
return len(rest) == 9
return len(rest) in {8, 9}
if not compact.startswith("0") or not compact.isdigit():
return False
if compact.startswith("08"):
return len(compact) == 10
return len(compact) in {9, 10}
def is_plausible_card(value: str) -> bool:
digits = "".join(ch for ch in value if ch.isdigit())
if not (13 <= len(digits) <= 19):
return False
if luhn_ok(value):
return True
return CARD_GROUPED_RE.match(value.strip()) is not None
def normalize_passport(value: str) -> str:
return re.sub(r"\s+", "", value.strip().upper())
def regex_candidates_for_label(text: str, label: str):
label = label.upper()
if label == "PPSN":
for candidate in iter_ppsn_candidates(text):
yield candidate
return
if label == "POSTCODE":
for candidate in iter_eircode_candidates(text):
yield candidate
return
pattern = {
"PHONE_NUMBER": PHONE_CANDIDATE_RE,
"PASSPORT_NUMBER": PASSPORT_CANDIDATE_RE,
"BANK_ROUTING_NUMBER": BANK_ROUTING_CANDIDATE_RE,
"ACCOUNT_NUMBER": IBAN_IE_CANDIDATE_RE,
"CREDIT_DEBIT_CARD": CARD_CANDIDATE_RE,
"SWIFT_BIC": BIC_CANDIDATE_RE,
}.get(label)
if pattern is None:
return
for match in pattern.finditer(text):
yield {
"start": match.start(),
"end": match.end(),
"text": match.group(0),
"normalized": match.group(0),
}
def plausible_label_text(label: str, value: str) -> bool:
value = value.strip()
if label == "PPSN":
return is_plausible_ppsn(value)
if label == "PHONE_NUMBER":
return PHONE_RE.match(value) is not None and is_valid_irish_phone(value)
if label == "PASSPORT_NUMBER":
return PASSPORT_RE.match(normalize_passport(value)) is not None
if label == "BANK_ROUTING_NUMBER":
return SORT_RE.match(value) is not None
if label == "ACCOUNT_NUMBER":
compact = re.sub(r"[\s-]+", "", value)
return is_plausible_ie_iban(value) or (compact.isdigit() and len(compact) == 8)
if label == "CREDIT_DEBIT_CARD":
return is_plausible_card(value)
if label == "SWIFT_BIC":
return BIC_RE.match(value.upper()) is not None
if label == "POSTCODE":
return EIRCODE_RE.match(value) is not None and is_valid_eircode(value)
return True
def label_ids_from_mapping(id2label, label: str):
target = label.upper()
ids = []
for raw_id, raw_label in id2label.items():
if normalize_label(str(raw_label)) == target:
ids.append(int(raw_id))
return ids
def label_ids(model, label: str):
return label_ids_from_mapping(model.config.id2label, label)
def word_scores_for_label(text: str, model, tokenizer, label: str):
pieces = tokenize_with_spans(text)
if not pieces:
return pieces, []
words = [word for word, _, _ in pieces]
encoded = tokenizer(words, is_split_into_words=True, return_tensors="pt", truncation=True)
word_ids = encoded.word_ids(batch_index=0)
device = next(model.parameters()).device
encoded = {key: value.to(device) for key, value in encoded.items()}
with torch.no_grad():
logits = model(**encoded).logits[0]
probs = torch.softmax(logits, dim=-1)
ids = label_ids(model, label)
scores = []
for word_index in range(len(pieces)):
score = 0.0
for token_index, wid in enumerate(word_ids):
if wid != word_index:
continue
for label_id in ids:
score = max(score, float(probs[token_index, label_id]))
scores.append(score)
return pieces, scores
def word_scores_for_label_onnx(text: str, session, tokenizer, config, label: str):
from onnx_token_classifier import _run_onnx, _softmax
pieces = tokenize_with_spans(text)
if not pieces:
return pieces, []
words = [word for word, _, _ in pieces]
encoded = tokenizer(words, is_split_into_words=True, return_tensors="np", truncation=True)
word_ids = encoded.word_ids(batch_index=0)
logits = _run_onnx(session, encoded)[0]
probs = _softmax(logits, axis=-1)
ids = label_ids_from_mapping(config.id2label, label)
scores = []
for word_index in range(len(pieces)):
score = 0.0
for token_index, wid in enumerate(word_ids):
if wid != word_index:
continue
for label_id in ids:
score = max(score, float(probs[token_index, label_id]))
scores.append(score)
return pieces, scores
def _word_aligned_label_spans_from_scores(text: str, label: str, threshold: float, pieces, scores):
spans = []
active = None
for (word, start, end), score in zip(pieces, scores):
keep = score >= threshold
if label in {"PHONE_NUMBER", "BANK_ROUTING_NUMBER", "CREDIT_DEBIT_CARD"} and word in {"-", "/"}:
keep = active is not None and score >= threshold / 2.0
if keep:
if active is None:
active = {"start": start, "end": end, "label": label}
else:
if start - active["end"] <= 1:
active["end"] = end
else:
spans.append(active)
active = {"start": start, "end": end, "label": label}
elif active is not None:
spans.append(active)
active = None
if active is not None:
spans.append(active)
out = []
for span in spans:
value = text[span["start"] : span["end"]]
if plausible_label_text(label, value):
out.append(
{
"label": label,
"start": span["start"],
"end": span["end"],
"text": value,
}
)
return out
def word_aligned_label_spans(
text: str,
model,
tokenizer,
label: str,
threshold: float,
):
pieces, scores = word_scores_for_label(text, model, tokenizer, label)
return _word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores)
def word_aligned_label_spans_onnx(
text: str,
session,
tokenizer,
config,
label: str,
threshold: float,
):
pieces, scores = word_scores_for_label_onnx(text, session, tokenizer, config, label)
return _word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores)
def regex_guided_label_spans(text: str, label: str, threshold: float, pieces, scores):
if not pieces:
return []
out = []
for candidate in regex_candidates_for_label(text, label):
start = int(candidate["start"])
end = int(candidate["end"])
while start < end and text[start].isspace():
start += 1
while end > start and text[end - 1].isspace():
end -= 1
support = 0.0
for (_, piece_start, piece_end), score in zip(pieces, scores):
if piece_end <= start or piece_start >= end:
continue
support = max(support, float(score))
value = text[start:end]
if support >= threshold and plausible_label_text(label, value):
out.append(
{
"label": label,
"start": start,
"end": end,
"text": value,
"score": support,
}
)
return out
def pipeline_to_spans(text: str, outputs: list[dict], min_score: float):
spans = []
for output in outputs:
label = normalize_label(output.get("entity_group") or output.get("entity") or "")
if not label:
continue
score = float(output.get("score", 0.0))
if score < min_score:
continue
spans.append(
{
"label": label,
"start": int(output["start"]),
"end": int(output["end"]),
"score": score,
"text": text[int(output["start"]) : int(output["end"])],
}
)
return spans
def overlaps(a: dict, b: dict) -> bool:
return not (a["end"] <= b["start"] or b["end"] <= a["start"])
def span_length(span: dict) -> int:
return int(span["end"]) - int(span["start"])
def normalize_simple_span(span: dict):
label = normalize_label(span["label"])
value = span["text"]
if label == "PHONE_NUMBER" and plausible_label_text("CREDIT_DEBIT_CARD", value):
label = "CREDIT_DEBIT_CARD"
if label in FORMAT_LABELS or label == "POSTCODE":
if not plausible_label_text(label, value):
return None
return {
"label": label,
"start": int(span["start"]),
"end": int(span["end"]),
"score": float(span.get("score", 0.0)),
"text": value,
}
def dedupe_and_sort(spans: list[dict]):
ordered = sorted(
spans,
key=lambda span: (
int(span["start"]),
-span_length(span),
OUTPUT_PRIORITY.get(str(span["label"]).upper(), 99),
),
)
kept = []
for span in ordered:
if any(overlaps(span, other) for other in kept):
continue
kept.append(span)
return kept
def repair_irish_core_spans(
text: str,
model,
tokenizer,
general_outputs: list[dict],
other_min_score: float,
ppsn_min_score: float,
label_thresholds: dict[str, float] | None = None,
):
thresholds = dict(DEFAULT_LABEL_THRESHOLDS)
if label_thresholds:
thresholds.update({key.upper(): value for key, value in label_thresholds.items()})
spans = []
for span in pipeline_to_spans(text, general_outputs, min_score=other_min_score):
normalized = normalize_simple_span(span)
if normalized is not None and normalized["label"] != "PPSN":
spans.append(normalized)
ppsn_spans = word_aligned_ppsn_spans(text, model, tokenizer, threshold=ppsn_min_score)
for span in ppsn_spans:
value = text[int(span["start"]) : int(span["end"])]
if plausible_label_text("PPSN", value):
spans.append(
{
"label": "PPSN",
"start": int(span["start"]),
"end": int(span["end"]),
"score": float(span.get("score", 0.0)),
"text": value,
}
)
repairs = []
ppsn_pieces, ppsn_scores = word_scores_for_label(text, model, tokenizer, "PPSN")
repairs.extend(regex_guided_label_spans(text, "PPSN", ppsn_min_score, ppsn_pieces, ppsn_scores))
for label, threshold in thresholds.items():
pieces, scores = word_scores_for_label(text, model, tokenizer, label)
repairs.extend(_word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores))
repairs.extend(regex_guided_label_spans(text, label, threshold, pieces, scores))
for candidate in repairs:
updated = []
replaced = False
for span in spans:
if not overlaps(candidate, span):
updated.append(span)
continue
if candidate["label"] == span["label"] and span_length(candidate) > span_length(span):
replaced = True
continue
if candidate["label"] in FORMAT_LABELS and span["label"] in FORMAT_LABELS and span_length(candidate) > span_length(span):
replaced = True
continue
updated.append(span)
spans = updated
if replaced or not any(overlaps(candidate, span) for span in spans):
spans.append(candidate)
return dedupe_and_sort(spans)
def repair_irish_core_spans_onnx(
text: str,
session,
tokenizer,
config,
other_min_score: float,
ppsn_min_score: float,
label_thresholds: dict[str, float] | None = None,
general_outputs: list[dict] | None = None,
):
from onnx_token_classifier import simple_aggregate_spans_onnx, word_aligned_ppsn_spans_onnx
thresholds = dict(DEFAULT_LABEL_THRESHOLDS)
if label_thresholds:
thresholds.update({key.upper(): value for key, value in label_thresholds.items()})
spans = []
if general_outputs is None:
general_outputs = simple_aggregate_spans_onnx(
text,
session,
tokenizer,
config,
min_score=other_min_score,
)
for span in pipeline_to_spans(text, general_outputs, min_score=other_min_score):
normalized = normalize_simple_span(span)
if normalized is not None and normalized["label"] != "PPSN":
spans.append(normalized)
ppsn_spans = word_aligned_ppsn_spans_onnx(text, session, tokenizer, config, threshold=ppsn_min_score)
for span in ppsn_spans:
value = text[int(span["start"]) : int(span["end"])]
if plausible_label_text("PPSN", value):
spans.append(
{
"label": "PPSN",
"start": int(span["start"]),
"end": int(span["end"]),
"score": float(span.get("score", 0.0)),
"text": value,
}
)
repairs = []
ppsn_pieces, ppsn_scores = word_scores_for_label_onnx(text, session, tokenizer, config, "PPSN")
repairs.extend(regex_guided_label_spans(text, "PPSN", ppsn_min_score, ppsn_pieces, ppsn_scores))
for label, threshold in thresholds.items():
pieces, scores = word_scores_for_label_onnx(text, session, tokenizer, config, label)
repairs.extend(_word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores))
repairs.extend(regex_guided_label_spans(text, label, threshold, pieces, scores))
for candidate in repairs:
updated = []
replaced = False
for span in spans:
if not overlaps(candidate, span):
updated.append(span)
continue
if candidate["label"] == span["label"] and span_length(candidate) > span_length(span):
replaced = True
continue
if candidate["label"] in FORMAT_LABELS and span["label"] in FORMAT_LABELS and span_length(candidate) > span_length(span):
replaced = True
continue
updated.append(span)
spans = updated
if replaced or not any(overlaps(candidate, span) for span in spans):
spans.append(candidate)
return dedupe_and_sort(spans)