File size: 5,321 Bytes
65f125e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a19b0dc
65f125e
 
 
 
 
 
 
 
 
 
 
 
 
a19b0dc
 
 
 
65f125e
 
 
 
 
 
a19b0dc
 
 
 
 
 
 
 
 
 
 
 
65f125e
 
a19b0dc
 
 
 
65f125e
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from typing import Dict, List, Any
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
import os

class EndpointHandler:
    def __init__(self, path: str):
      
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.tokenizer = AutoTokenizer.from_pretrained('bie-nhd/visobert-multitask')
        checkpoint = torch.load(os.path.join(path, 'model.pt'), map_location=self.device)

        self.model_state_dict = checkpoint['encoder']
        self.model = AutoModel.from_pretrained('uitnlp/visobert')
        self.model.load_state_dict(self.model_state_dict)
        self.model.eval()

        self.task_heads = torch.nn.ModuleDict({
            'sentiment': TaskClassificationHead(self.model.config.hidden_size, 4, 0.2),
            'topic': TaskClassificationHead(self.model.config.hidden_size, 10, 0.2),
            'hate_speech': TaskClassificationHead(self.model.config.hidden_size, 5, 0.2),
            'clickbait': TaskClassificationHead(self.model.config.hidden_size, 2, 0.2)
        })
        self.log_vars = torch.nn.ParameterDict({
            task: torch.nn.Parameter(torch.zeros(1)) for task in self.task_heads
        })
        self.log_vars.load_state_dict(checkpoint['log_vars'])
        self.model.to(self.device)
        self.task_heads.to(self.device)
        self.log_vars.to(self.device)

        self.task_config = {
            'sentiment': {
                'num_labels': 4,
                'type': 'single_label',
                'label_map': {0: 'Neutral', 1: 'Positive', 2: 'Negative', 3: 'Toxic'}
            },
            'topic': {
                'num_labels': 10,
                'type': 'single_label',
                'label_map': {i: label for i, label in enumerate(['Spam', 'News', 'Academic', 'Other', 'Service', 'Jobs', 'Personal', 'Social', 'Help', 'Events'])}
            },
            'hate_speech': {
                'num_labels': 5,
                'type': 'multi_label',
                'label_list': ['individual', 'groups', 'religion/creed', 'race/ethnicity', 'politics']
            },
            'clickbait': {
                'num_labels': 2,
                'type': 'single_label',
                'label_map': {0: 'Non-Clickbait', 1: 'Clickbait'},
                'dual_input': True
            }
        }

    def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        task = inputs.get('task', None)
        title = inputs.get('title', None)
        text = inputs.get('text', inputs.get('inputs', None))

        if task is None or task not in self.task_config.keys():
            raise ValueError(f"Invalid task: {task}")

        config = self.task_config[task]
        max_length = 256 if task == 'clickbait' else 128
        encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')

        if config.get('dual_input', False):
            encoding = self.tokenizer(f"{title} </s></s> {text}", padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')

        return encoding

    def predict(self, task:str, preprocessed: Dict[str, Any]) -> Dict[str, Any]:

        logits = self.task_heads[task](self.model(**preprocessed).last_hidden_state[:, 0, :])
        config = self.task_config[task]

        if config['type'] == 'multi_label':
            probs = torch.sigmoid(logits).detach().cpu().numpy()[0]
            active = [config['label_list'][i] for i in np.where(probs > 0.5)[0]]
            return {'labels': active, 'scores': probs.tolist()}
        else:
            probs = torch.softmax(logits, dim=-1).detach().cpu().numpy()[0]
            pred_idx = int(np.argmax(probs))
            return {'label': config['label_map'][pred_idx], 'confidence': float(probs[pred_idx])}

    # def postprocess(self, outputs: Dict[str, Any]) -> List[Dict[str, Any]]:
    #     return [outputs]



    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        task = data.get('task', None)
        print(f"Task: {task}")
        if task is None:
            raise ValueError("'task' key is required in the input dictionary")
        task = task.lower()

        if task == "all":
            results = {}

            for _t in self.task_config.keys():
                data['task'] = _t
                preprocessed = self.preprocess(data)
                outputs = self.predict(_t, preprocessed)
                results[_t] = outputs
            return results
        elif task not in self.task_config.keys():
            raise ValueError(f"Invalid task: {task}")
        preprocessed = self.preprocess(data)

        outputs = self.predict(task, preprocessed)
        # return self.postprocess(outputs)
        return outputs

class TaskClassificationHead(torch.nn.Module):
    def __init__(self, hidden_size: int, num_labels: int, dropout: float):
        super().__init__()
        bottleneck = max(hidden_size // 2, num_labels)
        self.projection = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, bottleneck),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(bottleneck, num_labels),
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.projection(hidden_states)