temsa commited on
Commit
b5b10c7
·
verified ·
1 Parent(s): bcff460

Fix standalone ContextPII helper imports

Browse files
Files changed (2) hide show
  1. base_common.py +628 -0
  2. common.py +9 -20
base_common.py ADDED
@@ -0,0 +1,628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ import tempfile
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import numpy as np
11
+ from huggingface_hub import HfApi, hf_hub_download
12
+ from transformers import AutoConfig, AutoTokenizer
13
+
14
+ TOKENIZER_FILES = [
15
+ "tokenizer_config.json",
16
+ "tokenizer.json",
17
+ "special_tokens_map.json",
18
+ "vocab.txt",
19
+ "vocab.json",
20
+ "merges.txt",
21
+ "added_tokens.json",
22
+ "sentencepiece.bpe.model",
23
+ "spiece.model",
24
+ ]
25
+ DEFAULT_LABEL_MAX_SPAN_TOKENS = {
26
+ # Token-piece limits, not word limits. These need to reflect how the
27
+ # underlying tokenizer actually fragments compact identifiers.
28
+ "PPSN": 9,
29
+ "POSTCODE": 7,
30
+ "PHONE_NUMBER": 10,
31
+ "PASSPORT_NUMBER": 8,
32
+ "BANK_ROUTING_NUMBER": 5,
33
+ "ACCOUNT_NUMBER": 19,
34
+ "CREDIT_DEBIT_CARD": 12,
35
+ "SWIFT_BIC": 8,
36
+ "EMAIL": 15,
37
+ "FIRST_NAME": 5,
38
+ "LAST_NAME": 8,
39
+ }
40
+ DEFAULT_LABEL_MIN_NONSPACE_CHARS = {
41
+ "PPSN": 8,
42
+ "POSTCODE": 6,
43
+ "PHONE_NUMBER": 7,
44
+ "PASSPORT_NUMBER": 7,
45
+ "BANK_ROUTING_NUMBER": 6,
46
+ "ACCOUNT_NUMBER": 6,
47
+ "CREDIT_DEBIT_CARD": 12,
48
+ "SWIFT_BIC": 8,
49
+ "EMAIL": 6,
50
+ "FIRST_NAME": 2,
51
+ "LAST_NAME": 2,
52
+ }
53
+ WHITESPACE_BRIDGE_LABELS = {
54
+ "PPSN",
55
+ "POSTCODE",
56
+ "PHONE_NUMBER",
57
+ "PASSPORT_NUMBER",
58
+ "BANK_ROUTING_NUMBER",
59
+ "ACCOUNT_NUMBER",
60
+ "CREDIT_DEBIT_CARD",
61
+ "SWIFT_BIC",
62
+ }
63
+ SIMPLE_PUNCT_BRIDGE_LABELS = {
64
+ "PHONE_NUMBER",
65
+ "BANK_ROUTING_NUMBER",
66
+ "ACCOUNT_NUMBER",
67
+ "CREDIT_DEBIT_CARD",
68
+ }
69
+ MIN_CHAR_FALLBACK_LABELS = {
70
+ "PHONE_NUMBER",
71
+ "BANK_ROUTING_NUMBER",
72
+ "ACCOUNT_NUMBER",
73
+ "CREDIT_DEBIT_CARD",
74
+ "EMAIL",
75
+ }
76
+ CONSERVATIVE_BOUNDARY_REFINEMENT_LABELS = {
77
+ "PPSN",
78
+ "POSTCODE",
79
+ "PHONE_NUMBER",
80
+ "PASSPORT_NUMBER",
81
+ "BANK_ROUTING_NUMBER",
82
+ "ACCOUNT_NUMBER",
83
+ "CREDIT_DEBIT_CARD",
84
+ "SWIFT_BIC",
85
+ "EMAIL",
86
+ }
87
+ OUTPUT_PRIORITY = {
88
+ "PPSN": 0,
89
+ "PASSPORT_NUMBER": 1,
90
+ "ACCOUNT_NUMBER": 2,
91
+ "BANK_ROUTING_NUMBER": 3,
92
+ "CREDIT_DEBIT_CARD": 4,
93
+ "PHONE_NUMBER": 5,
94
+ "SWIFT_BIC": 6,
95
+ "POSTCODE": 7,
96
+ "EMAIL": 8,
97
+ "FIRST_NAME": 9,
98
+ "LAST_NAME": 10,
99
+ }
100
+
101
+
102
+ def normalize_entity_name(label: str) -> str:
103
+ label = (label or "").strip()
104
+ if label.startswith("B-") or label.startswith("I-"):
105
+ label = label[2:]
106
+ return label.upper()
107
+
108
+
109
+ def _sanitize_tokenizer_dir(tokenizer_path: Path) -> str:
110
+ tokenizer_cfg_path = tokenizer_path / "tokenizer_config.json"
111
+ if not tokenizer_cfg_path.exists():
112
+ return str(tokenizer_path)
113
+ data = json.loads(tokenizer_cfg_path.read_text(encoding="utf-8"))
114
+ if "fix_mistral_regex" not in data:
115
+ return str(tokenizer_path)
116
+ tmpdir = Path(tempfile.mkdtemp(prefix="openmed_span_tokenizer_"))
117
+ keep = set(TOKENIZER_FILES)
118
+ for child in tokenizer_path.iterdir():
119
+ if child.is_file() and child.name in keep:
120
+ (tmpdir / child.name).write_bytes(child.read_bytes())
121
+ data.pop("fix_mistral_regex", None)
122
+ (tmpdir / "tokenizer_config.json").write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
123
+ return str(tmpdir)
124
+
125
+
126
+ def safe_auto_tokenizer(tokenizer_ref: str):
127
+ tokenizer_path = Path(tokenizer_ref)
128
+ if tokenizer_path.exists():
129
+ tokenizer_ref = _sanitize_tokenizer_dir(tokenizer_path)
130
+ else:
131
+ api = HfApi()
132
+ files = set(api.list_repo_files(repo_id=tokenizer_ref, repo_type="model"))
133
+ tmpdir = Path(tempfile.mkdtemp(prefix="openmed_remote_span_tokenizer_"))
134
+ copied = False
135
+ for name in TOKENIZER_FILES:
136
+ if name not in files:
137
+ continue
138
+ src = hf_hub_download(repo_id=tokenizer_ref, filename=name, repo_type="model")
139
+ (tmpdir / Path(name).name).write_bytes(Path(src).read_bytes())
140
+ copied = True
141
+ if copied:
142
+ tokenizer_ref = _sanitize_tokenizer_dir(tmpdir)
143
+
144
+ try:
145
+ return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=True)
146
+ except Exception:
147
+ pass
148
+ try:
149
+ return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=False)
150
+ except TypeError:
151
+ pass
152
+ try:
153
+ return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True)
154
+ except Exception:
155
+ return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=False)
156
+
157
+
158
+ def label_names_from_config(config) -> list[str]:
159
+ names = list(getattr(config, "span_label_names", []))
160
+ if not names:
161
+ raise ValueError("Missing span_label_names in config")
162
+ return [normalize_entity_name(name) for name in names]
163
+
164
+
165
+ def label_thresholds_from_config(config, default_threshold: float) -> dict[str, float]:
166
+ raw = getattr(config, "span_label_thresholds", None) or {}
167
+ out = {normalize_entity_name(key): float(value) for key, value in raw.items()}
168
+ for label in label_names_from_config(config):
169
+ out.setdefault(label, float(default_threshold))
170
+ return out
171
+
172
+
173
+ def token_label_thresholds_from_config(config, default_threshold: float) -> dict[str, float]:
174
+ raw = getattr(config, "token_label_thresholds", None) or {}
175
+ out = {normalize_entity_name(key): float(value) for key, value in raw.items()}
176
+ for label in label_names_from_config(config):
177
+ out.setdefault(label, float(default_threshold))
178
+ return out
179
+
180
+
181
+ def token_extend_thresholds_from_config(config, default_fraction: float = 0.6) -> dict[str, float]:
182
+ raw = getattr(config, "token_extend_thresholds", None) or {}
183
+ out = {normalize_entity_name(key): float(value) for key, value in raw.items()}
184
+ for label in label_names_from_config(config):
185
+ out.setdefault(label, max(0.0, min(1.0, float(token_label_thresholds_from_config(config, 0.5).get(label, 0.5)) * default_fraction)))
186
+ return out
187
+
188
+
189
+ def boundary_label_thresholds_from_config(config, default_threshold: float = 0.0) -> dict[str, float]:
190
+ raw = getattr(config, "boundary_label_thresholds", None) or {}
191
+ out = {normalize_entity_name(key): float(value) for key, value in raw.items()}
192
+ for label in label_names_from_config(config):
193
+ out.setdefault(label, float(default_threshold))
194
+ return out
195
+
196
+
197
+ def label_max_span_tokens_from_config(config) -> dict[str, int]:
198
+ raw = getattr(config, "span_label_max_span_tokens", None) or {}
199
+ out = {normalize_entity_name(key): int(value) for key, value in raw.items()}
200
+ for label, value in DEFAULT_LABEL_MAX_SPAN_TOKENS.items():
201
+ out.setdefault(label, value)
202
+ for label in label_names_from_config(config):
203
+ out.setdefault(label, 8)
204
+ return out
205
+
206
+
207
+ def label_min_nonspace_chars_from_config(config) -> dict[str, int]:
208
+ raw = getattr(config, "span_label_min_nonspace_chars", None) or {}
209
+ out = {normalize_entity_name(key): int(value) for key, value in raw.items()}
210
+ for label, value in DEFAULT_LABEL_MIN_NONSPACE_CHARS.items():
211
+ out.setdefault(label, value)
212
+ for label in label_names_from_config(config):
213
+ out.setdefault(label, 1)
214
+ return out
215
+
216
+
217
+ def overlaps(a: dict, b: dict) -> bool:
218
+ return not (a["end"] <= b["start"] or b["end"] <= a["start"])
219
+
220
+
221
+ def dedupe_spans(spans: list[dict]) -> list[dict]:
222
+ ordered = sorted(
223
+ spans,
224
+ key=lambda item: (-float(item.get("score", 0.0)), item["start"], item["end"], OUTPUT_PRIORITY.get(item["label"], 99)),
225
+ )
226
+ kept = []
227
+ for span in ordered:
228
+ if any(overlaps(span, other) for other in kept):
229
+ continue
230
+ kept.append(span)
231
+ kept.sort(key=lambda item: (item["start"], item["end"], OUTPUT_PRIORITY.get(item["label"], 99)))
232
+ return kept
233
+
234
+
235
+ def _valid_offset(offset: tuple[int, int]) -> bool:
236
+ return bool(offset) and offset[1] > offset[0]
237
+
238
+
239
+ def _has_skippable_bridge(text: str, left: tuple[int, int], right: tuple[int, int], label: str) -> bool:
240
+ bridge = text[int(left[1]) : int(right[0])]
241
+ if bridge == "":
242
+ return True
243
+ if label == "PPSN" and bridge.isspace():
244
+ next_token = _token_text(text, right).strip()
245
+ return 0 < len(next_token) <= 2 and next_token.isalnum()
246
+ if label in WHITESPACE_BRIDGE_LABELS and bridge.isspace():
247
+ return True
248
+ if label in SIMPLE_PUNCT_BRIDGE_LABELS:
249
+ normalized = bridge.replace("\u00A0", " ").replace("\u202F", " ").strip()
250
+ if normalized == "-":
251
+ return True
252
+ return False
253
+
254
+
255
+ def _has_left_extension_bridge(text: str, left: tuple[int, int], right: tuple[int, int]) -> bool:
256
+ bridge = text[int(left[1]) : int(right[0])]
257
+ return bridge == ""
258
+
259
+
260
+ def _nonspace_length(text: str, start: int, end: int) -> int:
261
+ return sum(0 if ch.isspace() else 1 for ch in text[int(start) : int(end)])
262
+
263
+
264
+ def _is_simple_punct_token(text: str, offset: tuple[int, int], label: str) -> bool:
265
+ if label not in SIMPLE_PUNCT_BRIDGE_LABELS or not _valid_offset(offset):
266
+ return False
267
+ token_text = text[int(offset[0]) : int(offset[1])].replace("\u00A0", " ").replace("\u202F", " ").strip()
268
+ return token_text == "-"
269
+
270
+
271
+ def _token_text(text: str, offset: tuple[int, int]) -> str:
272
+ return text[int(offset[0]) : int(offset[1])]
273
+
274
+
275
+ def _is_short_alnum_token(text: str, offset: tuple[int, int], max_len: int = 4) -> bool:
276
+ token_text = _token_text(text, offset).strip()
277
+ return 0 < len(token_text) <= max_len and token_text.isalnum()
278
+
279
+
280
+ def _rescue_structured_start(
281
+ text: str,
282
+ offsets: list[tuple[int, int]],
283
+ valid: list[bool],
284
+ token_scores: np.ndarray,
285
+ start_scores: np.ndarray,
286
+ label: str,
287
+ label_index: int,
288
+ threshold: float,
289
+ boundary_threshold: float,
290
+ start_idx: int,
291
+ end_idx: int,
292
+ ) -> int | None:
293
+ if label not in {"ACCOUNT_NUMBER", "CREDIT_DEBIT_CARD"}:
294
+ return None
295
+ segment_text = text[int(offsets[start_idx][0]) : int(offsets[end_idx][1])]
296
+ if label == "ACCOUNT_NUMBER" and not any(ch.isspace() for ch in segment_text):
297
+ return None
298
+ best_idx = None
299
+ best_score = -1.0
300
+ for cand_idx in range(start_idx, end_idx + 1):
301
+ if not valid[cand_idx]:
302
+ continue
303
+ token_score = float(token_scores[cand_idx, label_index])
304
+ start_score = float(start_scores[cand_idx, label_index])
305
+ if token_score < threshold or start_score < boundary_threshold:
306
+ continue
307
+ token_text = _token_text(text, offsets[cand_idx]).strip()
308
+ score = start_score + 0.2 * token_score
309
+ if label == "ACCOUNT_NUMBER":
310
+ next_text = _token_text(text, offsets[cand_idx + 1]).strip() if cand_idx + 1 <= end_idx and valid[cand_idx + 1] else ""
311
+ if token_text.upper() == "I" and next_text.upper() == "E":
312
+ score += 1.0
313
+ elif token_text.upper().startswith("IE"):
314
+ score += 0.6
315
+ elif label == "CREDIT_DEBIT_CARD" and token_text.isdigit():
316
+ score += 0.3
317
+ if score > best_score:
318
+ best_idx = cand_idx
319
+ best_score = score
320
+ return best_idx
321
+
322
+
323
+ def _rescue_email_outer_span(span_text: str, outer_text: str) -> bool:
324
+ if "@" not in span_text or " " in outer_text:
325
+ return False
326
+ if "@" not in outer_text:
327
+ return False
328
+ _, _, span_domain = span_text.partition("@")
329
+ _, _, outer_domain = outer_text.partition("@")
330
+ if "." in span_domain and not span_text.endswith("@"):
331
+ return False
332
+ return "." in outer_domain and not outer_text.endswith("@")
333
+
334
+
335
+ def _rescue_iban_tail(text: str, offsets: list[tuple[int, int]], valid: list[bool], start_idx: int, end_idx: int) -> int:
336
+ next_idx = end_idx + 1
337
+ span_text = text[int(offsets[start_idx][0]) : int(offsets[end_idx][1])]
338
+ if not any(ch.isspace() for ch in span_text):
339
+ return end_idx
340
+ compact = "".join(ch for ch in span_text if not ch.isspace())
341
+ if not compact.upper().startswith("IE"):
342
+ return end_idx
343
+ while next_idx < len(offsets) and valid[next_idx]:
344
+ if not _has_skippable_bridge(text, offsets[end_idx], offsets[next_idx], "ACCOUNT_NUMBER"):
345
+ break
346
+ if not _is_short_alnum_token(text, offsets[next_idx]):
347
+ break
348
+ end_idx = next_idx
349
+ span_text = text[int(offsets[start_idx][0]) : int(offsets[end_idx][1])]
350
+ compact = "".join(ch for ch in span_text if not ch.isspace())
351
+ if len(compact) >= 22:
352
+ break
353
+ next_idx += 1
354
+ return end_idx
355
+
356
+
357
+ def decode_span_logits(
358
+ text: str,
359
+ offsets: list[tuple[int, int]],
360
+ start_scores: np.ndarray,
361
+ end_scores: np.ndarray,
362
+ label_names: list[str],
363
+ default_threshold: float,
364
+ label_thresholds: dict[str, float] | None = None,
365
+ label_max_span_tokens: dict[str, int] | None = None,
366
+ ) -> list[dict]:
367
+ thresholds = {label: float(default_threshold) for label in label_names}
368
+ if label_thresholds:
369
+ thresholds.update({normalize_entity_name(key): float(value) for key, value in label_thresholds.items()})
370
+ max_tokens = dict(DEFAULT_LABEL_MAX_SPAN_TOKENS)
371
+ if label_max_span_tokens:
372
+ max_tokens.update({normalize_entity_name(key): int(value) for key, value in label_max_span_tokens.items()})
373
+
374
+ spans: list[dict] = []
375
+ for label_index, label in enumerate(label_names):
376
+ threshold = thresholds.get(label, float(default_threshold))
377
+ max_span = max_tokens.get(label, 8)
378
+ start_candidates = [idx for idx in range(len(offsets)) if _valid_offset(offsets[idx]) and float(start_scores[idx, label_index]) >= threshold]
379
+ for start_idx in start_candidates:
380
+ best = None
381
+ for end_idx in range(start_idx, min(len(offsets), start_idx + max_span)):
382
+ if not _valid_offset(offsets[end_idx]):
383
+ continue
384
+ end_score = float(end_scores[end_idx, label_index])
385
+ if end_score < threshold:
386
+ continue
387
+ score = min(float(start_scores[start_idx, label_index]), end_score)
388
+ if best is None or score > best["score"]:
389
+ best = {
390
+ "label": label,
391
+ "start": int(offsets[start_idx][0]),
392
+ "end": int(offsets[end_idx][1]),
393
+ "score": score,
394
+ }
395
+ if best is not None and best["start"] < best["end"]:
396
+ best["text"] = text[best["start"]:best["end"]]
397
+ spans.append(best)
398
+ return dedupe_spans(spans)
399
+
400
+
401
+ def decode_token_presence_segments(
402
+ text: str,
403
+ offsets: list[tuple[int, int]],
404
+ token_scores: np.ndarray,
405
+ label_names: list[str],
406
+ default_threshold: float,
407
+ label_thresholds: dict[str, float] | None = None,
408
+ label_extend_thresholds: dict[str, float] | None = None,
409
+ label_max_span_tokens: dict[str, int] | None = None,
410
+ label_min_nonspace_chars: dict[str, int] | None = None,
411
+ boundary_label_thresholds: dict[str, float] | None = None,
412
+ start_scores: np.ndarray | None = None,
413
+ end_scores: np.ndarray | None = None,
414
+ ) -> list[dict]:
415
+ thresholds = {label: float(default_threshold) for label in label_names}
416
+ if label_thresholds:
417
+ thresholds.update({normalize_entity_name(key): float(value) for key, value in label_thresholds.items()})
418
+ extend_thresholds = {label: max(0.0, min(1.0, thresholds[label] * 0.6)) for label in label_names}
419
+ if label_extend_thresholds:
420
+ extend_thresholds.update({normalize_entity_name(key): float(value) for key, value in label_extend_thresholds.items()})
421
+ max_tokens = dict(DEFAULT_LABEL_MAX_SPAN_TOKENS)
422
+ if label_max_span_tokens:
423
+ max_tokens.update({normalize_entity_name(key): int(value) for key, value in label_max_span_tokens.items()})
424
+ min_nonspace_chars = dict(DEFAULT_LABEL_MIN_NONSPACE_CHARS)
425
+ if label_min_nonspace_chars:
426
+ min_nonspace_chars.update({normalize_entity_name(key): int(value) for key, value in label_min_nonspace_chars.items()})
427
+ boundary_thresholds = {label: 0.0 for label in label_names}
428
+ if boundary_label_thresholds:
429
+ boundary_thresholds.update({normalize_entity_name(key): float(value) for key, value in boundary_label_thresholds.items()})
430
+
431
+ spans: list[dict] = []
432
+ valid = [_valid_offset(offset) for offset in offsets]
433
+ num_tokens = len(offsets)
434
+ for label_index, label in enumerate(label_names):
435
+ threshold = thresholds.get(label, float(default_threshold))
436
+ extend_threshold = min(threshold, extend_thresholds.get(label, threshold))
437
+ max_span = max_tokens.get(label, 8)
438
+ idx = 0
439
+ while idx < num_tokens:
440
+ if not valid[idx] or float(token_scores[idx, label_index]) < threshold:
441
+ idx += 1
442
+ continue
443
+ start_idx = idx
444
+ end_idx = idx
445
+ outer_start_idx = start_idx
446
+ outer_end_idx = end_idx
447
+ while end_idx + 1 < num_tokens and valid[end_idx + 1] and float(token_scores[end_idx + 1, label_index]) >= threshold and (end_idx + 1 - start_idx + 1) <= max_span:
448
+ end_idx += 1
449
+ while (
450
+ start_idx - 1 >= 0
451
+ and valid[start_idx - 1]
452
+ and _has_left_extension_bridge(text, offsets[start_idx - 1], offsets[start_idx])
453
+ and float(token_scores[start_idx - 1, label_index]) >= extend_threshold
454
+ and (end_idx - (start_idx - 1) + 1) <= max_span
455
+ ):
456
+ start_idx -= 1
457
+ while end_idx + 1 < num_tokens:
458
+ next_idx = end_idx + 1
459
+ if not valid[next_idx]:
460
+ break
461
+ if (
462
+ _has_skippable_bridge(text, offsets[end_idx], offsets[next_idx], label)
463
+ and float(token_scores[next_idx, label_index]) >= extend_threshold
464
+ and (next_idx - start_idx + 1) <= max_span
465
+ ):
466
+ end_idx = next_idx
467
+ continue
468
+ if (
469
+ _is_simple_punct_token(text, offsets[next_idx], label)
470
+ and next_idx + 1 < num_tokens
471
+ and valid[next_idx + 1]
472
+ and _has_skippable_bridge(text, offsets[end_idx], offsets[next_idx], label)
473
+ and _has_skippable_bridge(text, offsets[next_idx], offsets[next_idx + 1], label)
474
+ and float(token_scores[next_idx + 1, label_index]) >= extend_threshold
475
+ and ((next_idx + 1) - start_idx + 1) <= max_span
476
+ ):
477
+ end_idx = next_idx + 1
478
+ continue
479
+ break
480
+ outer_start_idx = start_idx
481
+ outer_end_idx = end_idx
482
+ presence_slice = token_scores[start_idx : end_idx + 1, label_index]
483
+ score = float(presence_slice.mean())
484
+ out_start_idx = start_idx
485
+ out_end_idx = end_idx
486
+ if start_scores is not None and end_scores is not None:
487
+ refine_window = min(3, end_idx - start_idx + 1)
488
+ start_window = start_scores[start_idx : start_idx + refine_window, label_index]
489
+ best_start_rel = int(np.argmax(start_window))
490
+ best_start_idx = start_idx + best_start_rel
491
+ end_window_start = max(best_start_idx, end_idx - refine_window + 1)
492
+ end_window = end_scores[end_window_start : end_idx + 1, label_index]
493
+ best_end_rel = int(np.argmax(end_window))
494
+ best_end_idx = end_window_start + best_end_rel
495
+ if (
496
+ float(start_scores[best_start_idx, label_index]) < boundary_thresholds.get(label, 0.0)
497
+ or float(end_scores[best_end_idx, label_index]) < boundary_thresholds.get(label, 0.0)
498
+ ):
499
+ rescued_start_idx = _rescue_structured_start(
500
+ text,
501
+ offsets,
502
+ valid,
503
+ token_scores,
504
+ start_scores,
505
+ label,
506
+ label_index,
507
+ threshold,
508
+ boundary_thresholds.get(label, 0.0),
509
+ start_idx,
510
+ end_idx,
511
+ )
512
+ if rescued_start_idx is not None:
513
+ out_start_idx = rescued_start_idx
514
+ out_end_idx = end_idx
515
+ else:
516
+ idx = end_idx + 1
517
+ continue
518
+ else:
519
+ out_start_idx = best_start_idx
520
+ out_end_idx = best_end_idx
521
+ if label in CONSERVATIVE_BOUNDARY_REFINEMENT_LABELS and (
522
+ best_start_idx != start_idx or best_end_idx != end_idx
523
+ ):
524
+ outer_boundary = min(float(start_scores[start_idx, label_index]), float(end_scores[end_idx, label_index]))
525
+ refined_boundary = min(
526
+ float(start_scores[best_start_idx, label_index]),
527
+ float(end_scores[best_end_idx, label_index]),
528
+ )
529
+ if refined_boundary < outer_boundary + 0.08:
530
+ out_start_idx = start_idx
531
+ out_end_idx = end_idx
532
+ score = (
533
+ 0.65 * score
534
+ + 0.175 * float(start_scores[out_start_idx, label_index])
535
+ + 0.175 * float(end_scores[out_end_idx, label_index])
536
+ )
537
+ min_chars = int(min_nonspace_chars.get(label, 1))
538
+ if _nonspace_length(text, offsets[out_start_idx][0], offsets[out_end_idx][1]) < min_chars:
539
+ if (
540
+ label in MIN_CHAR_FALLBACK_LABELS
541
+ and (out_start_idx != start_idx or out_end_idx != end_idx)
542
+ and _nonspace_length(text, offsets[start_idx][0], offsets[end_idx][1]) >= min_chars
543
+ ):
544
+ out_start_idx = start_idx
545
+ out_end_idx = end_idx
546
+ else:
547
+ idx = end_idx + 1
548
+ continue
549
+ if label == "ACCOUNT_NUMBER":
550
+ out_end_idx = _rescue_iban_tail(text, offsets, valid, out_start_idx, out_end_idx)
551
+ span_text = text[int(offsets[out_start_idx][0]) : int(offsets[out_end_idx][1])]
552
+ outer_text = text[int(offsets[outer_start_idx][0]) : int(offsets[outer_end_idx][1])]
553
+ if label == "EMAIL" and _rescue_email_outer_span(span_text, outer_text):
554
+ out_start_idx = outer_start_idx
555
+ out_end_idx = outer_end_idx
556
+ span_text = outer_text
557
+ if label in {"FIRST_NAME", "LAST_NAME"} and any(ch.isdigit() for ch in span_text):
558
+ idx = end_idx + 1
559
+ continue
560
+ spans.append(
561
+ {
562
+ "label": label,
563
+ "start": int(offsets[out_start_idx][0]),
564
+ "end": int(offsets[out_end_idx][1]),
565
+ "score": score,
566
+ "text": span_text,
567
+ }
568
+ )
569
+ idx = end_idx + 1
570
+ return dedupe_spans(spans)
571
+
572
+
573
+ def load_onnx_session(model_ref: str, onnx_file: str = "model_quantized.onnx", onnx_subfolder: str = "onnx"):
574
+ import onnxruntime as ort
575
+
576
+ model_path = Path(model_ref)
577
+ if model_path.exists():
578
+ candidates = []
579
+ if onnx_subfolder:
580
+ candidates.append(model_path / onnx_subfolder / onnx_file)
581
+ candidates.append(model_path / onnx_file)
582
+ onnx_path = next((path for path in candidates if path.exists()), candidates[0])
583
+ config = AutoConfig.from_pretrained(model_ref)
584
+ tokenizer = safe_auto_tokenizer(model_ref)
585
+ else:
586
+ remote_name = f"{onnx_subfolder}/{onnx_file}" if onnx_subfolder else onnx_file
587
+ onnx_path = Path(hf_hub_download(repo_id=model_ref, filename=remote_name, repo_type="model"))
588
+ config = AutoConfig.from_pretrained(model_ref)
589
+ tokenizer = safe_auto_tokenizer(model_ref)
590
+ sess_options = ort.SessionOptions()
591
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
592
+ intra_threads_raw = os.environ.get("OPENMED_ORT_INTRA_OP_THREADS", "").strip()
593
+ if intra_threads_raw:
594
+ try:
595
+ intra_threads = max(1, int(intra_threads_raw))
596
+ except ValueError:
597
+ intra_threads = 4
598
+ else:
599
+ cpu_count = os.cpu_count() or 4
600
+ intra_threads = max(1, min(4, cpu_count))
601
+ sess_options.intra_op_num_threads = intra_threads
602
+ sess_options.inter_op_num_threads = 1
603
+ sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
604
+ session = ort.InferenceSession(str(onnx_path), sess_options=sess_options, providers=["CPUExecutionProvider"])
605
+ return session, tokenizer, config
606
+
607
+
608
+ def run_onnx(session, encoded: dict[str, Any]) -> tuple[np.ndarray, np.ndarray]:
609
+ feed = {}
610
+ input_names = {item.name for item in session.get_inputs()}
611
+ for key, value in encoded.items():
612
+ if key == "offset_mapping":
613
+ continue
614
+ if key in input_names:
615
+ feed[key] = value
616
+ outputs = session.run(None, feed)
617
+ return outputs[0], outputs[1]
618
+
619
+
620
+ def run_onnx_all(session, encoded: dict[str, Any]) -> list[np.ndarray]:
621
+ feed = {}
622
+ input_names = {item.name for item in session.get_inputs()}
623
+ for key, value in encoded.items():
624
+ if key == "offset_mapping":
625
+ continue
626
+ if key in input_names:
627
+ feed[key] = value
628
+ return session.run(None, feed)
common.py CHANGED
@@ -14,26 +14,15 @@ ROOT_DIR = Path(__file__).resolve().parents[2]
14
  if str(ROOT_DIR) not in sys.path:
