| | |
| | from __future__ import annotations |
| |
|
| | import json |
| | import tempfile |
| | from pathlib import Path |
| | from typing import Any |
| |
|
| | import numpy as np |
| | from huggingface_hub import HfApi, hf_hub_download |
| | from transformers import AutoConfig, AutoTokenizer |
| |
|
| | TOKENIZER_FILES = [ |
| | "tokenizer_config.json", |
| | "tokenizer.json", |
| | "special_tokens_map.json", |
| | "vocab.txt", |
| | "vocab.json", |
| | "merges.txt", |
| | "added_tokens.json", |
| | "sentencepiece.bpe.model", |
| | "spiece.model", |
| | ] |
| | DEFAULT_LABEL_MAX_SPAN_TOKENS = { |
| | |
| | |
| | "PPSN": 9, |
| | "POSTCODE": 7, |
| | "PHONE_NUMBER": 10, |
| | "PASSPORT_NUMBER": 8, |
| | "BANK_ROUTING_NUMBER": 5, |
| | "ACCOUNT_NUMBER": 19, |
| | "CREDIT_DEBIT_CARD": 12, |
| | "SWIFT_BIC": 8, |
| | "EMAIL": 15, |
| | "FIRST_NAME": 5, |
| | "LAST_NAME": 8, |
| | } |
| | DEFAULT_LABEL_MIN_NONSPACE_CHARS = { |
| | "PPSN": 8, |
| | "POSTCODE": 6, |
| | "PHONE_NUMBER": 7, |
| | "PASSPORT_NUMBER": 7, |
| | "BANK_ROUTING_NUMBER": 6, |
| | "ACCOUNT_NUMBER": 6, |
| | "CREDIT_DEBIT_CARD": 12, |
| | "SWIFT_BIC": 8, |
| | "EMAIL": 6, |
| | "FIRST_NAME": 2, |
| | "LAST_NAME": 2, |
| | } |
| | WHITESPACE_BRIDGE_LABELS = { |
| | "PPSN", |
| | "POSTCODE", |
| | "PHONE_NUMBER", |
| | "PASSPORT_NUMBER", |
| | "BANK_ROUTING_NUMBER", |
| | "ACCOUNT_NUMBER", |
| | "CREDIT_DEBIT_CARD", |
| | "SWIFT_BIC", |
| | "EMAIL", |
| | } |
| | CONSERVATIVE_BOUNDARY_REFINEMENT_LABELS = { |
| | "PPSN", |
| | "POSTCODE", |
| | "PHONE_NUMBER", |
| | "PASSPORT_NUMBER", |
| | "BANK_ROUTING_NUMBER", |
| | "ACCOUNT_NUMBER", |
| | "CREDIT_DEBIT_CARD", |
| | "SWIFT_BIC", |
| | "EMAIL", |
| | } |
| | OUTPUT_PRIORITY = { |
| | "PPSN": 0, |
| | "PASSPORT_NUMBER": 1, |
| | "ACCOUNT_NUMBER": 2, |
| | "BANK_ROUTING_NUMBER": 3, |
| | "CREDIT_DEBIT_CARD": 4, |
| | "PHONE_NUMBER": 5, |
| | "SWIFT_BIC": 6, |
| | "POSTCODE": 7, |
| | "EMAIL": 8, |
| | "FIRST_NAME": 9, |
| | "LAST_NAME": 10, |
| | } |
| |
|
| |
|
| | def normalize_entity_name(label: str) -> str: |
| | label = (label or "").strip() |
| | if label.startswith("B-") or label.startswith("I-"): |
| | label = label[2:] |
| | return label.upper() |
| |
|
| |
|
| | def _sanitize_tokenizer_dir(tokenizer_path: Path) -> str: |
| | tokenizer_cfg_path = tokenizer_path / "tokenizer_config.json" |
| | if not tokenizer_cfg_path.exists(): |
| | return str(tokenizer_path) |
| | data = json.loads(tokenizer_cfg_path.read_text(encoding="utf-8")) |
| | if "fix_mistral_regex" not in data: |
| | return str(tokenizer_path) |
| | tmpdir = Path(tempfile.mkdtemp(prefix="openmed_span_tokenizer_")) |
| | keep = set(TOKENIZER_FILES) |
| | for child in tokenizer_path.iterdir(): |
| | if child.is_file() and child.name in keep: |
| | (tmpdir / child.name).write_bytes(child.read_bytes()) |
| | data.pop("fix_mistral_regex", None) |
| | (tmpdir / "tokenizer_config.json").write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") |
| | return str(tmpdir) |
| |
|
| |
|
| | def safe_auto_tokenizer(tokenizer_ref: str): |
| | tokenizer_path = Path(tokenizer_ref) |
| | if tokenizer_path.exists(): |
| | tokenizer_ref = _sanitize_tokenizer_dir(tokenizer_path) |
| | else: |
| | api = HfApi() |
| | files = set(api.list_repo_files(repo_id=tokenizer_ref, repo_type="model")) |
| | tmpdir = Path(tempfile.mkdtemp(prefix="openmed_remote_span_tokenizer_")) |
| | copied = False |
| | for name in TOKENIZER_FILES: |
| | if name not in files: |
| | continue |
| | src = hf_hub_download(repo_id=tokenizer_ref, filename=name, repo_type="model") |
| | (tmpdir / Path(name).name).write_bytes(Path(src).read_bytes()) |
| | copied = True |
| | if copied: |
| | tokenizer_ref = _sanitize_tokenizer_dir(tmpdir) |
| |
|
| | try: |
| | return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=True) |
| | except Exception: |
| | pass |
| | try: |
| | return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=False) |
| | except TypeError: |
| | pass |
| | try: |
| | return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True) |
| | except Exception: |
| | return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=False) |
| |
|
| |
|
| | def label_names_from_config(config) -> list[str]: |
| | names = list(getattr(config, "span_label_names", [])) |
| | if not names: |
| | raise ValueError("Missing span_label_names in config") |
| | return [normalize_entity_name(name) for name in names] |
| |
|
| |
|
| | def label_thresholds_from_config(config, default_threshold: float) -> dict[str, float]: |
| | raw = getattr(config, "span_label_thresholds", None) or {} |
| | out = {normalize_entity_name(key): float(value) for key, value in raw.items()} |
| | for label in label_names_from_config(config): |
| | out.setdefault(label, float(default_threshold)) |
| | return out |
| |
|
| |
|
| | def token_label_thresholds_from_config(config, default_threshold: float) -> dict[str, float]: |
| | raw = getattr(config, "token_label_thresholds", None) or {} |
| | out = {normalize_entity_name(key): float(value) for key, value in raw.items()} |
| | for label in label_names_from_config(config): |
| | out.setdefault(label, float(default_threshold)) |
| | return out |
| |
|
| |
|
| | def token_extend_thresholds_from_config(config, default_fraction: float = 0.6) -> dict[str, float]: |
| | raw = getattr(config, "token_extend_thresholds", None) or {} |
| | out = {normalize_entity_name(key): float(value) for key, value in raw.items()} |
| | for label in label_names_from_config(config): |
| | out.setdefault(label, max(0.0, min(1.0, float(token_label_thresholds_from_config(config, 0.5).get(label, 0.5)) * default_fraction))) |
| | return out |
| |
|
| |
|
| | def boundary_label_thresholds_from_config(config, default_threshold: float = 0.0) -> dict[str, float]: |
| | raw = getattr(config, "boundary_label_thresholds", None) or {} |
| | out = {normalize_entity_name(key): float(value) for key, value in raw.items()} |
| | for label in label_names_from_config(config): |
| | out.setdefault(label, float(default_threshold)) |
| | return out |
| |
|
| |
|
| | def label_max_span_tokens_from_config(config) -> dict[str, int]: |
| | raw = getattr(config, "span_label_max_span_tokens", None) or {} |
| | out = {normalize_entity_name(key): int(value) for key, value in raw.items()} |
| | for label, value in DEFAULT_LABEL_MAX_SPAN_TOKENS.items(): |
| | out.setdefault(label, value) |
| | for label in label_names_from_config(config): |
| | out.setdefault(label, 8) |
| | return out |
| |
|
| |
|
| | def label_min_nonspace_chars_from_config(config) -> dict[str, int]: |
| | raw = getattr(config, "span_label_min_nonspace_chars", None) or {} |
| | out = {normalize_entity_name(key): int(value) for key, value in raw.items()} |
| | for label, value in DEFAULT_LABEL_MIN_NONSPACE_CHARS.items(): |
| | out.setdefault(label, value) |
| | for label in label_names_from_config(config): |
| | out.setdefault(label, 1) |
| | return out |
| |
|
| |
|
| | def overlaps(a: dict, b: dict) -> bool: |
| | return not (a["end"] <= b["start"] or b["end"] <= a["start"]) |
| |
|
| |
|
| | def dedupe_spans(spans: list[dict]) -> list[dict]: |
| | ordered = sorted( |
| | spans, |
| | key=lambda item: (-float(item.get("score", 0.0)), item["start"], item["end"], OUTPUT_PRIORITY.get(item["label"], 99)), |
| | ) |
| | kept = [] |
| | for span in ordered: |
| | if any(overlaps(span, other) for other in kept): |
| | continue |
| | kept.append(span) |
| | kept.sort(key=lambda item: (item["start"], item["end"], OUTPUT_PRIORITY.get(item["label"], 99))) |
| | return kept |
| |
|
| |
|
| | def _valid_offset(offset: tuple[int, int]) -> bool: |
| | return bool(offset) and offset[1] > offset[0] |
| |
|
| |
|
| | def _has_skippable_bridge(text: str, left: tuple[int, int], right: tuple[int, int], label: str) -> bool: |
| | bridge = text[int(left[1]) : int(right[0])] |
| | if bridge == "": |
| | return True |
| | return label in WHITESPACE_BRIDGE_LABELS and bridge.isspace() |
| |
|
| |
|
| | def _has_left_extension_bridge(text: str, left: tuple[int, int], right: tuple[int, int]) -> bool: |
| | bridge = text[int(left[1]) : int(right[0])] |
| | return bridge == "" |
| |
|
| |
|
| | def _nonspace_length(text: str, start: int, end: int) -> int: |
| | return sum(0 if ch.isspace() else 1 for ch in text[int(start) : int(end)]) |
| |
|
| |
|
| | def decode_span_logits( |
| | text: str, |
| | offsets: list[tuple[int, int]], |
| | start_scores: np.ndarray, |
| | end_scores: np.ndarray, |
| | label_names: list[str], |
| | default_threshold: float, |
| | label_thresholds: dict[str, float] | None = None, |
| | label_max_span_tokens: dict[str, int] | None = None, |
| | ) -> list[dict]: |
| | thresholds = {label: float(default_threshold) for label in label_names} |
| | if label_thresholds: |
| | thresholds.update({normalize_entity_name(key): float(value) for key, value in label_thresholds.items()}) |
| | max_tokens = dict(DEFAULT_LABEL_MAX_SPAN_TOKENS) |
| | if label_max_span_tokens: |
| | max_tokens.update({normalize_entity_name(key): int(value) for key, value in label_max_span_tokens.items()}) |
| |
|
| | spans: list[dict] = [] |
| | for label_index, label in enumerate(label_names): |
| | threshold = thresholds.get(label, float(default_threshold)) |
| | max_span = max_tokens.get(label, 8) |
| | start_candidates = [idx for idx in range(len(offsets)) if _valid_offset(offsets[idx]) and float(start_scores[idx, label_index]) >= threshold] |
| | for start_idx in start_candidates: |
| | best = None |
| | for end_idx in range(start_idx, min(len(offsets), start_idx + max_span)): |
| | if not _valid_offset(offsets[end_idx]): |
| | continue |
| | end_score = float(end_scores[end_idx, label_index]) |
| | if end_score < threshold: |
| | continue |
| | score = min(float(start_scores[start_idx, label_index]), end_score) |
| | if best is None or score > best["score"]: |
| | best = { |
| | "label": label, |
| | "start": int(offsets[start_idx][0]), |
| | "end": int(offsets[end_idx][1]), |
| | "score": score, |
| | } |
| | if best is not None and best["start"] < best["end"]: |
| | best["text"] = text[best["start"]:best["end"]] |
| | spans.append(best) |
| | return dedupe_spans(spans) |
| |
|
| |
|
| | def decode_token_presence_segments( |
| | text: str, |
| | offsets: list[tuple[int, int]], |
| | token_scores: np.ndarray, |
| | label_names: list[str], |
| | default_threshold: float, |
| | label_thresholds: dict[str, float] | None = None, |
| | label_extend_thresholds: dict[str, float] | None = None, |
| | label_max_span_tokens: dict[str, int] | None = None, |
| | label_min_nonspace_chars: dict[str, int] | None = None, |
| | boundary_label_thresholds: dict[str, float] | None = None, |
| | start_scores: np.ndarray | None = None, |
| | end_scores: np.ndarray | None = None, |
| | ) -> list[dict]: |
| | thresholds = {label: float(default_threshold) for label in label_names} |
| | if label_thresholds: |
| | thresholds.update({normalize_entity_name(key): float(value) for key, value in label_thresholds.items()}) |
| | extend_thresholds = {label: max(0.0, min(1.0, thresholds[label] * 0.6)) for label in label_names} |
| | if label_extend_thresholds: |
| | extend_thresholds.update({normalize_entity_name(key): float(value) for key, value in label_extend_thresholds.items()}) |
| | max_tokens = dict(DEFAULT_LABEL_MAX_SPAN_TOKENS) |
| | if label_max_span_tokens: |
| | max_tokens.update({normalize_entity_name(key): int(value) for key, value in label_max_span_tokens.items()}) |
| | min_nonspace_chars = dict(DEFAULT_LABEL_MIN_NONSPACE_CHARS) |
| | if label_min_nonspace_chars: |
| | min_nonspace_chars.update({normalize_entity_name(key): int(value) for key, value in label_min_nonspace_chars.items()}) |
| | boundary_thresholds = {label: 0.0 for label in label_names} |
| | if boundary_label_thresholds: |
| | boundary_thresholds.update({normalize_entity_name(key): float(value) for key, value in boundary_label_thresholds.items()}) |
| |
|
| | spans: list[dict] = [] |
| | valid = [_valid_offset(offset) for offset in offsets] |
| | num_tokens = len(offsets) |
| | for label_index, label in enumerate(label_names): |
| | threshold = thresholds.get(label, float(default_threshold)) |
| | extend_threshold = min(threshold, extend_thresholds.get(label, threshold)) |
| | max_span = max_tokens.get(label, 8) |
| | idx = 0 |
| | while idx < num_tokens: |
| | if not valid[idx] or float(token_scores[idx, label_index]) < threshold: |
| | idx += 1 |
| | continue |
| | start_idx = idx |
| | end_idx = idx |
| | while end_idx + 1 < num_tokens and valid[end_idx + 1] and float(token_scores[end_idx + 1, label_index]) >= threshold and (end_idx + 1 - start_idx + 1) <= max_span: |
| | end_idx += 1 |
| | while ( |
| | start_idx - 1 >= 0 |
| | and valid[start_idx - 1] |
| | and _has_left_extension_bridge(text, offsets[start_idx - 1], offsets[start_idx]) |
| | and float(token_scores[start_idx - 1, label_index]) >= extend_threshold |
| | and (end_idx - (start_idx - 1) + 1) <= max_span |
| | ): |
| | start_idx -= 1 |
| | while ( |
| | end_idx + 1 < num_tokens |
| | and valid[end_idx + 1] |
| | and _has_skippable_bridge(text, offsets[end_idx], offsets[end_idx + 1], label) |
| | and float(token_scores[end_idx + 1, label_index]) >= extend_threshold |
| | and ((end_idx + 1) - start_idx + 1) <= max_span |
| | ): |
| | end_idx += 1 |
| | presence_slice = token_scores[start_idx : end_idx + 1, label_index] |
| | score = float(presence_slice.mean()) |
| | out_start_idx = start_idx |
| | out_end_idx = end_idx |
| | if start_scores is not None and end_scores is not None: |
| | refine_window = min(3, end_idx - start_idx + 1) |
| | start_window = start_scores[start_idx : start_idx + refine_window, label_index] |
| | best_start_rel = int(np.argmax(start_window)) |
| | best_start_idx = start_idx + best_start_rel |
| | end_window_start = max(best_start_idx, end_idx - refine_window + 1) |
| | end_window = end_scores[end_window_start : end_idx + 1, label_index] |
| | best_end_rel = int(np.argmax(end_window)) |
| | best_end_idx = end_window_start + best_end_rel |
| | if ( |
| | float(start_scores[best_start_idx, label_index]) < boundary_thresholds.get(label, 0.0) |
| | or float(end_scores[best_end_idx, label_index]) < boundary_thresholds.get(label, 0.0) |
| | ): |
| | idx = end_idx + 1 |
| | continue |
| | out_start_idx = best_start_idx |
| | out_end_idx = best_end_idx |
| | if label in CONSERVATIVE_BOUNDARY_REFINEMENT_LABELS and ( |
| | best_start_idx != start_idx or best_end_idx != end_idx |
| | ): |
| | outer_boundary = min(float(start_scores[start_idx, label_index]), float(end_scores[end_idx, label_index])) |
| | refined_boundary = min( |
| | float(start_scores[best_start_idx, label_index]), |
| | float(end_scores[best_end_idx, label_index]), |
| | ) |
| | if refined_boundary < outer_boundary + 0.08: |
| | out_start_idx = start_idx |
| | out_end_idx = end_idx |
| | score = ( |
| | 0.65 * score |
| | + 0.175 * float(start_scores[out_start_idx, label_index]) |
| | + 0.175 * float(end_scores[out_end_idx, label_index]) |
| | ) |
| | min_chars = int(min_nonspace_chars.get(label, 1)) |
| | if _nonspace_length(text, offsets[out_start_idx][0], offsets[out_end_idx][1]) < min_chars: |
| | idx = end_idx + 1 |
| | continue |
| | spans.append( |
| | { |
| | "label": label, |
| | "start": int(offsets[out_start_idx][0]), |
| | "end": int(offsets[out_end_idx][1]), |
| | "score": score, |
| | "text": text[int(offsets[out_start_idx][0]) : int(offsets[out_end_idx][1])], |
| | } |
| | ) |
| | idx = end_idx + 1 |
| | return dedupe_spans(spans) |
| |
|
| |
|
| | def load_onnx_session(model_ref: str, onnx_file: str = "model_quantized.onnx", onnx_subfolder: str = "onnx"): |
| | import onnxruntime as ort |
| |
|
| | model_path = Path(model_ref) |
| | if model_path.exists(): |
| | candidates = [] |
| | if onnx_subfolder: |
| | candidates.append(model_path / onnx_subfolder / onnx_file) |
| | candidates.append(model_path / onnx_file) |
| | onnx_path = next((path for path in candidates if path.exists()), candidates[0]) |
| | config = AutoConfig.from_pretrained(model_ref) |
| | tokenizer = safe_auto_tokenizer(model_ref) |
| | else: |
| | remote_name = f"{onnx_subfolder}/{onnx_file}" if onnx_subfolder else onnx_file |
| | onnx_path = Path(hf_hub_download(repo_id=model_ref, filename=remote_name, repo_type="model")) |
| | config = AutoConfig.from_pretrained(model_ref) |
| | tokenizer = safe_auto_tokenizer(model_ref) |
| | session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) |
| | return session, tokenizer, config |
| |
|
| |
|
| | def run_onnx(session, encoded: dict[str, Any]) -> tuple[np.ndarray, np.ndarray]: |
| | feed = {} |
| | input_names = {item.name for item in session.get_inputs()} |
| | for key, value in encoded.items(): |
| | if key == "offset_mapping": |
| | continue |
| | if key in input_names: |
| | feed[key] = value |
| | outputs = session.run(None, feed) |
| | return outputs[0], outputs[1] |
| |
|
| |
|
| | def run_onnx_all(session, encoded: dict[str, Any]) -> list[np.ndarray]: |
| | feed = {} |
| | input_names = {item.name for item in session.get_inputs()} |
| | for key, value in encoded.items(): |
| | if key == "offset_mapping": |
| | continue |
| | if key in input_names: |
| | feed[key] = value |
| | return session.run(None, feed) |
| |
|