File size: 5,321 Bytes
65f125e a19b0dc 65f125e a19b0dc 65f125e a19b0dc 65f125e a19b0dc 65f125e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
from typing import Dict, List, Any
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
import os
class EndpointHandler:
def __init__(self, path: str):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.tokenizer = AutoTokenizer.from_pretrained('bie-nhd/visobert-multitask')
checkpoint = torch.load(os.path.join(path, 'model.pt'), map_location=self.device)
self.model_state_dict = checkpoint['encoder']
self.model = AutoModel.from_pretrained('uitnlp/visobert')
self.model.load_state_dict(self.model_state_dict)
self.model.eval()
self.task_heads = torch.nn.ModuleDict({
'sentiment': TaskClassificationHead(self.model.config.hidden_size, 4, 0.2),
'topic': TaskClassificationHead(self.model.config.hidden_size, 10, 0.2),
'hate_speech': TaskClassificationHead(self.model.config.hidden_size, 5, 0.2),
'clickbait': TaskClassificationHead(self.model.config.hidden_size, 2, 0.2)
})
self.log_vars = torch.nn.ParameterDict({
task: torch.nn.Parameter(torch.zeros(1)) for task in self.task_heads
})
self.log_vars.load_state_dict(checkpoint['log_vars'])
self.model.to(self.device)
self.task_heads.to(self.device)
self.log_vars.to(self.device)
self.task_config = {
'sentiment': {
'num_labels': 4,
'type': 'single_label',
'label_map': {0: 'Neutral', 1: 'Positive', 2: 'Negative', 3: 'Toxic'}
},
'topic': {
'num_labels': 10,
'type': 'single_label',
'label_map': {i: label for i, label in enumerate(['Spam', 'News', 'Academic', 'Other', 'Service', 'Jobs', 'Personal', 'Social', 'Help', 'Events'])}
},
'hate_speech': {
'num_labels': 5,
'type': 'multi_label',
'label_list': ['individual', 'groups', 'religion/creed', 'race/ethnicity', 'politics']
},
'clickbait': {
'num_labels': 2,
'type': 'single_label',
'label_map': {0: 'Non-Clickbait', 1: 'Clickbait'},
'dual_input': True
}
}
def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
task = inputs.get('task', None)
title = inputs.get('title', None)
text = inputs.get('text', inputs.get('inputs', None))
if task is None or task not in self.task_config.keys():
raise ValueError(f"Invalid task: {task}")
config = self.task_config[task]
max_length = 256 if task == 'clickbait' else 128
encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
if config.get('dual_input', False):
encoding = self.tokenizer(f"{title} </s></s> {text}", padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
return encoding
def predict(self, task:str, preprocessed: Dict[str, Any]) -> Dict[str, Any]:
logits = self.task_heads[task](self.model(**preprocessed).last_hidden_state[:, 0, :])
config = self.task_config[task]
if config['type'] == 'multi_label':
probs = torch.sigmoid(logits).detach().cpu().numpy()[0]
active = [config['label_list'][i] for i in np.where(probs > 0.5)[0]]
return {'labels': active, 'scores': probs.tolist()}
else:
probs = torch.softmax(logits, dim=-1).detach().cpu().numpy()[0]
pred_idx = int(np.argmax(probs))
return {'label': config['label_map'][pred_idx], 'confidence': float(probs[pred_idx])}
# def postprocess(self, outputs: Dict[str, Any]) -> List[Dict[str, Any]]:
# return [outputs]
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
task = data.get('task', None)
print(f"Task: {task}")
if task is None:
raise ValueError("'task' key is required in the input dictionary")
task = task.lower()
if task == "all":
results = {}
for _t in self.task_config.keys():
data['task'] = _t
preprocessed = self.preprocess(data)
outputs = self.predict(_t, preprocessed)
results[_t] = outputs
return results
elif task not in self.task_config.keys():
raise ValueError(f"Invalid task: {task}")
preprocessed = self.preprocess(data)
outputs = self.predict(task, preprocessed)
# return self.postprocess(outputs)
return outputs
class TaskClassificationHead(torch.nn.Module):
def __init__(self, hidden_size: int, num_labels: int, dropout: float):
super().__init__()
bottleneck = max(hidden_size // 2, num_labels)
self.projection = torch.nn.Sequential(
torch.nn.Linear(hidden_size, bottleneck),
torch.nn.ReLU(),
torch.nn.Dropout(dropout),
torch.nn.Linear(bottleneck, num_labels),
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.projection(hidden_states) |