File size: 5,583 Bytes
5b2de76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c036ef
 
 
 
5b2de76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from typing import Dict, List, Any
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel,  AutoConfig
from huggingface_hub import hf_hub_download
import os

class EndpointHandler:
    def __init__(self, path: str):
        MODEL_REPO = 'bie-nhd/visobert-multitask'
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.tokenizer = AutoTokenizer.from_pretrained('bie-nhd/visobert-multitask')
        # config = AutoConfig.from_pretrained('bie-nhd/visobert-multitask')
        model_path = hf_hub_download(repo_id=MODEL_REPO, filename="model.pt")
        self.model = AutoModel.from_pretrained("uitnlp/visobert")

        checkpoint = torch.load(model_path, map_location=self.device)
        self.model_state_dict = checkpoint['encoder']
        self.model.load_state_dict(self.model_state_dict)
        self.model.eval()

        self.task_heads = torch.nn.ModuleDict({
            'sentiment': TaskClassificationHead(self.model.config.hidden_size, 4, 0.1),
            'topic': TaskClassificationHead(self.model.config.hidden_size, 10, 0.1),
            'hate_speech': TaskClassificationHead(self.model.config.hidden_size, 5, 0.1),
            'clickbait': TaskClassificationHead(self.model.config.hidden_size, 2, 0.1)
        })
        self.log_vars = torch.nn.ParameterDict({
            task: torch.nn.Parameter(torch.zeros(1)) for task in self.task_heads
        })
        self.log_vars.load_state_dict(checkpoint['log_vars'])
        self.model.to(self.device)
        self.task_heads.to(self.device)
        self.log_vars.to(self.device)

        self.task_config = {
            'sentiment': {
                'num_labels': 4,
                'type': 'single_label',
                'label_map': {0: 'Neutral', 1: 'Positive', 2: 'Negative', 3: 'Toxic'}
            },
            'topic': {
                'num_labels': 10,
                'type': 'single_label',
                'label_map': {i: label for i, label in enumerate(['Spam', 'News', 'Academic', 'Other', 'Service', 'Jobs', 'Personal', 'Social', 'Help', 'Events'])}
            },
            'hate_speech': {
                'num_labels': 5,
                'type': 'multi_label',
                'label_list': ['individual', 'groups', 'religion/creed', 'race/ethnicity', 'politics']
            },
            'clickbait': {
                'num_labels': 2,
                'type': 'single_label',
                'label_map': {0: 'Non-Clickbait', 1: 'Clickbait'},
                'dual_input': True
            }
        }

    def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        task = inputs.get('task', None)
        title = inputs.get('title', None)
        text = inputs.get('text', inputs.get('inputs', None))

        if task is None or task not in self.task_config.keys():
            raise ValueError(f"Invalid task: {task}")

        config = self.task_config[task]
        max_length = 256 if task == 'clickbait' else 128
        encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')

        if config.get('dual_input', False):
            encoding = self.tokenizer(f"{title} </s></s> {text}", padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')

        return encoding

    def predict(self, task:str, preprocessed: Dict[str, Any]) -> Dict[str, Any]:

        logits = self.task_heads[task](self.model(**preprocessed).last_hidden_state[:, 0, :])
        config = self.task_config[task]

        if config['type'] == 'multi_label':
            probs = torch.sigmoid(logits).detach().cpu().numpy()[0]
            active = [config['label_list'][i] for i in np.where(probs > 0.5)[0]]
            return {'labels': active, 'scores': probs.tolist()}
        else:
            probs = torch.softmax(logits, dim=-1).detach().cpu().numpy()[0]
            pred_idx = int(np.argmax(probs))
            return {'label': config['label_map'][pred_idx], 'confidence': float(probs[pred_idx])}

    # def postprocess(self, outputs: Dict[str, Any]) -> List[Dict[str, Any]]:
    #     return [outputs]



    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        task = data.get('task', None)
        print(f"Task: {task}")
        if task is None:
            raise ValueError("'task' key is required in the input dictionary")
        task = task.lower()

        results = {}
        if task == "all":

            for _t in self.task_config.keys():
                data['task'] = _t
                preprocessed = self.preprocess(data)
                outputs = self.predict(_t, preprocessed)
                results[_t] = outputs
            return results
        elif task not in self.task_config.keys():
            raise ValueError(f"Invalid task: {task}")
        preprocessed = self.preprocess(data)

        outputs = self.predict(task, preprocessed)
        results[task] = outputs
        # return self.postprocess(outputs)
        return results

class TaskClassificationHead(torch.nn.Module):
    def __init__(self, hidden_size: int, num_labels: int, dropout: float):
        super().__init__()
        bottleneck = max(hidden_size // 2, num_labels)
        self.projection = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, bottleneck),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(bottleneck, num_labels),
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.projection(hidden_states)