visobert-multi-task / handler.py
Ng Huong Duyen
update dropout to 0.1 for all task heads
0c036ef
from typing import Dict, List, Any
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel, AutoConfig
from huggingface_hub import hf_hub_download
import os
class EndpointHandler:
def __init__(self, path: str):
MODEL_REPO = 'bie-nhd/visobert-multitask'
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.tokenizer = AutoTokenizer.from_pretrained('bie-nhd/visobert-multitask')
# config = AutoConfig.from_pretrained('bie-nhd/visobert-multitask')
model_path = hf_hub_download(repo_id=MODEL_REPO, filename="model.pt")
self.model = AutoModel.from_pretrained("uitnlp/visobert")
checkpoint = torch.load(model_path, map_location=self.device)
self.model_state_dict = checkpoint['encoder']
self.model.load_state_dict(self.model_state_dict)
self.model.eval()
self.task_heads = torch.nn.ModuleDict({
'sentiment': TaskClassificationHead(self.model.config.hidden_size, 4, 0.1),
'topic': TaskClassificationHead(self.model.config.hidden_size, 10, 0.1),
'hate_speech': TaskClassificationHead(self.model.config.hidden_size, 5, 0.1),
'clickbait': TaskClassificationHead(self.model.config.hidden_size, 2, 0.1)
})
self.log_vars = torch.nn.ParameterDict({
task: torch.nn.Parameter(torch.zeros(1)) for task in self.task_heads
})
self.log_vars.load_state_dict(checkpoint['log_vars'])
self.model.to(self.device)
self.task_heads.to(self.device)
self.log_vars.to(self.device)
self.task_config = {
'sentiment': {
'num_labels': 4,
'type': 'single_label',
'label_map': {0: 'Neutral', 1: 'Positive', 2: 'Negative', 3: 'Toxic'}
},
'topic': {
'num_labels': 10,
'type': 'single_label',
'label_map': {i: label for i, label in enumerate(['Spam', 'News', 'Academic', 'Other', 'Service', 'Jobs', 'Personal', 'Social', 'Help', 'Events'])}
},
'hate_speech': {
'num_labels': 5,
'type': 'multi_label',
'label_list': ['individual', 'groups', 'religion/creed', 'race/ethnicity', 'politics']
},
'clickbait': {
'num_labels': 2,
'type': 'single_label',
'label_map': {0: 'Non-Clickbait', 1: 'Clickbait'},
'dual_input': True
}
}
def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
task = inputs.get('task', None)
title = inputs.get('title', None)
text = inputs.get('text', inputs.get('inputs', None))
if task is None or task not in self.task_config.keys():
raise ValueError(f"Invalid task: {task}")
config = self.task_config[task]
max_length = 256 if task == 'clickbait' else 128
encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
if config.get('dual_input', False):
encoding = self.tokenizer(f"{title} </s></s> {text}", padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
return encoding
def predict(self, task:str, preprocessed: Dict[str, Any]) -> Dict[str, Any]:
logits = self.task_heads[task](self.model(**preprocessed).last_hidden_state[:, 0, :])
config = self.task_config[task]
if config['type'] == 'multi_label':
probs = torch.sigmoid(logits).detach().cpu().numpy()[0]
active = [config['label_list'][i] for i in np.where(probs > 0.5)[0]]
return {'labels': active, 'scores': probs.tolist()}
else:
probs = torch.softmax(logits, dim=-1).detach().cpu().numpy()[0]
pred_idx = int(np.argmax(probs))
return {'label': config['label_map'][pred_idx], 'confidence': float(probs[pred_idx])}
# def postprocess(self, outputs: Dict[str, Any]) -> List[Dict[str, Any]]:
# return [outputs]
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
task = data.get('task', None)
print(f"Task: {task}")
if task is None:
raise ValueError("'task' key is required in the input dictionary")
task = task.lower()
results = {}
if task == "all":
for _t in self.task_config.keys():
data['task'] = _t
preprocessed = self.preprocess(data)
outputs = self.predict(_t, preprocessed)
results[_t] = outputs
return results
elif task not in self.task_config.keys():
raise ValueError(f"Invalid task: {task}")
preprocessed = self.preprocess(data)
outputs = self.predict(task, preprocessed)
results[task] = outputs
# return self.postprocess(outputs)
return results
class TaskClassificationHead(torch.nn.Module):
def __init__(self, hidden_size: int, num_labels: int, dropout: float):
super().__init__()
bottleneck = max(hidden_size // 2, num_labels)
self.projection = torch.nn.Sequential(
torch.nn.Linear(hidden_size, bottleneck),
torch.nn.ReLU(),
torch.nn.Dropout(dropout),
torch.nn.Linear(bottleneck, num_labels),
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.projection(hidden_states)