Spaces:
Sleeping
Sleeping
File size: 5,583 Bytes
5b2de76 0c036ef 5b2de76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from typing import Dict, List, Any
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel, AutoConfig
from huggingface_hub import hf_hub_download
import os
class EndpointHandler:
def __init__(self, path: str):
MODEL_REPO = 'bie-nhd/visobert-multitask'
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.tokenizer = AutoTokenizer.from_pretrained('bie-nhd/visobert-multitask')
# config = AutoConfig.from_pretrained('bie-nhd/visobert-multitask')
model_path = hf_hub_download(repo_id=MODEL_REPO, filename="model.pt")
self.model = AutoModel.from_pretrained("uitnlp/visobert")
checkpoint = torch.load(model_path, map_location=self.device)
self.model_state_dict = checkpoint['encoder']
self.model.load_state_dict(self.model_state_dict)
self.model.eval()
self.task_heads = torch.nn.ModuleDict({
'sentiment': TaskClassificationHead(self.model.config.hidden_size, 4, 0.1),
'topic': TaskClassificationHead(self.model.config.hidden_size, 10, 0.1),
'hate_speech': TaskClassificationHead(self.model.config.hidden_size, 5, 0.1),
'clickbait': TaskClassificationHead(self.model.config.hidden_size, 2, 0.1)
})
self.log_vars = torch.nn.ParameterDict({
task: torch.nn.Parameter(torch.zeros(1)) for task in self.task_heads
})
self.log_vars.load_state_dict(checkpoint['log_vars'])
self.model.to(self.device)
self.task_heads.to(self.device)
self.log_vars.to(self.device)
self.task_config = {
'sentiment': {
'num_labels': 4,
'type': 'single_label',
'label_map': {0: 'Neutral', 1: 'Positive', 2: 'Negative', 3: 'Toxic'}
},
'topic': {
'num_labels': 10,
'type': 'single_label',
'label_map': {i: label for i, label in enumerate(['Spam', 'News', 'Academic', 'Other', 'Service', 'Jobs', 'Personal', 'Social', 'Help', 'Events'])}
},
'hate_speech': {
'num_labels': 5,
'type': 'multi_label',
'label_list': ['individual', 'groups', 'religion/creed', 'race/ethnicity', 'politics']
},
'clickbait': {
'num_labels': 2,
'type': 'single_label',
'label_map': {0: 'Non-Clickbait', 1: 'Clickbait'},
'dual_input': True
}
}
def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
task = inputs.get('task', None)
title = inputs.get('title', None)
text = inputs.get('text', inputs.get('inputs', None))
if task is None or task not in self.task_config.keys():
raise ValueError(f"Invalid task: {task}")
config = self.task_config[task]
max_length = 256 if task == 'clickbait' else 128
encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
if config.get('dual_input', False):
encoding = self.tokenizer(f"{title} </s></s> {text}", padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
return encoding
def predict(self, task:str, preprocessed: Dict[str, Any]) -> Dict[str, Any]:
logits = self.task_heads[task](self.model(**preprocessed).last_hidden_state[:, 0, :])
config = self.task_config[task]
if config['type'] == 'multi_label':
probs = torch.sigmoid(logits).detach().cpu().numpy()[0]
active = [config['label_list'][i] for i in np.where(probs > 0.5)[0]]
return {'labels': active, 'scores': probs.tolist()}
else:
probs = torch.softmax(logits, dim=-1).detach().cpu().numpy()[0]
pred_idx = int(np.argmax(probs))
return {'label': config['label_map'][pred_idx], 'confidence': float(probs[pred_idx])}
# def postprocess(self, outputs: Dict[str, Any]) -> List[Dict[str, Any]]:
# return [outputs]
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
task = data.get('task', None)
print(f"Task: {task}")
if task is None:
raise ValueError("'task' key is required in the input dictionary")
task = task.lower()
results = {}
if task == "all":
for _t in self.task_config.keys():
data['task'] = _t
preprocessed = self.preprocess(data)
outputs = self.predict(_t, preprocessed)
results[_t] = outputs
return results
elif task not in self.task_config.keys():
raise ValueError(f"Invalid task: {task}")
preprocessed = self.preprocess(data)
outputs = self.predict(task, preprocessed)
results[task] = outputs
# return self.postprocess(outputs)
return results
class TaskClassificationHead(torch.nn.Module):
def __init__(self, hidden_size: int, num_labels: int, dropout: float):
super().__init__()
bottleneck = max(hidden_size // 2, num_labels)
self.projection = torch.nn.Sequential(
torch.nn.Linear(hidden_size, bottleneck),
torch.nn.ReLU(),
torch.nn.Dropout(dropout),
torch.nn.Linear(bottleneck, num_labels),
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.projection(hidden_states) |