File size: 1,782 Bytes
3e4ee15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import torch
import numpy as np
from typing import Dict, List, Any
class EndpointHandler:
def __init__(self, path=""):
# Load the model and tokenizer
from transformers import AutoModelForSequenceClassification, AutoTokenizer
self.model = AutoModelForSequenceClassification.from_pretrained(path)
self.tokenizer = AutoTokenizer.from_pretrained(path)
# Load per-class thresholds
thresholds_path = f"{path}/thresholds.npy"
self.thresholds = np.load(thresholds_path)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data (Dict[str, Any]): Input data containing 'inputs' key
Returns:
List[Dict[str, Any]]: Predictions with labels and scores
"""
inputs_text = data.pop("inputs", data)
# Tokenize
inputs = self.tokenizer(
inputs_text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=128
)
# Inference
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits[0]
probs = torch.sigmoid(logits).cpu().numpy()
# Apply per-class thresholds
predictions = []
for idx, prob in enumerate(probs):
if prob >= self.thresholds[idx]:
predictions.append({
"label": self.model.config.id2label[idx],
"score": float(prob)
})
# Sort by score descending
predictions = sorted(predictions, key=lambda x: x["score"], reverse=True)
return predictions |