File size: 2,594 Bytes
da97e2c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | from typing import Dict, List, Any
import json
import torch
from transformers import BertTokenizerFast, BertForTokenClassification
class EndpointHandler:
def __init__(self, path=""):
# Load the tokenizer and model
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
self.model = BertForTokenClassification.from_pretrained(path)
self.model.eval()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
# ID to label mapping
self.id2label = {
0: "O",
1: "B-STEREO",
2: "I-STEREO",
3: "B-GEN",
4: "I-GEN",
5: "B-UNFAIR",
6: "I-UNFAIR",
7: "B-EXCL",
8: "I-EXCL",
9: "B-FRAME",
10: "I-FRAME",
11: "B-ASSUMP",
12: "I-ASSUMP",
}
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data (Dict[str, Any]): A dictionary containing the input text under 'inputs'.
Returns:
List[Dict[str, Any]]: A list of dictionaries with token labels.
"""
# Extract the input sentence
sentence = data.get("inputs", "")
if not sentence:
return [{"error": "Input 'inputs' is required."}]
# Tokenize the input sentence
inputs = self.tokenizer(
sentence, return_tensors="pt", padding=True, truncation=True, max_length=128
)
input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs["attention_mask"].to(self.device)
# Run inference
with torch.no_grad():
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
probabilities = torch.sigmoid(logits)
predicted_labels = (probabilities > 0.5).int()
# Prepare the result
result = []
tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
for i, token in enumerate(tokens):
if token not in self.tokenizer.all_special_tokens:
label_indices = (
(predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
)
labels = (
[self.id2label[idx.item()] for idx in label_indices]
if label_indices.numel() > 0
else ["O"]
)
result.append({"token": token, "labels": labels})
return result |