File size: 3,039 Bytes
a32bec1 7eb969d a32bec1 7eb969d a32bec1 7eb969d a32bec1 7eb969d a32bec1 7eb969d a32bec1 7eb969d a32bec1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | """
Custom handler for BERT-OJA-SkillLess on HF Inference Endpoints.
Uses ONNX Runtime with CUDA for 2-4x faster inference.
"""
from typing import Dict, List, Any
import numpy as np
from transformers import AutoTokenizer
class EndpointHandler:
def __init__(self, path=""):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.batch_size = 1024
try:
from optimum.onnxruntime import ORTModelForSequenceClassification
self.model = ORTModelForSequenceClassification.from_pretrained(
path, export=True, provider="CUDAExecutionProvider",
)
self._use_ort = True
print(f"[handler] Loaded ONNX model on CUDA (batch_size={self.batch_size})")
except Exception as e:
print(f"[handler] ONNX failed ({e}), falling back to PyTorch FP16")
import torch
from transformers import AutoModelForSequenceClassification
self.model = AutoModelForSequenceClassification.from_pretrained(path)
self.model.eval()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if self.device == "cuda":
self.model = self.model.to(self.device).half()
self._use_ort = False
self._torch = torch
def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]:
inputs = data.get("inputs", data.get("input", ""))
if isinstance(inputs, str):
inputs = [inputs]
all_results = []
for i in range(0, len(inputs), self.batch_size):
batch = inputs[i : i + self.batch_size]
encoded = self.tokenizer(
batch, padding=True, truncation=True,
max_length=128, return_tensors="pt" if not self._use_ort else "np",
)
if self._use_ort:
logits = self.model(**{k: v for k, v in encoded.items()}).logits
if hasattr(logits, 'numpy'):
logits = logits.numpy()
exp = np.exp(logits - logits.max(axis=-1, keepdims=True))
probs = exp / exp.sum(axis=-1, keepdims=True)
for j in range(len(batch)):
all_results.append([
{"label": "LABEL_0", "score": round(float(probs[j][0]), 6)},
{"label": "LABEL_1", "score": round(float(probs[j][1]), 6)},
])
else:
torch = self._torch
encoded = {k: v.to(self.device) for k, v in encoded.items()}
with torch.no_grad():
logits = self.model(**encoded).logits
probs = torch.softmax(logits, dim=-1)
for j in range(len(batch)):
all_results.append([
{"label": "LABEL_0", "score": round(probs[j][0].item(), 6)},
{"label": "LABEL_1", "score": round(probs[j][1].item(), 6)},
])
return all_results
|