15
  sys.path.insert(0, str(ROOT_DIR))
16
 
17
- try:
18
- from ..irish_core_span_raw_only.common import (
19
- dedupe_spans,
20
- label_max_span_tokens_from_config,
21
- label_min_nonspace_chars_from_config,
22
- label_names_from_config,
23
- load_onnx_session,
24
- normalize_entity_name,
25
- safe_auto_tokenizer,
26
- )
27
- except ImportError:
28
- from experiments.irish_core_span_raw_only.common import (
29
- dedupe_spans,
30
- label_max_span_tokens_from_config,
31
- label_min_nonspace_chars_from_config,
32
- label_names_from_config,
33
- load_onnx_session,
34
- normalize_entity_name,
35
- safe_auto_tokenizer,
36
- )
37
 
38
 
39
  def label_thresholds_from_config(config, default_threshold: float) -> dict[str, float]:
 
14
  if str(ROOT_DIR) not in sys.path:
15
  sys.path.insert(0, str(ROOT_DIR))
16
 
17
+ from base_common import (
18
+ dedupe_spans,
19
+ label_max_span_tokens_from_config,
20
+ label_min_nonspace_chars_from_config,
21
+ label_names_from_config,
22
+ load_onnx_session,
23
+ normalize_entity_name,
24
+ safe_auto_tokenizer,
25
+ )
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  def label_thresholds_from_config(config, default_threshold: float) -> dict[str, float]: