File size: 2,224 Bytes
443d436 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | """
CPU-optimized handler for BERT-OJA-SkillLess (ONNX).
Uses ONNX Runtime with all available CPU threads, small internal batches
for cache efficiency, and numpy for fast softmax.
"""
from typing import Dict, List, Any
import os
import numpy as np
from transformers import AutoTokenizer
import onnxruntime as ort
class EndpointHandler:
def __init__(self, path=""):
self.tokenizer = AutoTokenizer.from_pretrained(path)
n_threads = int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4))
opts = ort.SessionOptions()
opts.intra_op_num_threads = n_threads
opts.inter_op_num_threads = max(1, n_threads // 2)
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
onnx_path = os.path.join(path, "model.onnx")
self.session = ort.InferenceSession(
onnx_path, sess_options=opts, providers=["CPUExecutionProvider"],
)
self.input_names = [i.name for i in self.session.get_inputs()]
self.batch_size = 128
print(f"[handler] ONNX CPU ready — threads={n_threads}, batch={self.batch_size}")
def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]:
inputs = data.get("inputs", data.get("input", ""))
if isinstance(inputs, str):
inputs = [inputs]
all_results = []
for i in range(0, len(inputs), self.batch_size):
batch = inputs[i : i + self.batch_size]
encoded = self.tokenizer(
batch, padding=True, truncation=True,
max_length=128, return_tensors="np",
)
feed = {k: encoded[k] for k in self.input_names if k in encoded}
logits = self.session.run(None, feed)[0]
exp = np.exp(logits - logits.max(axis=-1, keepdims=True))
probs = exp / exp.sum(axis=-1, keepdims=True)
for j in range(len(batch)):
all_results.append([
{"label": "LABEL_0", "score": round(float(probs[j][0]), 6)},
{"label": "LABEL_1", "score": round(float(probs[j][1]), 6)},
])
return all_results
|