File size: 2,224 Bytes
443d436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
CPU-optimized handler for BERT-OJA-SkillLess (ONNX).
Uses ONNX Runtime with all available CPU threads, small internal batches
for cache efficiency, and numpy for fast softmax.
"""
from typing import Dict, List, Any
import os
import numpy as np
from transformers import AutoTokenizer
import onnxruntime as ort


class EndpointHandler:
    def __init__(self, path=""):
        self.tokenizer = AutoTokenizer.from_pretrained(path)

        n_threads = int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4))

        opts = ort.SessionOptions()
        opts.intra_op_num_threads = n_threads
        opts.inter_op_num_threads = max(1, n_threads // 2)
        opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
        opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL

        onnx_path = os.path.join(path, "model.onnx")
        self.session = ort.InferenceSession(
            onnx_path, sess_options=opts, providers=["CPUExecutionProvider"],
        )
        self.input_names = [i.name for i in self.session.get_inputs()]
        self.batch_size = 128
        print(f"[handler] ONNX CPU ready — threads={n_threads}, batch={self.batch_size}")

    def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]:
        inputs = data.get("inputs", data.get("input", ""))
        if isinstance(inputs, str):
            inputs = [inputs]

        all_results = []
        for i in range(0, len(inputs), self.batch_size):
            batch = inputs[i : i + self.batch_size]
            encoded = self.tokenizer(
                batch, padding=True, truncation=True,
                max_length=128, return_tensors="np",
            )
            feed = {k: encoded[k] for k in self.input_names if k in encoded}
            logits = self.session.run(None, feed)[0]

            exp = np.exp(logits - logits.max(axis=-1, keepdims=True))
            probs = exp / exp.sum(axis=-1, keepdims=True)

            for j in range(len(batch)):
                all_results.append([
                    {"label": "LABEL_0", "score": round(float(probs[j][0]), 6)},
                    {"label": "LABEL_1", "score": round(float(probs[j][1]), 6)},
                ])

        return all_results