#!/usr/bin/env python3 from __future__ import annotations import json import tempfile from pathlib import Path from typing import Any import numpy as np from huggingface_hub import HfApi, hf_hub_download from transformers import AutoConfig, AutoTokenizer TOKENIZER_FILES = [ "tokenizer_config.json", "tokenizer.json", "special_tokens_map.json", "vocab.txt", "vocab.json", "merges.txt", "added_tokens.json", "sentencepiece.bpe.model", "spiece.model", ] DEFAULT_LABEL_MAX_SPAN_TOKENS = { # Token-piece limits, not word limits. These need to reflect how the # underlying tokenizer actually fragments compact identifiers. "PPSN": 9, "POSTCODE": 7, "PHONE_NUMBER": 10, "PASSPORT_NUMBER": 8, "BANK_ROUTING_NUMBER": 5, "ACCOUNT_NUMBER": 19, "CREDIT_DEBIT_CARD": 12, "SWIFT_BIC": 8, "EMAIL": 15, "FIRST_NAME": 5, "LAST_NAME": 8, } DEFAULT_LABEL_MIN_NONSPACE_CHARS = { "PPSN": 8, "POSTCODE": 6, "PHONE_NUMBER": 7, "PASSPORT_NUMBER": 7, "BANK_ROUTING_NUMBER": 6, "ACCOUNT_NUMBER": 6, "CREDIT_DEBIT_CARD": 12, "SWIFT_BIC": 8, "EMAIL": 6, "FIRST_NAME": 2, "LAST_NAME": 2, } WHITESPACE_BRIDGE_LABELS = { "PPSN", "POSTCODE", "PHONE_NUMBER", "PASSPORT_NUMBER", "BANK_ROUTING_NUMBER", "ACCOUNT_NUMBER", "CREDIT_DEBIT_CARD", "SWIFT_BIC", } SIMPLE_PUNCT_BRIDGE_LABELS = { "PHONE_NUMBER", "BANK_ROUTING_NUMBER", "ACCOUNT_NUMBER", "CREDIT_DEBIT_CARD", } MIN_CHAR_FALLBACK_LABELS = { "PHONE_NUMBER", "BANK_ROUTING_NUMBER", "ACCOUNT_NUMBER", "CREDIT_DEBIT_CARD", "EMAIL", } CONSERVATIVE_BOUNDARY_REFINEMENT_LABELS = { "PPSN", "POSTCODE", "PHONE_NUMBER", "PASSPORT_NUMBER", "BANK_ROUTING_NUMBER", "ACCOUNT_NUMBER", "CREDIT_DEBIT_CARD", "SWIFT_BIC", "EMAIL", } OUTPUT_PRIORITY = { "PPSN": 0, "PASSPORT_NUMBER": 1, "ACCOUNT_NUMBER": 2, "BANK_ROUTING_NUMBER": 3, "CREDIT_DEBIT_CARD": 4, "PHONE_NUMBER": 5, "SWIFT_BIC": 6, "POSTCODE": 7, "EMAIL": 8, "FIRST_NAME": 9, "LAST_NAME": 10, } def normalize_entity_name(label: str) -> str: label = (label or "").strip() if label.startswith("B-") or label.startswith("I-"): label = label[2:] return label.upper() def _sanitize_tokenizer_dir(tokenizer_path: Path) -> str: tokenizer_cfg_path = tokenizer_path / "tokenizer_config.json" if not tokenizer_cfg_path.exists(): return str(tokenizer_path) data = json.loads(tokenizer_cfg_path.read_text(encoding="utf-8")) if "fix_mistral_regex" not in data: return str(tokenizer_path) tmpdir = Path(tempfile.mkdtemp(prefix="openmed_span_tokenizer_")) keep = set(TOKENIZER_FILES) for child in tokenizer_path.iterdir(): if child.is_file() and child.name in keep: (tmpdir / child.name).write_bytes(child.read_bytes()) data.pop("fix_mistral_regex", None) (tmpdir / "tokenizer_config.json").write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") return str(tmpdir) def safe_auto_tokenizer(tokenizer_ref: str): tokenizer_path = Path(tokenizer_ref) if tokenizer_path.exists(): tokenizer_ref = _sanitize_tokenizer_dir(tokenizer_path) else: api = HfApi() files = set(api.list_repo_files(repo_id=tokenizer_ref, repo_type="model")) tmpdir = Path(tempfile.mkdtemp(prefix="openmed_remote_span_tokenizer_")) copied = False for name in TOKENIZER_FILES: if name not in files: continue src = hf_hub_download(repo_id=tokenizer_ref, filename=name, repo_type="model") (tmpdir / Path(name).name).write_bytes(Path(src).read_bytes()) copied = True if copied: tokenizer_ref = _sanitize_tokenizer_dir(tmpdir) try: return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=True) except Exception: pass try: return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=False) except TypeError: pass try: return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True) except Exception: return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=False) def label_names_from_config(config) -> list[str]: names = list(getattr(config, "span_label_names", [])) if not names: raise ValueError("Missing span_label_names in config") return [normalize_entity_name(name) for name in names] def label_thresholds_from_config(config, default_threshold: float) -> dict[str, float]: raw = getattr(config, "span_label_thresholds", None) or {} out = {normalize_entity_name(key): float(value) for key, value in raw.items()} for label in label_names_from_config(config): out.setdefault(label, float(default_threshold)) return out def token_label_thresholds_from_config(config, default_threshold: float) -> dict[str, float]: raw = getattr(config, "token_label_thresholds", None) or {} out = {normalize_entity_name(key): float(value) for key, value in raw.items()} for label in label_names_from_config(config): out.setdefault(label, float(default_threshold)) return out def token_extend_thresholds_from_config(config, default_fraction: float = 0.6) -> dict[str, float]: raw = getattr(config, "token_extend_thresholds", None) or {} out = {normalize_entity_name(key): float(value) for key, value in raw.items()} for label in label_names_from_config(config): out.setdefault(label, max(0.0, min(1.0, float(token_label_thresholds_from_config(config, 0.5).get(label, 0.5)) * default_fraction))) return out def boundary_label_thresholds_from_config(config, default_threshold: float = 0.0) -> dict[str, float]: raw = getattr(config, "boundary_label_thresholds", None) or {} out = {normalize_entity_name(key): float(value) for key, value in raw.items()} for label in label_names_from_config(config): out.setdefault(label, float(default_threshold)) return out def label_max_span_tokens_from_config(config) -> dict[str, int]: raw = getattr(config, "span_label_max_span_tokens", None) or {} out = {normalize_entity_name(key): int(value) for key, value in raw.items()} for label, value in DEFAULT_LABEL_MAX_SPAN_TOKENS.items(): out.setdefault(label, value) for label in label_names_from_config(config): out.setdefault(label, 8) return out def label_min_nonspace_chars_from_config(config) -> dict[str, int]: raw = getattr(config, "span_label_min_nonspace_chars", None) or {} out = {normalize_entity_name(key): int(value) for key, value in raw.items()} for label, value in DEFAULT_LABEL_MIN_NONSPACE_CHARS.items(): out.setdefault(label, value) for label in label_names_from_config(config): out.setdefault(label, 1) return out def overlaps(a: dict, b: dict) -> bool: return not (a["end"] <= b["start"] or b["end"] <= a["start"]) def dedupe_spans(spans: list[dict]) -> list[dict]: ordered = sorted( spans, key=lambda item: (-float(item.get("score", 0.0)), item["start"], item["end"], OUTPUT_PRIORITY.get(item["label"], 99)), ) kept = [] for span in ordered: if any(overlaps(span, other) for other in kept): continue kept.append(span) kept.sort(key=lambda item: (item["start"], item["end"], OUTPUT_PRIORITY.get(item["label"], 99))) return kept def _valid_offset(offset: tuple[int, int]) -> bool: return bool(offset) and offset[1] > offset[0] def _has_skippable_bridge(text: str, left: tuple[int, int], right: tuple[int, int], label: str) -> bool: bridge = text[int(left[1]) : int(right[0])] if bridge == "": return True if label == "PPSN" and bridge.isspace(): next_token = _token_text(text, right).strip() return 0 < len(next_token) <= 2 and next_token.isalnum() if label in WHITESPACE_BRIDGE_LABELS and bridge.isspace(): return True if label in SIMPLE_PUNCT_BRIDGE_LABELS: normalized = bridge.replace("\u00A0", " ").replace("\u202F", " ").strip() if normalized == "-": return True return False def _has_left_extension_bridge(text: str, left: tuple[int, int], right: tuple[int, int]) -> bool: bridge = text[int(left[1]) : int(right[0])] return bridge == "" def _nonspace_length(text: str, start: int, end: int) -> int: return sum(0 if ch.isspace() else 1 for ch in text[int(start) : int(end)]) def _is_simple_punct_token(text: str, offset: tuple[int, int], label: str) -> bool: if label not in SIMPLE_PUNCT_BRIDGE_LABELS or not _valid_offset(offset): return False token_text = text[int(offset[0]) : int(offset[1])].replace("\u00A0", " ").replace("\u202F", " ").strip() return token_text == "-" def _token_text(text: str, offset: tuple[int, int]) -> str: return text[int(offset[0]) : int(offset[1])] def _is_short_alnum_token(text: str, offset: tuple[int, int], max_len: int = 4) -> bool: token_text = _token_text(text, offset).strip() return 0 < len(token_text) <= max_len and token_text.isalnum() def _rescue_structured_start( text: str, offsets: list[tuple[int, int]], valid: list[bool], token_scores: np.ndarray, start_scores: np.ndarray, label: str, label_index: int, threshold: float, boundary_threshold: float, start_idx: int, end_idx: int, ) -> int | None: if label not in {"ACCOUNT_NUMBER", "CREDIT_DEBIT_CARD"}: return None segment_text = text[int(offsets[start_idx][0]) : int(offsets[end_idx][1])] if label == "ACCOUNT_NUMBER" and not any(ch.isspace() for ch in segment_text): return None best_idx = None best_score = -1.0 for cand_idx in range(start_idx, end_idx + 1): if not valid[cand_idx]: continue token_score = float(token_scores[cand_idx, label_index]) start_score = float(start_scores[cand_idx, label_index]) if token_score < threshold or start_score < boundary_threshold: continue token_text = _token_text(text, offsets[cand_idx]).strip() score = start_score + 0.2 * token_score if label == "ACCOUNT_NUMBER": next_text = _token_text(text, offsets[cand_idx + 1]).strip() if cand_idx + 1 <= end_idx and valid[cand_idx + 1] else "" if token_text.upper() == "I" and next_text.upper() == "E": score += 1.0 elif token_text.upper().startswith("IE"): score += 0.6 elif label == "CREDIT_DEBIT_CARD" and token_text.isdigit(): score += 0.3 if score > best_score: best_idx = cand_idx best_score = score return best_idx def _rescue_email_outer_span(span_text: str, outer_text: str) -> bool: if "@" not in span_text or " " in outer_text: return False if "@" not in outer_text: return False _, _, span_domain = span_text.partition("@") _, _, outer_domain = outer_text.partition("@") if "." in span_domain and not span_text.endswith("@"): return False return "." in outer_domain and not outer_text.endswith("@") def _rescue_iban_tail(text: str, offsets: list[tuple[int, int]], valid: list[bool], start_idx: int, end_idx: int) -> int: next_idx = end_idx + 1 span_text = text[int(offsets[start_idx][0]) : int(offsets[end_idx][1])] if not any(ch.isspace() for ch in span_text): return end_idx compact = "".join(ch for ch in span_text if not ch.isspace()) if not compact.upper().startswith("IE"): return end_idx while next_idx < len(offsets) and valid[next_idx]: if not _has_skippable_bridge(text, offsets[end_idx], offsets[next_idx], "ACCOUNT_NUMBER"): break if not _is_short_alnum_token(text, offsets[next_idx]): break end_idx = next_idx span_text = text[int(offsets[start_idx][0]) : int(offsets[end_idx][1])] compact = "".join(ch for ch in span_text if not ch.isspace()) if len(compact) >= 22: break next_idx += 1 return end_idx def decode_span_logits( text: str, offsets: list[tuple[int, int]], start_scores: np.ndarray, end_scores: np.ndarray, label_names: list[str], default_threshold: float, label_thresholds: dict[str, float] | None = None, label_max_span_tokens: dict[str, int] | None = None, ) -> list[dict]: thresholds = {label: float(default_threshold) for label in label_names} if label_thresholds: thresholds.update({normalize_entity_name(key): float(value) for key, value in label_thresholds.items()}) max_tokens = dict(DEFAULT_LABEL_MAX_SPAN_TOKENS) if label_max_span_tokens: max_tokens.update({normalize_entity_name(key): int(value) for key, value in label_max_span_tokens.items()}) spans: list[dict] = [] for label_index, label in enumerate(label_names): threshold = thresholds.get(label, float(default_threshold)) max_span = max_tokens.get(label, 8) start_candidates = [idx for idx in range(len(offsets)) if _valid_offset(offsets[idx]) and float(start_scores[idx, label_index]) >= threshold] for start_idx in start_candidates: best = None for end_idx in range(start_idx, min(len(offsets), start_idx + max_span)): if not _valid_offset(offsets[end_idx]): continue end_score = float(end_scores[end_idx, label_index]) if end_score < threshold: continue score = min(float(start_scores[start_idx, label_index]), end_score) if best is None or score > best["score"]: best = { "label": label, "start": int(offsets[start_idx][0]), "end": int(offsets[end_idx][1]), "score": score, } if best is not None and best["start"] < best["end"]: best["text"] = text[best["start"]:best["end"]] spans.append(best) return dedupe_spans(spans) def decode_token_presence_segments( text: str, offsets: list[tuple[int, int]], token_scores: np.ndarray, label_names: list[str], default_threshold: float, label_thresholds: dict[str, float] | None = None, label_extend_thresholds: dict[str, float] | None = None, label_max_span_tokens: dict[str, int] | None = None, label_min_nonspace_chars: dict[str, int] | None = None, boundary_label_thresholds: dict[str, float] | None = None, start_scores: np.ndarray | None = None, end_scores: np.ndarray | None = None, ) -> list[dict]: thresholds = {label: float(default_threshold) for label in label_names} if label_thresholds: thresholds.update({normalize_entity_name(key): float(value) for key, value in label_thresholds.items()}) extend_thresholds = {label: max(0.0, min(1.0, thresholds[label] * 0.6)) for label in label_names} if label_extend_thresholds: extend_thresholds.update({normalize_entity_name(key): float(value) for key, value in label_extend_thresholds.items()}) max_tokens = dict(DEFAULT_LABEL_MAX_SPAN_TOKENS) if label_max_span_tokens: max_tokens.update({normalize_entity_name(key): int(value) for key, value in label_max_span_tokens.items()}) min_nonspace_chars = dict(DEFAULT_LABEL_MIN_NONSPACE_CHARS) if label_min_nonspace_chars: min_nonspace_chars.update({normalize_entity_name(key): int(value) for key, value in label_min_nonspace_chars.items()}) boundary_thresholds = {label: 0.0 for label in label_names} if boundary_label_thresholds: boundary_thresholds.update({normalize_entity_name(key): float(value) for key, value in boundary_label_thresholds.items()}) spans: list[dict] = [] valid = [_valid_offset(offset) for offset in offsets] num_tokens = len(offsets) for label_index, label in enumerate(label_names): threshold = thresholds.get(label, float(default_threshold)) extend_threshold = min(threshold, extend_thresholds.get(label, threshold)) max_span = max_tokens.get(label, 8) idx = 0 while idx < num_tokens: if not valid[idx] or float(token_scores[idx, label_index]) < threshold: idx += 1 continue start_idx = idx end_idx = idx outer_start_idx = start_idx outer_end_idx = end_idx while end_idx + 1 < num_tokens and valid[end_idx + 1] and float(token_scores[end_idx + 1, label_index]) >= threshold and (end_idx + 1 - start_idx + 1) <= max_span: end_idx += 1 while ( start_idx - 1 >= 0 and valid[start_idx - 1] and _has_left_extension_bridge(text, offsets[start_idx - 1], offsets[start_idx]) and float(token_scores[start_idx - 1, label_index]) >= extend_threshold and (end_idx - (start_idx - 1) + 1) <= max_span ): start_idx -= 1 while end_idx + 1 < num_tokens: next_idx = end_idx + 1 if not valid[next_idx]: break if ( _has_skippable_bridge(text, offsets[end_idx], offsets[next_idx], label) and float(token_scores[next_idx, label_index]) >= extend_threshold and (next_idx - start_idx + 1) <= max_span ): end_idx = next_idx continue if ( _is_simple_punct_token(text, offsets[next_idx], label) and next_idx + 1 < num_tokens and valid[next_idx + 1] and _has_skippable_bridge(text, offsets[end_idx], offsets[next_idx], label) and _has_skippable_bridge(text, offsets[next_idx], offsets[next_idx + 1], label) and float(token_scores[next_idx + 1, label_index]) >= extend_threshold and ((next_idx + 1) - start_idx + 1) <= max_span ): end_idx = next_idx + 1 continue break outer_start_idx = start_idx outer_end_idx = end_idx presence_slice = token_scores[start_idx : end_idx + 1, label_index] score = float(presence_slice.mean()) out_start_idx = start_idx out_end_idx = end_idx if start_scores is not None and end_scores is not None: refine_window = min(3, end_idx - start_idx + 1) start_window = start_scores[start_idx : start_idx + refine_window, label_index] best_start_rel = int(np.argmax(start_window)) best_start_idx = start_idx + best_start_rel end_window_start = max(best_start_idx, end_idx - refine_window + 1) end_window = end_scores[end_window_start : end_idx + 1, label_index] best_end_rel = int(np.argmax(end_window)) best_end_idx = end_window_start + best_end_rel if ( float(start_scores[best_start_idx, label_index]) < boundary_thresholds.get(label, 0.0) or float(end_scores[best_end_idx, label_index]) < boundary_thresholds.get(label, 0.0) ): rescued_start_idx = _rescue_structured_start( text, offsets, valid, token_scores, start_scores, label, label_index, threshold, boundary_thresholds.get(label, 0.0), start_idx, end_idx, ) if rescued_start_idx is not None: out_start_idx = rescued_start_idx out_end_idx = end_idx else: idx = end_idx + 1 continue else: out_start_idx = best_start_idx out_end_idx = best_end_idx if label in CONSERVATIVE_BOUNDARY_REFINEMENT_LABELS and ( best_start_idx != start_idx or best_end_idx != end_idx ): outer_boundary = min(float(start_scores[start_idx, label_index]), float(end_scores[end_idx, label_index])) refined_boundary = min( float(start_scores[best_start_idx, label_index]), float(end_scores[best_end_idx, label_index]), ) if refined_boundary < outer_boundary + 0.08: out_start_idx = start_idx out_end_idx = end_idx score = ( 0.65 * score + 0.175 * float(start_scores[out_start_idx, label_index]) + 0.175 * float(end_scores[out_end_idx, label_index]) ) min_chars = int(min_nonspace_chars.get(label, 1)) if _nonspace_length(text, offsets[out_start_idx][0], offsets[out_end_idx][1]) < min_chars: if ( label in MIN_CHAR_FALLBACK_LABELS and (out_start_idx != start_idx or out_end_idx != end_idx) and _nonspace_length(text, offsets[start_idx][0], offsets[end_idx][1]) >= min_chars ): out_start_idx = start_idx out_end_idx = end_idx else: idx = end_idx + 1 continue if label == "ACCOUNT_NUMBER": out_end_idx = _rescue_iban_tail(text, offsets, valid, out_start_idx, out_end_idx) span_text = text[int(offsets[out_start_idx][0]) : int(offsets[out_end_idx][1])] outer_text = text[int(offsets[outer_start_idx][0]) : int(offsets[outer_end_idx][1])] if label == "EMAIL" and _rescue_email_outer_span(span_text, outer_text): out_start_idx = outer_start_idx out_end_idx = outer_end_idx span_text = outer_text if label in {"FIRST_NAME", "LAST_NAME"} and any(ch.isdigit() for ch in span_text): idx = end_idx + 1 continue spans.append( { "label": label, "start": int(offsets[out_start_idx][0]), "end": int(offsets[out_end_idx][1]), "score": score, "text": span_text, } ) idx = end_idx + 1 return dedupe_spans(spans) def load_onnx_session(model_ref: str, onnx_file: str = "model_quantized.onnx", onnx_subfolder: str = "onnx"): import onnxruntime as ort model_path = Path(model_ref) if model_path.exists(): candidates = [] if onnx_subfolder: candidates.append(model_path / onnx_subfolder / onnx_file) candidates.append(model_path / onnx_file) onnx_path = next((path for path in candidates if path.exists()), candidates[0]) config = AutoConfig.from_pretrained(model_ref) tokenizer = safe_auto_tokenizer(model_ref) else: remote_name = f"{onnx_subfolder}/{onnx_file}" if onnx_subfolder else onnx_file onnx_path = Path(hf_hub_download(repo_id=model_ref, filename=remote_name, repo_type="model")) config = AutoConfig.from_pretrained(model_ref) tokenizer = safe_auto_tokenizer(model_ref) session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) return session, tokenizer, config def run_onnx(session, encoded: dict[str, Any]) -> tuple[np.ndarray, np.ndarray]: feed = {} input_names = {item.name for item in session.get_inputs()} for key, value in encoded.items(): if key == "offset_mapping": continue if key in input_names: feed[key] = value outputs = session.run(None, feed) return outputs[0], outputs[1] def run_onnx_all(session, encoded: dict[str, Any]) -> list[np.ndarray]: feed = {} input_names = {item.name for item in session.get_inputs()} for key, value in encoded.items(): if key == "offset_mapping": continue if key in input_names: feed[key] = value return session.run(None, feed)