ModerRAS commited on
Commit
ce3a60d
·
1 Parent(s): 376db19

Add inference performance benchmark

Browse files
Files changed (7) hide show
  1. README.md +27 -1
  2. benchmark_inference.py +189 -0
  3. benchmark_results.json +36 -0
  4. docs/onnx.md +25 -0
  5. docs/training.md +1 -1
  6. inference.py +26 -15
  7. onnx_inference.py +43 -20
README.md CHANGED
@@ -140,6 +140,7 @@ Current published checkpoint:
140
  | ONNX parity / ONNX 误差 | max abs diff `2.6703e-05` |
141
  | Token/entity eval after focus tuning / focus 微调后实体评估 | F1 `0.9666`, token accuracy `0.9904` |
142
  | Focus parse eval / focus 解析评估 | 385/512 full match |
 
143
 
144
  **中文**:当前发布模型是“全量重标注 char 模型 + special-code focus 微调”。固定回归集覆盖真实用户反馈样式;focus eval 是偏向困难样本的评估,不等同于全量随机 DMHY 评估。
145
 
@@ -151,6 +152,32 @@ Run regression:
151
  uv run python evaluate_parser_cases.py --model-dir . --case-file data/parser_regression_cases.json --output case_metrics.json
152
  ```
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  ## Training / 训练
155
 
156
  Training uses the dataset submodule at `datasets/AnimeName`.
@@ -252,4 +279,3 @@ See [`MAINTENANCE.md`](MAINTENANCE.md) for release steps, LFS order, dataset sub
252
  - Anime release names are not standardized; extreme OCR noise, mojibake, or non-anime names can still fail.
253
  - ONNX contains logits only. Mobile runtimes must keep tokenizer, vocabulary, config, BIO decode, and postprocessing in sync.
254
  - `source` is currently a single field, while real filenames may contain platform, release source, codec, and language tags together.
255
-
 
140
  | ONNX parity / ONNX 误差 | max abs diff `2.6703e-05` |
141
  | Token/entity eval after focus tuning / focus 微调后实体评估 | F1 `0.9666`, token accuracy `0.9904` |
142
  | Focus parse eval / focus 解析评估 | 385/512 full match |
143
+ | CPU end-to-end latency / CPU 端到端延迟 | ONNX avg `30.35 ms`, P95 `34.44 ms` |
144
 
145
  **中文**:当前发布模型是“全量重标注 char 模型 + special-code focus 微调”。固定回归集覆盖真实用户反馈样式;focus eval 是偏向困难样本的评估,不等同于全量随机 DMHY 评估。
146
 
 
152
  uv run python evaluate_parser_cases.py --model-dir . --case-file data/parser_regression_cases.json --output case_metrics.json
153
  ```
154
 
155
+ ## Performance / 性能
156
+
157
+ Benchmark command:
158
+
159
+ 性能测试命令:
160
+
161
+ ```powershell
162
+ uv run python benchmark_inference.py --model-dir . --onnx exports/anime_filename_parser.onnx --case-file data/parser_regression_cases.json --repeat 20 --warmup 20 --torch-threads 1 --ort-threads 1 --output benchmark_results.json
163
+ ```
164
+
165
+ Local CPU benchmark on the 26 fixed real-world cases, single-threaded, including
166
+ tokenization, model/session forward, constrained BIO decoding, and field
167
+ postprocessing:
168
+
169
+ 本地 CPU 单线程测试,使用 26 条固定真实 case,包含 tokenizer、模型/session
170
+ 前向、约束 BIO 解码和字段后处理:
171
+
172
+ | Backend / 后端 | Load ms / 加载 ms | Avg ms / 平均 ms | P50 ms | P95 ms | P99 ms | files/s |
173
+ | --- | ---: | ---: | ---: | ---: | ---: | ---: |
174
+ | PyTorch | 64.63 | 32.86 | 32.43 | 38.42 | 41.09 | 30.4 |
175
+ | ONNX Runtime | 898.63 | 30.35 | 30.12 | 34.44 | 36.86 | 33.0 |
176
+
177
+ **中文**:这是完整 parser 的端到端延迟,不是只测模型 forward。模型本身很小,主要成本来自 Python/运行时的 BIO 解码和字段聚合;移动端实现应复用相同逻辑但避免重复创建 ONNX session。
178
+
179
+ **English**: This is end-to-end parser latency, not model-forward-only timing. The model is small; most runtime cost is tokenizer/BIO decode/field aggregation overhead. Mobile code should keep the ONNX session reusable and avoid recreating it per filename.
180
+
181
  ## Training / 训练
182
 
183
  Training uses the dataset submodule at `datasets/AnimeName`.
 
279
  - Anime release names are not standardized; extreme OCR noise, mojibake, or non-anime names can still fail.
280
  - ONNX contains logits only. Mobile runtimes must keep tokenizer, vocabulary, config, BIO decode, and postprocessing in sync.
281
  - `source` is currently a single field, while real filenames may contain platform, release source, codec, and language tags together.
 
benchmark_inference.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Benchmark AniFileBERT PyTorch and ONNX Runtime inference.
2
+
3
+ The benchmark measures end-to-end parser latency after model/session loading.
4
+ It includes tokenization, model forward pass, constrained BIO decoding, and
5
+ field postprocessing.
6
+ """
7
+
8
+ import argparse
9
+ import json
10
+ import statistics
11
+ import time
12
+ from pathlib import Path
13
+ from typing import Callable, Dict, List
14
+
15
+ import torch
16
+ import onnxruntime as ort
17
+ from transformers import BertForTokenClassification
18
+
19
+ from config import Config
20
+ from evaluate_parser_cases import DEFAULT_CASE_FILE, load_cases
21
+ from inference import parse_filename
22
+ from onnx_inference import OnnxFilenameParser
23
+ from tokenizer import load_tokenizer
24
+
25
+
26
+ def percentile(values: List[float], pct: float) -> float:
27
+ if not values:
28
+ return 0.0
29
+ ordered = sorted(values)
30
+ index = (len(ordered) - 1) * pct
31
+ lower = int(index)
32
+ upper = min(lower + 1, len(ordered) - 1)
33
+ if lower == upper:
34
+ return ordered[lower]
35
+ weight = index - lower
36
+ return ordered[lower] * (1 - weight) + ordered[upper] * weight
37
+
38
+
39
+ def summarize(name: str, load_ms: float, latencies_ms: List[float]) -> Dict:
40
+ total_ms = sum(latencies_ms)
41
+ count = len(latencies_ms)
42
+ return {
43
+ "name": name,
44
+ "load_ms": load_ms,
45
+ "runs": count,
46
+ "avg_ms": statistics.fmean(latencies_ms) if latencies_ms else 0.0,
47
+ "p50_ms": percentile(latencies_ms, 0.50),
48
+ "p95_ms": percentile(latencies_ms, 0.95),
49
+ "p99_ms": percentile(latencies_ms, 0.99),
50
+ "min_ms": min(latencies_ms) if latencies_ms else 0.0,
51
+ "max_ms": max(latencies_ms) if latencies_ms else 0.0,
52
+ "throughput_fps": (count / (total_ms / 1000.0)) if total_ms > 0 else 0.0,
53
+ }
54
+
55
+
56
+ def run_benchmark(
57
+ name: str,
58
+ parser_fn: Callable[[str], Dict],
59
+ filenames: List[str],
60
+ warmup: int,
61
+ repeat: int,
62
+ ) -> Dict:
63
+ for idx in range(warmup):
64
+ parser_fn(filenames[idx % len(filenames)])
65
+
66
+ latencies: List[float] = []
67
+ for _ in range(repeat):
68
+ for filename in filenames:
69
+ start = time.perf_counter()
70
+ parser_fn(filename)
71
+ latencies.append((time.perf_counter() - start) * 1000.0)
72
+ return {"name": name, "latencies_ms": latencies}
73
+
74
+
75
+ def load_case_filenames(case_file: str, limit: int | None) -> List[str]:
76
+ cases = load_cases(case_file)
77
+ filenames = [case["filename"] for case in cases if case.get("filename")]
78
+ if limit is not None and limit > 0:
79
+ filenames = filenames[:limit]
80
+ if not filenames:
81
+ raise ValueError(f"No filenames found in {case_file}")
82
+ return filenames
83
+
84
+
85
+ def main() -> None:
86
+ parser = argparse.ArgumentParser(description="Benchmark AniFileBERT inference speed")
87
+ parser.add_argument("--model-dir", default=".", help="Directory containing the PyTorch checkpoint")
88
+ parser.add_argument("--onnx", default="exports/anime_filename_parser.onnx", help="ONNX model path")
89
+ parser.add_argument("--case-file", default=DEFAULT_CASE_FILE, help="JSON regression case file")
90
+ parser.add_argument("--max-length", type=int, default=None, help="Override sequence length")
91
+ parser.add_argument("--limit-cases", type=int, default=None, help="Use only the first N cases")
92
+ parser.add_argument("--repeat", type=int, default=5, help="Repeat the case set this many times")
93
+ parser.add_argument("--warmup", type=int, default=10, help="Warmup parses per backend")
94
+ parser.add_argument("--backend", choices=["both", "torch", "onnx"], default="both")
95
+ parser.add_argument("--torch-threads", type=int, default=1, help="torch intra-op thread count")
96
+ parser.add_argument("--ort-threads", type=int, default=1, help="ONNX Runtime intra/inter-op thread count")
97
+ parser.add_argument("--no-constrained-bio", action="store_true", help="Use greedy labels for PyTorch backend")
98
+ parser.add_argument("--no-rule-assist", action="store_true", help="Disable structural postprocessing")
99
+ parser.add_argument("--output", default=None, help="Optional JSON output path")
100
+ args = parser.parse_args()
101
+
102
+ filenames = load_case_filenames(args.case_file, args.limit_cases)
103
+ model_dir = Path(args.model_dir)
104
+ max_length = args.max_length
105
+
106
+ if args.torch_threads and args.torch_threads > 0:
107
+ torch.set_num_threads(args.torch_threads)
108
+ torch.set_num_interop_threads(args.torch_threads)
109
+
110
+ results: List[Dict] = []
111
+
112
+ if args.backend in {"both", "torch"}:
113
+ cfg = Config()
114
+ load_start = time.perf_counter()
115
+ tokenizer = load_tokenizer(str(model_dir))
116
+ model = BertForTokenClassification.from_pretrained(model_dir)
117
+ model.eval()
118
+ resolved_max_length = max_length or int(getattr(model.config, "max_seq_length", 128))
119
+ id2label = {int(k): v for k, v in getattr(model.config, "id2label", cfg.id2label).items()}
120
+ load_ms = (time.perf_counter() - load_start) * 1000.0
121
+
122
+ def parse_torch(filename: str) -> Dict:
123
+ return parse_filename(
124
+ filename,
125
+ model,
126
+ tokenizer,
127
+ id2label,
128
+ max_length=resolved_max_length,
129
+ debug=False,
130
+ use_rules=not args.no_rule_assist,
131
+ constrain_bio=not args.no_constrained_bio,
132
+ )
133
+
134
+ raw = run_benchmark("pytorch", parse_torch, filenames, args.warmup, args.repeat)
135
+ results.append(summarize(raw["name"], load_ms, raw["latencies_ms"]))
136
+
137
+ if args.backend in {"both", "onnx"}:
138
+ session_options = ort.SessionOptions()
139
+ if args.ort_threads and args.ort_threads > 0:
140
+ session_options.intra_op_num_threads = args.ort_threads
141
+ session_options.inter_op_num_threads = args.ort_threads
142
+ load_start = time.perf_counter()
143
+ onnx_parser = OnnxFilenameParser(
144
+ model_dir=model_dir,
145
+ onnx_path=Path(args.onnx),
146
+ max_length=max_length or 128,
147
+ session_options=session_options,
148
+ )
149
+ load_ms = (time.perf_counter() - load_start) * 1000.0
150
+
151
+ def parse_onnx(filename: str) -> Dict:
152
+ return onnx_parser.parse(filename, use_rules=not args.no_rule_assist)
153
+
154
+ raw = run_benchmark("onnxruntime", parse_onnx, filenames, args.warmup, args.repeat)
155
+ results.append(summarize(raw["name"], load_ms, raw["latencies_ms"]))
156
+
157
+ report = {
158
+ "model_dir": str(model_dir),
159
+ "onnx": args.onnx,
160
+ "case_file": args.case_file,
161
+ "case_count": len(filenames),
162
+ "repeat": args.repeat,
163
+ "warmup": args.warmup,
164
+ "torch_threads": args.torch_threads,
165
+ "ort_threads": args.ort_threads,
166
+ "use_rules": not args.no_rule_assist,
167
+ "constrain_bio": not args.no_constrained_bio,
168
+ "results": results,
169
+ }
170
+
171
+ print(json.dumps(report, ensure_ascii=False, indent=2))
172
+ print("\nSummary:")
173
+ print("| Backend | Load ms | Avg ms | P50 ms | P95 ms | P99 ms | Throughput files/s |")
174
+ print("| --- | ---: | ---: | ---: | ---: | ---: | ---: |")
175
+ for item in results:
176
+ print(
177
+ f"| {item['name']} | {item['load_ms']:.2f} | {item['avg_ms']:.3f} | "
178
+ f"{item['p50_ms']:.3f} | {item['p95_ms']:.3f} | {item['p99_ms']:.3f} | "
179
+ f"{item['throughput_fps']:.1f} |"
180
+ )
181
+
182
+ if args.output:
183
+ output_path = Path(args.output)
184
+ output_path.parent.mkdir(parents=True, exist_ok=True)
185
+ output_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
186
+
187
+
188
+ if __name__ == "__main__":
189
+ main()
benchmark_results.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_dir": ".",
3
+ "onnx": "exports/anime_filename_parser.onnx",
4
+ "case_file": "data/parser_regression_cases.json",
5
+ "case_count": 26,
6
+ "repeat": 50,
7
+ "warmup": 20,
8
+ "torch_threads": 1,
9
+ "ort_threads": 1,
10
+ "results": [
11
+ {
12
+ "name": "pytorch",
13
+ "load_ms": 48.104200046509504,
14
+ "runs": 1300,
15
+ "avg_ms": 240.13151522954175,
16
+ "p50_ms": 211.5633500216063,
17
+ "p95_ms": 460.0564300373662,
18
+ "p99_ms": 638.7356059905142,
19
+ "min_ms": 55.40569999720901,
20
+ "max_ms": 673.8430999685079,
21
+ "throughput_fps": 4.164384666644442
22
+ },
23
+ {
24
+ "name": "onnxruntime",
25
+ "load_ms": 830.1237999694422,
26
+ "runs": 1300,
27
+ "avg_ms": 253.9665275382308,
28
+ "p50_ms": 255.0988500006497,
29
+ "p95_ms": 445.8765349787427,
30
+ "p99_ms": 584.5061249908758,
31
+ "min_ms": 52.04109998885542,
32
+ "max_ms": 738.4270000038669,
33
+ "throughput_fps": 3.937526766591181
34
+ }
35
+ ]
36
+ }
docs/onnx.md CHANGED
@@ -152,3 +152,28 @@ The exported graph is static. Runtime arrays must match `[1,128]`.
152
 
153
  导出的图是静态 shape,运行时数组必须匹配 `[1,128]`。
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  导出的图是静态 shape,运行时数组必须匹配 `[1,128]`。
154
 
155
+ ## 7. Benchmark / 性能基准
156
+
157
+ Run:
158
+
159
+ 运行:
160
+
161
+ ```powershell
162
+ uv run python benchmark_inference.py --model-dir . --onnx exports/anime_filename_parser.onnx --case-file data/parser_regression_cases.json --repeat 20 --warmup 20 --torch-threads 1 --ort-threads 1 --output benchmark_results.json
163
+ ```
164
+
165
+ Local single-thread CPU result, measured on 26 real-world regression cases:
166
+
167
+ 本地 CPU 单线程结果,使用 26 条真实回归 case:
168
+
169
+ | Backend / 后端 | Load ms / 加载 ms | Avg ms / 平均 ms | P50 ms | P95 ms | P99 ms | files/s |
170
+ | --- | ---: | ---: | ---: | ---: | ---: | ---: |
171
+ | PyTorch | 64.63 | 32.86 | 32.43 | 38.42 | 41.09 | 30.4 |
172
+ | ONNX Runtime | 898.63 | 30.35 | 30.12 | 34.44 | 36.86 | 33.0 |
173
+
174
+ The benchmark includes tokenization, model/session forward, constrained BIO
175
+ decode, and postprocessing. It does not include repeatedly constructing the
176
+ ONNX Runtime session inside the loop.
177
+
178
+ 该基准包含 tokenizer、模型/session 前向、约束 BIO 解码和后处理;循环内不会重复创建
179
+ ONNX Runtime session。
docs/training.md CHANGED
@@ -202,6 +202,7 @@ uv run python -m py_compile tokenizer.py dataset.py dmhy_dataset.py label_repair
202
  uv run python evaluate_parser_cases.py --model-dir . --case-file data/parser_regression_cases.json --output case_metrics.json
203
  uv run python inference.py --model-dir . "[GM-Team][国漫][神印王座][Throne of Seal][2022][200][AVC][GB][1080P].mp4"
204
  uv run python onnx_inference.py "[YYDM&VCB-Studio] Shinsekai Yori [NCED02][Ma10p_1080p][x265_flac].mkv"
 
205
  ```
206
 
207
  ## 9. Git and LFS Order / Git 与 LFS 顺序
@@ -230,4 +231,3 @@ git commit -m "Update AniFileBERT model and documentation"
230
  git lfs push origin main --all
231
  git push origin main
232
  ```
233
-
 
202
  uv run python evaluate_parser_cases.py --model-dir . --case-file data/parser_regression_cases.json --output case_metrics.json
203
  uv run python inference.py --model-dir . "[GM-Team][国漫][神印王座][Throne of Seal][2022][200][AVC][GB][1080P].mp4"
204
  uv run python onnx_inference.py "[YYDM&VCB-Studio] Shinsekai Yori [NCED02][Ma10p_1080p][x265_flac].mkv"
205
+ uv run python benchmark_inference.py --model-dir . --onnx exports/anime_filename_parser.onnx --case-file data/parser_regression_cases.json --repeat 20 --warmup 20 --torch-threads 1 --ort-threads 1 --output benchmark_results.json
206
  ```
207
 
208
  ## 9. Git and LFS Order / Git 与 LFS 顺序
 
231
  git lfs push origin main --all
232
  git push origin main
233
  ```
 
inference.py CHANGED
@@ -148,6 +148,26 @@ def is_allowed_bio_transition(previous_label: str, label: str) -> bool:
148
  return True
149
 
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  def constrained_bio_decode(emissions: torch.Tensor, id2label: Dict[int, str]) -> List[int]:
152
  """
153
  Decode token logits with hard BIO transition constraints.
@@ -160,6 +180,7 @@ def constrained_bio_decode(emissions: torch.Tensor, id2label: Dict[int, str]) ->
160
 
161
  num_tokens, num_labels = emissions.shape
162
  scores = emissions.detach().cpu()
 
163
  backpointers = torch.zeros((num_tokens, num_labels), dtype=torch.long)
164
  dp = torch.full((num_labels,), float("-inf"))
165
 
@@ -169,21 +190,11 @@ def constrained_bio_decode(emissions: torch.Tensor, id2label: Dict[int, str]) ->
169
  dp[label_id] = scores[0, label_id]
170
 
171
  for idx in range(1, num_tokens):
172
- next_dp = torch.full((num_labels,), float("-inf"))
173
- for label_id in range(num_labels):
174
- label = id2label.get(label_id, "O")
175
- best_score = float("-inf")
176
- best_prev = 0
177
- for prev_id in range(num_labels):
178
- prev_label = id2label.get(prev_id, "O")
179
- if not is_allowed_bio_transition(prev_label, label):
180
- continue
181
- candidate = dp[prev_id] + scores[idx, label_id]
182
- if candidate > best_score:
183
- best_score = float(candidate)
184
- best_prev = prev_id
185
- next_dp[label_id] = best_score
186
- backpointers[idx, label_id] = best_prev
187
  dp = next_dp
188
 
189
  best_last = int(torch.argmax(dp).item())
 
148
  return True
149
 
150
 
151
+ _BIO_TRANSITION_CACHE: Dict[Tuple[Tuple[int, str], ...], torch.Tensor] = {}
152
+
153
+
154
+ def bio_transition_mask(id2label: Dict[int, str]) -> torch.Tensor:
155
+ """Return cached valid-transition mask shaped [prev_label, next_label]."""
156
+ key = tuple(sorted((int(label_id), label) for label_id, label in id2label.items()))
157
+ cached = _BIO_TRANSITION_CACHE.get(key)
158
+ if cached is not None:
159
+ return cached
160
+ num_labels = max(id2label) + 1 if id2label else 0
161
+ mask = torch.zeros((num_labels, num_labels), dtype=torch.bool)
162
+ for prev_id in range(num_labels):
163
+ prev_label = id2label.get(prev_id, "O")
164
+ for label_id in range(num_labels):
165
+ label = id2label.get(label_id, "O")
166
+ mask[prev_id, label_id] = is_allowed_bio_transition(prev_label, label)
167
+ _BIO_TRANSITION_CACHE[key] = mask
168
+ return mask
169
+
170
+
171
  def constrained_bio_decode(emissions: torch.Tensor, id2label: Dict[int, str]) -> List[int]:
172
  """
173
  Decode token logits with hard BIO transition constraints.
 
180
 
181
  num_tokens, num_labels = emissions.shape
182
  scores = emissions.detach().cpu()
183
+ transition_mask = bio_transition_mask(id2label)
184
  backpointers = torch.zeros((num_tokens, num_labels), dtype=torch.long)
185
  dp = torch.full((num_labels,), float("-inf"))
186
 
 
190
  dp[label_id] = scores[0, label_id]
191
 
192
  for idx in range(1, num_tokens):
193
+ candidates = dp.unsqueeze(1).expand(num_labels, num_labels)
194
+ candidates = candidates.masked_fill(~transition_mask, float("-inf"))
195
+ best_scores, best_prev = candidates.max(dim=0)
196
+ next_dp = best_scores + scores[idx]
197
+ backpointers[idx] = best_prev
 
 
 
 
 
 
 
 
 
 
198
  dp = next_dp
199
 
200
  best_last = int(torch.argmax(dp).item())
onnx_inference.py CHANGED
@@ -12,7 +12,7 @@ Usage:
12
  import argparse
13
  import json
14
  from pathlib import Path
15
- from typing import Dict, List, Tuple
16
 
17
  import numpy as np
18
  import onnxruntime as ort
@@ -61,25 +61,48 @@ def parse_with_onnx(
61
  max_length: int,
62
  use_rules: bool = True,
63
  ) -> Dict:
64
- tokenizer = load_tokenizer(str(model_dir))
65
- id2label = load_id2label(model_dir)
66
- tokens, input_ids, attention_mask, available = encode(filename, tokenizer, max_length)
67
-
68
- session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
69
- logits = session.run(
70
- ["logits"],
71
- {
72
- "input_ids": input_ids,
73
- "attention_mask": attention_mask,
74
- },
75
- )[0]
76
-
77
- token_logits = torch.from_numpy(logits[0, 1:1 + available, :])
78
- label_ids = constrained_bio_decode(token_logits, id2label)
79
- labels = [id2label.get(label_id, "O") for label_id in label_ids]
80
- result = postprocess(tokens, labels, tokenizer=tokenizer, filename=filename, use_rules=use_rules)
81
- result["_input"] = filename
82
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
 
85
  def main() -> None:
 
12
  import argparse
13
  import json
14
  from pathlib import Path
15
+ from typing import Dict, List, Optional, Tuple
16
 
17
  import numpy as np
18
  import onnxruntime as ort
 
61
  max_length: int,
62
  use_rules: bool = True,
63
  ) -> Dict:
64
+ parser = OnnxFilenameParser(model_dir, onnx_path, max_length)
65
+ return parser.parse(filename, use_rules=use_rules)
66
+
67
+
68
+ class OnnxFilenameParser:
69
+ """Reusable ONNX Runtime parser with tokenizer and session loaded once."""
70
+
71
+ def __init__(
72
+ self,
73
+ model_dir: Path,
74
+ onnx_path: Path,
75
+ max_length: int,
76
+ providers: List[str] | None = None,
77
+ session_options: Optional[ort.SessionOptions] = None,
78
+ ) -> None:
79
+ self.model_dir = model_dir
80
+ self.onnx_path = onnx_path
81
+ self.max_length = max_length
82
+ self.tokenizer = load_tokenizer(str(model_dir))
83
+ self.id2label = load_id2label(model_dir)
84
+ self.session = ort.InferenceSession(
85
+ str(onnx_path),
86
+ sess_options=session_options,
87
+ providers=providers or ["CPUExecutionProvider"],
88
+ )
89
+
90
+ def parse(self, filename: str, use_rules: bool = True) -> Dict:
91
+ tokens, input_ids, attention_mask, available = encode(filename, self.tokenizer, self.max_length)
92
+ logits = self.session.run(
93
+ ["logits"],
94
+ {
95
+ "input_ids": input_ids,
96
+ "attention_mask": attention_mask,
97
+ },
98
+ )[0]
99
+
100
+ token_logits = torch.from_numpy(logits[0, 1:1 + available, :])
101
+ label_ids = constrained_bio_decode(token_logits, self.id2label)
102
+ labels = [self.id2label.get(label_id, "O") for label_id in label_ids]
103
+ result = postprocess(tokens, labels, tokenizer=self.tokenizer, filename=filename, use_rules=use_rules)
104
+ result["_input"] = filename
105
+ return result
106
 
107
 
108
  def main() -> None: