| """ |
| CPU-optimized handler for BERT-OJA-SkillLess (ONNX). |
| Uses ONNX Runtime with all available CPU threads, small internal batches |
| for cache efficiency, and numpy for fast softmax. |
| """ |
| from typing import Dict, List, Any |
| import os |
| import numpy as np |
| from transformers import AutoTokenizer |
| import onnxruntime as ort |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
|
| n_threads = int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4)) |
|
|
| opts = ort.SessionOptions() |
| opts.intra_op_num_threads = n_threads |
| opts.inter_op_num_threads = max(1, n_threads // 2) |
| opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
| opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL |
|
|
| onnx_path = os.path.join(path, "model.onnx") |
| self.session = ort.InferenceSession( |
| onnx_path, sess_options=opts, providers=["CPUExecutionProvider"], |
| ) |
| self.input_names = [i.name for i in self.session.get_inputs()] |
| self.batch_size = 128 |
| print(f"[handler] ONNX CPU ready — threads={n_threads}, batch={self.batch_size}") |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]: |
| inputs = data.get("inputs", data.get("input", "")) |
| if isinstance(inputs, str): |
| inputs = [inputs] |
|
|
| all_results = [] |
| for i in range(0, len(inputs), self.batch_size): |
| batch = inputs[i : i + self.batch_size] |
| encoded = self.tokenizer( |
| batch, padding=True, truncation=True, |
| max_length=128, return_tensors="np", |
| ) |
| feed = {k: encoded[k] for k in self.input_names if k in encoded} |
| logits = self.session.run(None, feed)[0] |
|
|
| exp = np.exp(logits - logits.max(axis=-1, keepdims=True)) |
| probs = exp / exp.sum(axis=-1, keepdims=True) |
|
|
| for j in range(len(batch)): |
| all_results.append([ |
| {"label": "LABEL_0", "score": round(float(probs[j][0]), 6)}, |
| {"label": "LABEL_1", "score": round(float(probs[j][1]), 6)}, |
| ]) |
|
|
| return all_results |
|
|