File size: 3,039 Bytes
a32bec1
 
7eb969d
a32bec1
 
7eb969d
 
a32bec1
 
 
 
 
7eb969d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a32bec1
 
 
 
 
 
 
7eb969d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a32bec1
7eb969d
 
 
a32bec1
 
 
 
 
7eb969d
a32bec1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""
Custom handler for BERT-OJA-SkillLess on HF Inference Endpoints.
Uses ONNX Runtime with CUDA for 2-4x faster inference.
"""
from typing import Dict, List, Any
import numpy as np
from transformers import AutoTokenizer


class EndpointHandler:
    def __init__(self, path=""):
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.batch_size = 1024

        try:
            from optimum.onnxruntime import ORTModelForSequenceClassification
            self.model = ORTModelForSequenceClassification.from_pretrained(
                path, export=True, provider="CUDAExecutionProvider",
            )
            self._use_ort = True
            print(f"[handler] Loaded ONNX model on CUDA (batch_size={self.batch_size})")
        except Exception as e:
            print(f"[handler] ONNX failed ({e}), falling back to PyTorch FP16")
            import torch
            from transformers import AutoModelForSequenceClassification
            self.model = AutoModelForSequenceClassification.from_pretrained(path)
            self.model.eval()
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            if self.device == "cuda":
                self.model = self.model.to(self.device).half()
            self._use_ort = False
            self._torch = torch

    def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]:
        inputs = data.get("inputs", data.get("input", ""))
        if isinstance(inputs, str):
            inputs = [inputs]

        all_results = []
        for i in range(0, len(inputs), self.batch_size):
            batch = inputs[i : i + self.batch_size]
            encoded = self.tokenizer(
                batch, padding=True, truncation=True,
                max_length=128, return_tensors="pt" if not self._use_ort else "np",
            )

            if self._use_ort:
                logits = self.model(**{k: v for k, v in encoded.items()}).logits
                if hasattr(logits, 'numpy'):
                    logits = logits.numpy()
                exp = np.exp(logits - logits.max(axis=-1, keepdims=True))
                probs = exp / exp.sum(axis=-1, keepdims=True)
                for j in range(len(batch)):
                    all_results.append([
                        {"label": "LABEL_0", "score": round(float(probs[j][0]), 6)},
                        {"label": "LABEL_1", "score": round(float(probs[j][1]), 6)},
                    ])
            else:
                torch = self._torch
                encoded = {k: v.to(self.device) for k, v in encoded.items()}
                with torch.no_grad():
                    logits = self.model(**encoded).logits
                    probs = torch.softmax(logits, dim=-1)
                for j in range(len(batch)):
                    all_results.append([
                        {"label": "LABEL_0", "score": round(probs[j][0].item(), 6)},
                        {"label": "LABEL_1", "score": round(probs[j][1].item(), 6)},
                    ])

        return all_results