temsa's picture
Add files using upload-large-folder tool
ed10267 verified
#!/usr/bin/env python3
from __future__ import annotations
import json
import tempfile
from pathlib import Path
from typing import Any
import numpy as np
from huggingface_hub import HfApi, hf_hub_download
from transformers import AutoConfig, AutoTokenizer
TOKENIZER_FILES = [
"tokenizer_config.json",
"tokenizer.json",
"special_tokens_map.json",
"vocab.txt",
"vocab.json",
"merges.txt",
"added_tokens.json",
"sentencepiece.bpe.model",
"spiece.model",
]
DEFAULT_LABEL_MAX_SPAN_TOKENS = {
# Token-piece limits, not word limits. These need to reflect how the
# underlying tokenizer actually fragments compact identifiers.
"PPSN": 9,
"POSTCODE": 7,
"PHONE_NUMBER": 10,
"PASSPORT_NUMBER": 8,
"BANK_ROUTING_NUMBER": 5,
"ACCOUNT_NUMBER": 19,
"CREDIT_DEBIT_CARD": 12,
"SWIFT_BIC": 8,
"EMAIL": 15,
"FIRST_NAME": 5,
"LAST_NAME": 8,
}
DEFAULT_LABEL_MIN_NONSPACE_CHARS = {
"PPSN": 8,
"POSTCODE": 6,
"PHONE_NUMBER": 7,
"PASSPORT_NUMBER": 7,
"BANK_ROUTING_NUMBER": 6,
"ACCOUNT_NUMBER": 6,
"CREDIT_DEBIT_CARD": 12,
"SWIFT_BIC": 8,
"EMAIL": 6,
"FIRST_NAME": 2,
"LAST_NAME": 2,
}
WHITESPACE_BRIDGE_LABELS = {
"PPSN",
"POSTCODE",
"PHONE_NUMBER",
"PASSPORT_NUMBER",
"BANK_ROUTING_NUMBER",
"ACCOUNT_NUMBER",
"CREDIT_DEBIT_CARD",
"SWIFT_BIC",
"EMAIL",
}
CONSERVATIVE_BOUNDARY_REFINEMENT_LABELS = {
"PPSN",
"POSTCODE",
"PHONE_NUMBER",
"PASSPORT_NUMBER",
"BANK_ROUTING_NUMBER",
"ACCOUNT_NUMBER",
"CREDIT_DEBIT_CARD",
"SWIFT_BIC",
"EMAIL",
}
OUTPUT_PRIORITY = {
"PPSN": 0,
"PASSPORT_NUMBER": 1,
"ACCOUNT_NUMBER": 2,
"BANK_ROUTING_NUMBER": 3,
"CREDIT_DEBIT_CARD": 4,
"PHONE_NUMBER": 5,
"SWIFT_BIC": 6,
"POSTCODE": 7,
"EMAIL": 8,
"FIRST_NAME": 9,
"LAST_NAME": 10,
}
def normalize_entity_name(label: str) -> str:
label = (label or "").strip()
if label.startswith("B-") or label.startswith("I-"):
label = label[2:]
return label.upper()
def _sanitize_tokenizer_dir(tokenizer_path: Path) -> str:
tokenizer_cfg_path = tokenizer_path / "tokenizer_config.json"
if not tokenizer_cfg_path.exists():
return str(tokenizer_path)
data = json.loads(tokenizer_cfg_path.read_text(encoding="utf-8"))
if "fix_mistral_regex" not in data:
return str(tokenizer_path)
tmpdir = Path(tempfile.mkdtemp(prefix="openmed_span_tokenizer_"))
keep = set(TOKENIZER_FILES)
for child in tokenizer_path.iterdir():
if child.is_file() and child.name in keep:
(tmpdir / child.name).write_bytes(child.read_bytes())
data.pop("fix_mistral_regex", None)
(tmpdir / "tokenizer_config.json").write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
return str(tmpdir)
def safe_auto_tokenizer(tokenizer_ref: str):
tokenizer_path = Path(tokenizer_ref)
if tokenizer_path.exists():
tokenizer_ref = _sanitize_tokenizer_dir(tokenizer_path)
else:
api = HfApi()
files = set(api.list_repo_files(repo_id=tokenizer_ref, repo_type="model"))
tmpdir = Path(tempfile.mkdtemp(prefix="openmed_remote_span_tokenizer_"))
copied = False
for name in TOKENIZER_FILES:
if name not in files:
continue
src = hf_hub_download(repo_id=tokenizer_ref, filename=name, repo_type="model")
(tmpdir / Path(name).name).write_bytes(Path(src).read_bytes())
copied = True
if copied:
tokenizer_ref = _sanitize_tokenizer_dir(tmpdir)
try:
return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=True)
except Exception:
pass
try:
return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=False)
except TypeError:
pass
try:
return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True)
except Exception:
return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=False)
def label_names_from_config(config) -> list[str]:
names = list(getattr(config, "span_label_names", []))
if not names:
raise ValueError("Missing span_label_names in config")
return [normalize_entity_name(name) for name in names]
def label_thresholds_from_config(config, default_threshold: float) -> dict[str, float]:
raw = getattr(config, "span_label_thresholds", None) or {}
out = {normalize_entity_name(key): float(value) for key, value in raw.items()}
for label in label_names_from_config(config):
out.setdefault(label, float(default_threshold))
return out
def token_label_thresholds_from_config(config, default_threshold: float) -> dict[str, float]:
raw = getattr(config, "token_label_thresholds", None) or {}
out = {normalize_entity_name(key): float(value) for key, value in raw.items()}
for label in label_names_from_config(config):
out.setdefault(label, float(default_threshold))
return out
def token_extend_thresholds_from_config(config, default_fraction: float = 0.6) -> dict[str, float]:
raw = getattr(config, "token_extend_thresholds", None) or {}
out = {normalize_entity_name(key): float(value) for key, value in raw.items()}
for label in label_names_from_config(config):
out.setdefault(label, max(0.0, min(1.0, float(token_label_thresholds_from_config(config, 0.5).get(label, 0.5)) * default_fraction)))
return out
def boundary_label_thresholds_from_config(config, default_threshold: float = 0.0) -> dict[str, float]:
raw = getattr(config, "boundary_label_thresholds", None) or {}
out = {normalize_entity_name(key): float(value) for key, value in raw.items()}
for label in label_names_from_config(config):
out.setdefault(label, float(default_threshold))
return out
def label_max_span_tokens_from_config(config) -> dict[str, int]:
raw = getattr(config, "span_label_max_span_tokens", None) or {}
out = {normalize_entity_name(key): int(value) for key, value in raw.items()}
for label, value in DEFAULT_LABEL_MAX_SPAN_TOKENS.items():
out.setdefault(label, value)
for label in label_names_from_config(config):
out.setdefault(label, 8)
return out
def label_min_nonspace_chars_from_config(config) -> dict[str, int]:
raw = getattr(config, "span_label_min_nonspace_chars", None) or {}
out = {normalize_entity_name(key): int(value) for key, value in raw.items()}
for label, value in DEFAULT_LABEL_MIN_NONSPACE_CHARS.items():
out.setdefault(label, value)
for label in label_names_from_config(config):
out.setdefault(label, 1)
return out
def overlaps(a: dict, b: dict) -> bool:
return not (a["end"] <= b["start"] or b["end"] <= a["start"])
def dedupe_spans(spans: list[dict]) -> list[dict]:
ordered = sorted(
spans,
key=lambda item: (-float(item.get("score", 0.0)), item["start"], item["end"], OUTPUT_PRIORITY.get(item["label"], 99)),
)
kept = []
for span in ordered:
if any(overlaps(span, other) for other in kept):
continue
kept.append(span)
kept.sort(key=lambda item: (item["start"], item["end"], OUTPUT_PRIORITY.get(item["label"], 99)))
return kept
def _valid_offset(offset: tuple[int, int]) -> bool:
return bool(offset) and offset[1] > offset[0]
def _has_skippable_bridge(text: str, left: tuple[int, int], right: tuple[int, int], label: str) -> bool:
bridge = text[int(left[1]) : int(right[0])]
if bridge == "":
return True
return label in WHITESPACE_BRIDGE_LABELS and bridge.isspace()
def _has_left_extension_bridge(text: str, left: tuple[int, int], right: tuple[int, int]) -> bool:
bridge = text[int(left[1]) : int(right[0])]
return bridge == ""
def _nonspace_length(text: str, start: int, end: int) -> int:
return sum(0 if ch.isspace() else 1 for ch in text[int(start) : int(end)])
def decode_span_logits(
text: str,
offsets: list[tuple[int, int]],
start_scores: np.ndarray,
end_scores: np.ndarray,
label_names: list[str],
default_threshold: float,
label_thresholds: dict[str, float] | None = None,
label_max_span_tokens: dict[str, int] | None = None,
) -> list[dict]:
thresholds = {label: float(default_threshold) for label in label_names}
if label_thresholds:
thresholds.update({normalize_entity_name(key): float(value) for key, value in label_thresholds.items()})
max_tokens = dict(DEFAULT_LABEL_MAX_SPAN_TOKENS)
if label_max_span_tokens:
max_tokens.update({normalize_entity_name(key): int(value) for key, value in label_max_span_tokens.items()})
spans: list[dict] = []
for label_index, label in enumerate(label_names):
threshold = thresholds.get(label, float(default_threshold))
max_span = max_tokens.get(label, 8)
start_candidates = [idx for idx in range(len(offsets)) if _valid_offset(offsets[idx]) and float(start_scores[idx, label_index]) >= threshold]
for start_idx in start_candidates:
best = None
for end_idx in range(start_idx, min(len(offsets), start_idx + max_span)):
if not _valid_offset(offsets[end_idx]):
continue
end_score = float(end_scores[end_idx, label_index])
if end_score < threshold:
continue
score = min(float(start_scores[start_idx, label_index]), end_score)
if best is None or score > best["score"]:
best = {
"label": label,
"start": int(offsets[start_idx][0]),
"end": int(offsets[end_idx][1]),
"score": score,
}
if best is not None and best["start"] < best["end"]:
best["text"] = text[best["start"]:best["end"]]
spans.append(best)
return dedupe_spans(spans)
def decode_token_presence_segments(
text: str,
offsets: list[tuple[int, int]],
token_scores: np.ndarray,
label_names: list[str],
default_threshold: float,
label_thresholds: dict[str, float] | None = None,
label_extend_thresholds: dict[str, float] | None = None,
label_max_span_tokens: dict[str, int] | None = None,
label_min_nonspace_chars: dict[str, int] | None = None,
boundary_label_thresholds: dict[str, float] | None = None,
start_scores: np.ndarray | None = None,
end_scores: np.ndarray | None = None,
) -> list[dict]:
thresholds = {label: float(default_threshold) for label in label_names}
if label_thresholds:
thresholds.update({normalize_entity_name(key): float(value) for key, value in label_thresholds.items()})
extend_thresholds = {label: max(0.0, min(1.0, thresholds[label] * 0.6)) for label in label_names}
if label_extend_thresholds:
extend_thresholds.update({normalize_entity_name(key): float(value) for key, value in label_extend_thresholds.items()})
max_tokens = dict(DEFAULT_LABEL_MAX_SPAN_TOKENS)
if label_max_span_tokens:
max_tokens.update({normalize_entity_name(key): int(value) for key, value in label_max_span_tokens.items()})
min_nonspace_chars = dict(DEFAULT_LABEL_MIN_NONSPACE_CHARS)
if label_min_nonspace_chars:
min_nonspace_chars.update({normalize_entity_name(key): int(value) for key, value in label_min_nonspace_chars.items()})
boundary_thresholds = {label: 0.0 for label in label_names}
if boundary_label_thresholds:
boundary_thresholds.update({normalize_entity_name(key): float(value) for key, value in boundary_label_thresholds.items()})
spans: list[dict] = []
valid = [_valid_offset(offset) for offset in offsets]
num_tokens = len(offsets)
for label_index, label in enumerate(label_names):
threshold = thresholds.get(label, float(default_threshold))
extend_threshold = min(threshold, extend_thresholds.get(label, threshold))
max_span = max_tokens.get(label, 8)
idx = 0
while idx < num_tokens:
if not valid[idx] or float(token_scores[idx, label_index]) < threshold:
idx += 1
continue
start_idx = idx
end_idx = idx
while end_idx + 1 < num_tokens and valid[end_idx + 1] and float(token_scores[end_idx + 1, label_index]) >= threshold and (end_idx + 1 - start_idx + 1) <= max_span:
end_idx += 1
while (
start_idx - 1 >= 0
and valid[start_idx - 1]
and _has_left_extension_bridge(text, offsets[start_idx - 1], offsets[start_idx])
and float(token_scores[start_idx - 1, label_index]) >= extend_threshold
and (end_idx - (start_idx - 1) + 1) <= max_span
):
start_idx -= 1
while (
end_idx + 1 < num_tokens
and valid[end_idx + 1]
and _has_skippable_bridge(text, offsets[end_idx], offsets[end_idx + 1], label)
and float(token_scores[end_idx + 1, label_index]) >= extend_threshold
and ((end_idx + 1) - start_idx + 1) <= max_span
):
end_idx += 1
presence_slice = token_scores[start_idx : end_idx + 1, label_index]
score = float(presence_slice.mean())
out_start_idx = start_idx
out_end_idx = end_idx
if start_scores is not None and end_scores is not None:
refine_window = min(3, end_idx - start_idx + 1)
start_window = start_scores[start_idx : start_idx + refine_window, label_index]
best_start_rel = int(np.argmax(start_window))
best_start_idx = start_idx + best_start_rel
end_window_start = max(best_start_idx, end_idx - refine_window + 1)
end_window = end_scores[end_window_start : end_idx + 1, label_index]
best_end_rel = int(np.argmax(end_window))
best_end_idx = end_window_start + best_end_rel
if (
float(start_scores[best_start_idx, label_index]) < boundary_thresholds.get(label, 0.0)
or float(end_scores[best_end_idx, label_index]) < boundary_thresholds.get(label, 0.0)
):
idx = end_idx + 1
continue
out_start_idx = best_start_idx
out_end_idx = best_end_idx
if label in CONSERVATIVE_BOUNDARY_REFINEMENT_LABELS and (
best_start_idx != start_idx or best_end_idx != end_idx
):
outer_boundary = min(float(start_scores[start_idx, label_index]), float(end_scores[end_idx, label_index]))
refined_boundary = min(
float(start_scores[best_start_idx, label_index]),
float(end_scores[best_end_idx, label_index]),
)
if refined_boundary < outer_boundary + 0.08:
out_start_idx = start_idx
out_end_idx = end_idx
score = (
0.65 * score
+ 0.175 * float(start_scores[out_start_idx, label_index])
+ 0.175 * float(end_scores[out_end_idx, label_index])
)
min_chars = int(min_nonspace_chars.get(label, 1))
if _nonspace_length(text, offsets[out_start_idx][0], offsets[out_end_idx][1]) < min_chars:
idx = end_idx + 1
continue
spans.append(
{
"label": label,
"start": int(offsets[out_start_idx][0]),
"end": int(offsets[out_end_idx][1]),
"score": score,
"text": text[int(offsets[out_start_idx][0]) : int(offsets[out_end_idx][1])],
}
)
idx = end_idx + 1
return dedupe_spans(spans)
def load_onnx_session(model_ref: str, onnx_file: str = "model_quantized.onnx", onnx_subfolder: str = "onnx"):
import onnxruntime as ort
model_path = Path(model_ref)
if model_path.exists():
candidates = []
if onnx_subfolder:
candidates.append(model_path / onnx_subfolder / onnx_file)
candidates.append(model_path / onnx_file)
onnx_path = next((path for path in candidates if path.exists()), candidates[0])
config = AutoConfig.from_pretrained(model_ref)
tokenizer = safe_auto_tokenizer(model_ref)
else:
remote_name = f"{onnx_subfolder}/{onnx_file}" if onnx_subfolder else onnx_file
onnx_path = Path(hf_hub_download(repo_id=model_ref, filename=remote_name, repo_type="model"))
config = AutoConfig.from_pretrained(model_ref)
tokenizer = safe_auto_tokenizer(model_ref)
session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
return session, tokenizer, config
def run_onnx(session, encoded: dict[str, Any]) -> tuple[np.ndarray, np.ndarray]:
feed = {}
input_names = {item.name for item in session.get_inputs()}
for key, value in encoded.items():
if key == "offset_mapping":
continue
if key in input_names:
feed[key] = value
outputs = session.run(None, feed)
return outputs[0], outputs[1]
def run_onnx_all(session, encoded: dict[str, Any]) -> list[np.ndarray]:
feed = {}
input_names = {item.name for item in session.get_inputs()}
for key, value in encoded.items():
if key == "offset_mapping":
continue
if key in input_names:
feed[key] = value
return session.run(None, feed)