|
|
from typing import Dict, List, Any |
|
|
import torch |
|
|
import numpy as np |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
import os |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str): |
|
|
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
self.tokenizer = AutoTokenizer.from_pretrained('bie-nhd/visobert-multitask') |
|
|
checkpoint = torch.load(os.path.join(path, 'model.pt'), map_location=self.device) |
|
|
|
|
|
self.model_state_dict = checkpoint['encoder'] |
|
|
self.model = AutoModel.from_pretrained('uitnlp/visobert') |
|
|
self.model.load_state_dict(self.model_state_dict) |
|
|
self.model.eval() |
|
|
|
|
|
self.task_heads = torch.nn.ModuleDict({ |
|
|
'sentiment': TaskClassificationHead(self.model.config.hidden_size, 4, 0.2), |
|
|
'topic': TaskClassificationHead(self.model.config.hidden_size, 10, 0.2), |
|
|
'hate_speech': TaskClassificationHead(self.model.config.hidden_size, 5, 0.2), |
|
|
'clickbait': TaskClassificationHead(self.model.config.hidden_size, 2, 0.2) |
|
|
}) |
|
|
self.log_vars = torch.nn.ParameterDict({ |
|
|
task: torch.nn.Parameter(torch.zeros(1)) for task in self.task_heads |
|
|
}) |
|
|
self.log_vars.load_state_dict(checkpoint['log_vars']) |
|
|
self.model.to(self.device) |
|
|
self.task_heads.to(self.device) |
|
|
self.log_vars.to(self.device) |
|
|
|
|
|
self.task_config = { |
|
|
'sentiment': { |
|
|
'num_labels': 4, |
|
|
'type': 'single_label', |
|
|
'label_map': {0: 'Neutral', 1: 'Positive', 2: 'Negative', 3: 'Toxic'} |
|
|
}, |
|
|
'topic': { |
|
|
'num_labels': 10, |
|
|
'type': 'single_label', |
|
|
'label_map': {i: label for i, label in enumerate(['Spam', 'News', 'Academic', 'Other', 'Service', 'Jobs', 'Personal', 'Social', 'Help', 'Events'])} |
|
|
}, |
|
|
'hate_speech': { |
|
|
'num_labels': 5, |
|
|
'type': 'multi_label', |
|
|
'label_list': ['individual', 'groups', 'religion/creed', 'race/ethnicity', 'politics'] |
|
|
}, |
|
|
'clickbait': { |
|
|
'num_labels': 2, |
|
|
'type': 'single_label', |
|
|
'label_map': {0: 'Non-Clickbait', 1: 'Clickbait'}, |
|
|
'dual_input': True |
|
|
} |
|
|
} |
|
|
|
|
|
def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: |
|
|
task = inputs.get('task', None) |
|
|
title = inputs.get('title', None) |
|
|
text = inputs.get('text', inputs.get('inputs', None)) |
|
|
|
|
|
if task is None or task not in self.task_config.keys(): |
|
|
raise ValueError(f"Invalid task: {task}") |
|
|
|
|
|
config = self.task_config[task] |
|
|
max_length = 256 if task == 'clickbait' else 128 |
|
|
encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt') |
|
|
|
|
|
if config.get('dual_input', False): |
|
|
encoding = self.tokenizer(f"{title} </s></s> {text}", padding='max_length', truncation=True, max_length=max_length, return_tensors='pt') |
|
|
|
|
|
return encoding |
|
|
|
|
|
def predict(self, task:str, preprocessed: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
|
|
logits = self.task_heads[task](self.model(**preprocessed).last_hidden_state[:, 0, :]) |
|
|
config = self.task_config[task] |
|
|
|
|
|
if config['type'] == 'multi_label': |
|
|
probs = torch.sigmoid(logits).detach().cpu().numpy()[0] |
|
|
active = [config['label_list'][i] for i in np.where(probs > 0.5)[0]] |
|
|
return {'labels': active, 'scores': probs.tolist()} |
|
|
else: |
|
|
probs = torch.softmax(logits, dim=-1).detach().cpu().numpy()[0] |
|
|
pred_idx = int(np.argmax(probs)) |
|
|
return {'label': config['label_map'][pred_idx], 'confidence': float(probs[pred_idx])} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
task = data.get('task', None) |
|
|
print(f"Task: {task}") |
|
|
if task is None: |
|
|
raise ValueError("'task' key is required in the input dictionary") |
|
|
task = task.lower() |
|
|
|
|
|
if task == "all": |
|
|
results = {} |
|
|
|
|
|
for _t in self.task_config.keys(): |
|
|
data['task'] = _t |
|
|
preprocessed = self.preprocess(data) |
|
|
outputs = self.predict(_t, preprocessed) |
|
|
results[_t] = outputs |
|
|
return results |
|
|
elif task not in self.task_config.keys(): |
|
|
raise ValueError(f"Invalid task: {task}") |
|
|
preprocessed = self.preprocess(data) |
|
|
|
|
|
outputs = self.predict(task, preprocessed) |
|
|
|
|
|
return outputs |
|
|
|
|
|
class TaskClassificationHead(torch.nn.Module): |
|
|
def __init__(self, hidden_size: int, num_labels: int, dropout: float): |
|
|
super().__init__() |
|
|
bottleneck = max(hidden_size // 2, num_labels) |
|
|
self.projection = torch.nn.Sequential( |
|
|
torch.nn.Linear(hidden_size, bottleneck), |
|
|
torch.nn.ReLU(), |
|
|
torch.nn.Dropout(dropout), |
|
|
torch.nn.Linear(bottleneck, num_labels), |
|
|
) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
return self.projection(hidden_states) |