temsa's picture
Add files using upload-large-folder tool
d1ddb26 verified
#!/usr/bin/env python3
import regex as re
import torch
from raw_word_aligned import word_aligned_ppsn_spans
TOKEN_RE = re.compile(r"[A-Za-z0-9]+|[^\w\s]", re.UNICODE)
PHONE_RE = re.compile(r"^(?:\+353\s?(?:\(0\))?\s?\d(?:[\s-]?\d){7,8}|0\d(?:[\s-]?\d){7,8})$")
PASSPORT_RE = re.compile(r"^[A-Z]{2}\s?\d{7}$")
SORT_RE = re.compile(r"^(?:\d{6}|\d{2}[ -]\d{2}[ -]\d{2})$")
IBAN_IE_RE = re.compile(r"^IE\d{2}(?:\s?[A-Z]{4})(?:\s?\d{4}){3}\s?\d{2}$")
BIC_RE = re.compile(r"^[A-Z]{4}[A-Z]{2}[A-Z0-9]{2}(?:[A-Z0-9]{3})?$")
EIRCODE_RE = re.compile(r"^(?:[ACDEFHKNPRTVWXY]\d{2}|D6W)\s?[0-9ACDEFHKNPRTVWXY]{4}$", re.IGNORECASE)
DEFAULT_LABEL_THRESHOLDS = {
"PHONE_NUMBER": 0.35,
"PASSPORT_NUMBER": 0.11,
"BANK_ROUTING_NUMBER": 0.35,
"ACCOUNT_NUMBER": 0.40,
"CREDIT_DEBIT_CARD": 0.08,
"SWIFT_BIC": 0.50,
}
FORMAT_LABELS = set(DEFAULT_LABEL_THRESHOLDS)
OUTPUT_PRIORITY = {
"PPSN": 0,
"PASSPORT_NUMBER": 1,
"ACCOUNT_NUMBER": 2,
"BANK_ROUTING_NUMBER": 3,
"CREDIT_DEBIT_CARD": 4,
"PHONE_NUMBER": 5,
"SWIFT_BIC": 6,
"POSTCODE": 7,
"EMAIL": 8,
"FIRST_NAME": 9,
"LAST_NAME": 10,
}
def tokenize_with_spans(text: str):
return [(m.group(0), m.start(), m.end()) for m in TOKEN_RE.finditer(text)]
def normalize_label(label: str) -> str:
label = (label or "").strip()
if label.startswith("B-") or label.startswith("I-"):
label = label[2:]
return label.upper()
def luhn_ok(value: str) -> bool:
digits = "".join(ch for ch in value if ch.isdigit())
if not (13 <= len(digits) <= 19):
return False
total = 0
double = False
for ch in reversed(digits):
number = int(ch)
if double:
number *= 2
if number > 9:
number -= 9
total += number
double = not double
return total % 10 == 0
def plausible_label_text(label: str, value: str) -> bool:
value = value.strip()
if label == "PHONE_NUMBER":
return PHONE_RE.match(value) is not None
if label == "PASSPORT_NUMBER":
return PASSPORT_RE.match(value) is not None
if label == "BANK_ROUTING_NUMBER":
return SORT_RE.match(value) is not None
if label == "ACCOUNT_NUMBER":
compact = value.replace(" ", "")
return IBAN_IE_RE.match(value) is not None or (compact.isdigit() and len(compact) == 8)
if label == "CREDIT_DEBIT_CARD":
return luhn_ok(value)
if label == "SWIFT_BIC":
return BIC_RE.match(value) is not None
if label == "POSTCODE":
return EIRCODE_RE.match(value) is not None
return True
def label_ids_from_mapping(id2label, label: str):
target = label.upper()
ids = []
for raw_id, raw_label in id2label.items():
if normalize_label(str(raw_label)) == target:
ids.append(int(raw_id))
return ids
def label_ids(model, label: str):
return label_ids_from_mapping(model.config.id2label, label)
def word_scores_for_label(text: str, model, tokenizer, label: str):
pieces = tokenize_with_spans(text)
if not pieces:
return pieces, []
words = [word for word, _, _ in pieces]
encoded = tokenizer(words, is_split_into_words=True, return_tensors="pt", truncation=True)
word_ids = encoded.word_ids(batch_index=0)
device = next(model.parameters()).device
encoded = {key: value.to(device) for key, value in encoded.items()}
with torch.no_grad():
logits = model(**encoded).logits[0]
probs = torch.softmax(logits, dim=-1)
ids = label_ids(model, label)
scores = []
for word_index in range(len(pieces)):
score = 0.0
for token_index, wid in enumerate(word_ids):
if wid != word_index:
continue
for label_id in ids:
score = max(score, float(probs[token_index, label_id]))
scores.append(score)
return pieces, scores
def word_scores_for_label_onnx(text: str, session, tokenizer, config, label: str):
from onnx_token_classifier import _run_onnx, _softmax
pieces = tokenize_with_spans(text)
if not pieces:
return pieces, []
words = [word for word, _, _ in pieces]
encoded = tokenizer(words, is_split_into_words=True, return_tensors="np", truncation=True)
word_ids = encoded.word_ids(batch_index=0)
logits = _run_onnx(session, encoded)[0]
probs = _softmax(logits, axis=-1)
ids = label_ids_from_mapping(config.id2label, label)
scores = []
for word_index in range(len(pieces)):
score = 0.0
for token_index, wid in enumerate(word_ids):
if wid != word_index:
continue
for label_id in ids:
score = max(score, float(probs[token_index, label_id]))
scores.append(score)
return pieces, scores
def _word_aligned_label_spans_from_scores(text: str, label: str, threshold: float, pieces, scores):
spans = []
active = None
for (word, start, end), score in zip(pieces, scores):
keep = score >= threshold
if label in {"PHONE_NUMBER", "BANK_ROUTING_NUMBER", "CREDIT_DEBIT_CARD"} and word in {"-", "/"}:
keep = active is not None and score >= threshold / 2.0
if keep:
if active is None:
active = {"start": start, "end": end, "label": label}
else:
if start - active["end"] <= 1:
active["end"] = end
else:
spans.append(active)
active = {"start": start, "end": end, "label": label}
elif active is not None:
spans.append(active)
active = None
if active is not None:
spans.append(active)
out = []
for span in spans:
value = text[span["start"] : span["end"]]
if plausible_label_text(label, value):
out.append(
{
"label": label,
"start": span["start"],
"end": span["end"],
"text": value,
}
)
return out
def word_aligned_label_spans(
text: str,
model,
tokenizer,
label: str,
threshold: float,
):
pieces, scores = word_scores_for_label(text, model, tokenizer, label)
return _word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores)
def word_aligned_label_spans_onnx(
text: str,
session,
tokenizer,
config,
label: str,
threshold: float,
):
pieces, scores = word_scores_for_label_onnx(text, session, tokenizer, config, label)
return _word_aligned_label_spans_from_scores(text, label, threshold, pieces, scores)
def pipeline_to_spans(text: str, outputs: list[dict], min_score: float):
spans = []
for output in outputs:
label = normalize_label(output.get("entity_group") or output.get("entity") or "")
if not label:
continue
score = float(output.get("score", 0.0))
if score < min_score:
continue
spans.append(
{
"label": label,
"start": int(output["start"]),
"end": int(output["end"]),
"score": score,
"text": text[int(output["start"]) : int(output["end"])],
}
)
return spans
def overlaps(a: dict, b: dict) -> bool:
return not (a["end"] <= b["start"] or b["end"] <= a["start"])
def span_length(span: dict) -> int:
return int(span["end"]) - int(span["start"])
def normalize_simple_span(span: dict):
label = normalize_label(span["label"])
value = span["text"]
if label == "PHONE_NUMBER" and plausible_label_text("CREDIT_DEBIT_CARD", value):
label = "CREDIT_DEBIT_CARD"
if label in FORMAT_LABELS or label == "POSTCODE":
if not plausible_label_text(label, value):
return None
return {
"label": label,
"start": int(span["start"]),
"end": int(span["end"]),
"score": float(span.get("score", 0.0)),
"text": value,
}
def dedupe_and_sort(spans: list[dict]):
ordered = sorted(
spans,
key=lambda span: (
int(span["start"]),
-span_length(span),
OUTPUT_PRIORITY.get(str(span["label"]).upper(), 99),
),
)
kept = []
for span in ordered:
if any(overlaps(span, other) for other in kept):
continue
kept.append(span)
return kept
def repair_irish_core_spans(
text: str,
model,
tokenizer,
general_outputs: list[dict],
other_min_score: float,
ppsn_min_score: float,
label_thresholds: dict[str, float] | None = None,
):
thresholds = dict(DEFAULT_LABEL_THRESHOLDS)
if label_thresholds:
thresholds.update({key.upper(): value for key, value in label_thresholds.items()})
spans = []
for span in pipeline_to_spans(text, general_outputs, min_score=other_min_score):
normalized = normalize_simple_span(span)
if normalized is not None and normalized["label"] != "PPSN":
spans.append(normalized)
ppsn_spans = word_aligned_ppsn_spans(text, model, tokenizer, threshold=ppsn_min_score)
for span in ppsn_spans:
spans.append(
{
"label": "PPSN",
"start": int(span["start"]),
"end": int(span["end"]),
"score": float(span.get("score", 0.0)),
"text": text[int(span["start"]) : int(span["end"])],
}
)
repairs = []
for label, threshold in thresholds.items():
repairs.extend(word_aligned_label_spans(text, model, tokenizer, label, threshold))
for candidate in repairs:
updated = []
replaced = False
for span in spans:
if not overlaps(candidate, span):
updated.append(span)
continue
if candidate["label"] == span["label"] and span_length(candidate) > span_length(span):
replaced = True
continue
if candidate["label"] in FORMAT_LABELS and span["label"] in FORMAT_LABELS and span_length(candidate) > span_length(span):
replaced = True
continue
updated.append(span)
spans = updated
if replaced or not any(overlaps(candidate, span) for span in spans):
spans.append(candidate)
return dedupe_and_sort(spans)
def repair_irish_core_spans_onnx(
text: str,
session,
tokenizer,
config,
other_min_score: float,
ppsn_min_score: float,
label_thresholds: dict[str, float] | None = None,
general_outputs: list[dict] | None = None,
):
from onnx_token_classifier import simple_aggregate_spans_onnx, word_aligned_ppsn_spans_onnx
thresholds = dict(DEFAULT_LABEL_THRESHOLDS)
if label_thresholds:
thresholds.update({key.upper(): value for key, value in label_thresholds.items()})
spans = []
if general_outputs is None:
general_outputs = simple_aggregate_spans_onnx(
text,
session,
tokenizer,
config,
min_score=other_min_score,
)
for span in pipeline_to_spans(text, general_outputs, min_score=other_min_score):
normalized = normalize_simple_span(span)
if normalized is not None and normalized["label"] != "PPSN":
spans.append(normalized)
ppsn_spans = word_aligned_ppsn_spans_onnx(text, session, tokenizer, config, threshold=ppsn_min_score)
for span in ppsn_spans:
spans.append(
{
"label": "PPSN",
"start": int(span["start"]),
"end": int(span["end"]),
"score": float(span.get("score", 0.0)),
"text": text[int(span["start"]) : int(span["end"])],
}
)
repairs = []
for label, threshold in thresholds.items():
repairs.extend(word_aligned_label_spans_onnx(text, session, tokenizer, config, label, threshold))
for candidate in repairs:
updated = []
replaced = False
for span in spans:
if not overlaps(candidate, span):
updated.append(span)
continue
if candidate["label"] == span["label"] and span_length(candidate) > span_length(span):
replaced = True
continue
if candidate["label"] in FORMAT_LABELS and span["label"] in FORMAT_LABELS and span_length(candidate) > span_length(span):
replaced = True
continue
updated.append(span)
spans = updated
if replaced or not any(overlaps(candidate, span) for span in spans):
spans.append(candidate)
return dedupe_and_sort(spans)