""" CPU-optimized handler for BERT-OJA-SkillLess (ONNX). Uses ONNX Runtime with all available CPU threads, small internal batches for cache efficiency, and numpy for fast softmax. """ from typing import Dict, List, Any import os import numpy as np from transformers import AutoTokenizer import onnxruntime as ort class EndpointHandler: def __init__(self, path=""): self.tokenizer = AutoTokenizer.from_pretrained(path) n_threads = int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4)) opts = ort.SessionOptions() opts.intra_op_num_threads = n_threads opts.inter_op_num_threads = max(1, n_threads // 2) opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL onnx_path = os.path.join(path, "model.onnx") self.session = ort.InferenceSession( onnx_path, sess_options=opts, providers=["CPUExecutionProvider"], ) self.input_names = [i.name for i in self.session.get_inputs()] self.batch_size = 128 print(f"[handler] ONNX CPU ready — threads={n_threads}, batch={self.batch_size}") def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]: inputs = data.get("inputs", data.get("input", "")) if isinstance(inputs, str): inputs = [inputs] all_results = [] for i in range(0, len(inputs), self.batch_size): batch = inputs[i : i + self.batch_size] encoded = self.tokenizer( batch, padding=True, truncation=True, max_length=128, return_tensors="np", ) feed = {k: encoded[k] for k in self.input_names if k in encoded} logits = self.session.run(None, feed)[0] exp = np.exp(logits - logits.max(axis=-1, keepdims=True)) probs = exp / exp.sum(axis=-1, keepdims=True) for j in range(len(batch)): all_results.append([ {"label": "LABEL_0", "score": round(float(probs[j][0]), 6)}, {"label": "LABEL_1", "score": round(float(probs[j][1]), 6)}, ]) return all_results