mpalinski's picture
Upload handler.py with huggingface_hub
443d436 verified
"""
CPU-optimized handler for BERT-OJA-SkillLess (ONNX).
Uses ONNX Runtime with all available CPU threads, small internal batches
for cache efficiency, and numpy for fast softmax.
"""
from typing import Dict, List, Any
import os
import numpy as np
from transformers import AutoTokenizer
import onnxruntime as ort
class EndpointHandler:
def __init__(self, path=""):
self.tokenizer = AutoTokenizer.from_pretrained(path)
n_threads = int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4))
opts = ort.SessionOptions()
opts.intra_op_num_threads = n_threads
opts.inter_op_num_threads = max(1, n_threads // 2)
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
onnx_path = os.path.join(path, "model.onnx")
self.session = ort.InferenceSession(
onnx_path, sess_options=opts, providers=["CPUExecutionProvider"],
)
self.input_names = [i.name for i in self.session.get_inputs()]
self.batch_size = 128
print(f"[handler] ONNX CPU ready — threads={n_threads}, batch={self.batch_size}")
def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]:
inputs = data.get("inputs", data.get("input", ""))
if isinstance(inputs, str):
inputs = [inputs]
all_results = []
for i in range(0, len(inputs), self.batch_size):
batch = inputs[i : i + self.batch_size]
encoded = self.tokenizer(
batch, padding=True, truncation=True,
max_length=128, return_tensors="np",
)
feed = {k: encoded[k] for k in self.input_names if k in encoded}
logits = self.session.run(None, feed)[0]
exp = np.exp(logits - logits.max(axis=-1, keepdims=True))
probs = exp / exp.sum(axis=-1, keepdims=True)
for j in range(len(batch)):
all_results.append([
{"label": "LABEL_0", "score": round(float(probs[j][0]), 6)},
{"label": "LABEL_1", "score": round(float(probs[j][1]), 6)},
])
return all_results