bie-nhd commited on
Commit
5b2de76
·
1 Parent(s): 38b4475

create app.py and handlers

Browse files
Files changed (4) hide show
  1. .gitignore +2 -0
  2. app.py +25 -0
  3. handler.py +133 -0
  4. test_handler.py +17 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ .gradio
app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ from handler import EndpointHandler
4
+
5
+ handler = EndpointHandler('bie-nhd/visobert-multitask')
6
+
7
+ def predict(text, task):
8
+ # Call HuggingFace Inference Endpoint to get classification result
9
+ result = handler({"task": task, "text": text})
10
+ return result
11
+
12
+ with gr.Blocks() as app:
13
+ gr.Markdown("# Text Classification")
14
+ with gr.Row():
15
+ with gr.Column():
16
+ text_input = gr.Textbox(label="text")
17
+ task_dropdown = gr.Dropdown(label="task", choices=["all", "sentiment", "topic", "hate_speech"])
18
+ predict_btn = gr.Button("Predict")
19
+ with gr.Column():
20
+ output = gr.Textbox(label="Result")
21
+
22
+ predict_btn.click(fn=predict, inputs=[text_input, task_dropdown], outputs=output)
23
+
24
+
25
+ app.launch(share=True)
handler.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoTokenizer, AutoModel, AutoConfig
5
+ from huggingface_hub import hf_hub_download
6
+ import os
7
+
8
+ class EndpointHandler:
9
+ def __init__(self, path: str):
10
+ MODEL_REPO = 'bie-nhd/visobert-multitask'
11
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ self.tokenizer = AutoTokenizer.from_pretrained('bie-nhd/visobert-multitask')
13
+ # config = AutoConfig.from_pretrained('bie-nhd/visobert-multitask')
14
+ model_path = hf_hub_download(repo_id=MODEL_REPO, filename="model.pt")
15
+ self.model = AutoModel.from_pretrained("uitnlp/visobert")
16
+
17
+ checkpoint = torch.load(model_path, map_location=self.device)
18
+ self.model_state_dict = checkpoint['encoder']
19
+ self.model.load_state_dict(self.model_state_dict)
20
+ self.model.eval()
21
+
22
+ self.task_heads = torch.nn.ModuleDict({
23
+ 'sentiment': TaskClassificationHead(self.model.config.hidden_size, 4, 0.2),
24
+ 'topic': TaskClassificationHead(self.model.config.hidden_size, 10, 0.2),
25
+ 'hate_speech': TaskClassificationHead(self.model.config.hidden_size, 5, 0.2),
26
+ 'clickbait': TaskClassificationHead(self.model.config.hidden_size, 2, 0.2)
27
+ })
28
+ self.log_vars = torch.nn.ParameterDict({
29
+ task: torch.nn.Parameter(torch.zeros(1)) for task in self.task_heads
30
+ })
31
+ self.log_vars.load_state_dict(checkpoint['log_vars'])
32
+ self.model.to(self.device)
33
+ self.task_heads.to(self.device)
34
+ self.log_vars.to(self.device)
35
+
36
+ self.task_config = {
37
+ 'sentiment': {
38
+ 'num_labels': 4,
39
+ 'type': 'single_label',
40
+ 'label_map': {0: 'Neutral', 1: 'Positive', 2: 'Negative', 3: 'Toxic'}
41
+ },
42
+ 'topic': {
43
+ 'num_labels': 10,
44
+ 'type': 'single_label',
45
+ 'label_map': {i: label for i, label in enumerate(['Spam', 'News', 'Academic', 'Other', 'Service', 'Jobs', 'Personal', 'Social', 'Help', 'Events'])}
46
+ },
47
+ 'hate_speech': {
48
+ 'num_labels': 5,
49
+ 'type': 'multi_label',
50
+ 'label_list': ['individual', 'groups', 'religion/creed', 'race/ethnicity', 'politics']
51
+ },
52
+ 'clickbait': {
53
+ 'num_labels': 2,
54
+ 'type': 'single_label',
55
+ 'label_map': {0: 'Non-Clickbait', 1: 'Clickbait'},
56
+ 'dual_input': True
57
+ }
58
+ }
59
+
60
+ def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
61
+ task = inputs.get('task', None)
62
+ title = inputs.get('title', None)
63
+ text = inputs.get('text', inputs.get('inputs', None))
64
+
65
+ if task is None or task not in self.task_config.keys():
66
+ raise ValueError(f"Invalid task: {task}")
67
+
68
+ config = self.task_config[task]
69
+ max_length = 256 if task == 'clickbait' else 128
70
+ encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
71
+
72
+ if config.get('dual_input', False):
73
+ encoding = self.tokenizer(f"{title} </s></s> {text}", padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
74
+
75
+ return encoding
76
+
77
+ def predict(self, task:str, preprocessed: Dict[str, Any]) -> Dict[str, Any]:
78
+
79
+ logits = self.task_heads[task](self.model(**preprocessed).last_hidden_state[:, 0, :])
80
+ config = self.task_config[task]
81
+
82
+ if config['type'] == 'multi_label':
83
+ probs = torch.sigmoid(logits).detach().cpu().numpy()[0]
84
+ active = [config['label_list'][i] for i in np.where(probs > 0.5)[0]]
85
+ return {'labels': active, 'scores': probs.tolist()}
86
+ else:
87
+ probs = torch.softmax(logits, dim=-1).detach().cpu().numpy()[0]
88
+ pred_idx = int(np.argmax(probs))
89
+ return {'label': config['label_map'][pred_idx], 'confidence': float(probs[pred_idx])}
90
+
91
+ # def postprocess(self, outputs: Dict[str, Any]) -> List[Dict[str, Any]]:
92
+ # return [outputs]
93
+
94
+
95
+
96
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
97
+ task = data.get('task', None)
98
+ print(f"Task: {task}")
99
+ if task is None:
100
+ raise ValueError("'task' key is required in the input dictionary")
101
+ task = task.lower()
102
+
103
+ results = {}
104
+ if task == "all":
105
+
106
+ for _t in self.task_config.keys():
107
+ data['task'] = _t
108
+ preprocessed = self.preprocess(data)
109
+ outputs = self.predict(_t, preprocessed)
110
+ results[_t] = outputs
111
+ return results
112
+ elif task not in self.task_config.keys():
113
+ raise ValueError(f"Invalid task: {task}")
114
+ preprocessed = self.preprocess(data)
115
+
116
+ outputs = self.predict(task, preprocessed)
117
+ results[task] = outputs
118
+ # return self.postprocess(outputs)
119
+ return results
120
+
121
+ class TaskClassificationHead(torch.nn.Module):
122
+ def __init__(self, hidden_size: int, num_labels: int, dropout: float):
123
+ super().__init__()
124
+ bottleneck = max(hidden_size // 2, num_labels)
125
+ self.projection = torch.nn.Sequential(
126
+ torch.nn.Linear(hidden_size, bottleneck),
127
+ torch.nn.ReLU(),
128
+ torch.nn.Dropout(dropout),
129
+ torch.nn.Linear(bottleneck, num_labels),
130
+ )
131
+
132
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
133
+ return self.projection(hidden_states)
test_handler.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+ from huggingface_hub import login
3
+ from transformers import AutoTokenizer, AutoModel
4
+ import os
5
+
6
+ # LOGIN TO HUGGINGFACE TO UNLOCK STREAM
7
+ login("")
8
+
9
+ handler = EndpointHandler('bie-nhd/visobert-multitask')
10
+
11
+ print(handler({'task': 'topic', 'text': 'Thầy cô cho em hỏi đăng ký học ghép môn Kinh tế vi mô ở đâu ạ? Em cảm ơn. ⭐'}))
12
+ print(handler({'task': 'all', 'text': 'Thầy cô trường này bị sao ấy??? Học kỳ trước em đăng ký học 2 môn online, giờ học kỳ này vào đăng ký học lại thì bị khóa tài khoản không cho đăng ký học nữa. Em phải làm sao ạ???'}))
13
+
14
+ # # tokenizer = AutoTokenizer.from_pretrained('bie-nhd/visobert-multitask')
15
+ # config = AutoConfig.from_pretrained('bie-nhd/visobert-multitask')
16
+ # print(f"PATH: {os.path}")
17
+ # print(f"CONFIG: {config}")