heooo commited on
Commit
e02c258
·
1 Parent(s): 64b258b

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +159 -0
handler.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from pathlib import Path
3
+ import torch
4
+ from transformers import (
5
+ BartConfig,
6
+ BartForConditionalGeneration,
7
+ PreTrainedTokenizerFast,
8
+ )
9
+
10
+ class EndpointHandler():
11
+ def __init__(self, path=""):
12
+ # Load model from HuggingFace Hub
13
+ config = BartConfig.from_pretrained("hyunwoongko/kobart")
14
+ self.model = BartForConditionalGeneration(config).eval().to('cpu')
15
+ self.model.model.load_state_dict(torch.load(
16
+ path,
17
+ map_location='cpu',
18
+ ))
19
+ self.tokenizer = PreTrainedTokenizerFast.from_pretrained("hyunwoongko/kobart")
20
+
21
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
22
+ # destruct model and tokenizer
23
+ model = self.model
24
+ tokenizer = self.tokenizer
25
+
26
+ #parmeters
27
+ beam = 5
28
+ sampling = False
29
+ temperature = 1.0
30
+ sampling_topk = -1
31
+ sampling_topp = -1
32
+ length_penalty = 1.0
33
+ max_len_a = 1
34
+ max_len_b = 50
35
+ no_repeat_ngram_size = 4
36
+ return_tokens = False
37
+ bad_words_ids = None
38
+
39
+ dataPop = data.pop("inputs", data)
40
+
41
+ if isinstance(dataPop, str):
42
+ texts = [dataPop]
43
+ else:
44
+ texts = dataPop
45
+
46
+ tokenized = tokenize(tokenizer, texts)
47
+ input_ids = tokenized["input_ids"]
48
+ attention_mask = tokenized["attention_mask"]
49
+
50
+ generated = model.generate(
51
+ input_ids.to('cpu'),
52
+ attention_mask=attention_mask.to('cpu'),
53
+ use_cache=True,
54
+ early_stopping=False,
55
+ decoder_start_token_id=tokenizer.bos_token_id,
56
+ num_beams=beam,
57
+ do_sample=sampling,
58
+ temperature=temperature,
59
+ top_k=sampling_topk if sampling_topk > 0 else None,
60
+ top_p=sampling_topp if sampling_topk > 0 else None,
61
+ no_repeat_ngram_size=no_repeat_ngram_size,
62
+ bad_words_ids=[[tokenizer.convert_tokens_to_ids("<unk>")]]
63
+ if not bad_words_ids else bad_words_ids +
64
+ [[tokenizer.convert_tokens_to_ids("<unk>")]],
65
+ length_penalty=length_penalty,
66
+ max_length=max_len_a * len(input_ids[0]) + max_len_b,
67
+ )
68
+
69
+ summ_result = ''
70
+ if return_tokens:
71
+ output = [
72
+ tokenizer.convert_ids_to_tokens(_)
73
+ for _ in generated.tolist()
74
+ ]
75
+
76
+ summ_result = (output[0] if isinstance(
77
+ dataPop,
78
+ str,
79
+ ) else output)
80
+
81
+ else:
82
+ output = tokenizer.batch_decode(
83
+ generated.tolist(),
84
+ skip_special_tokens=True,
85
+ )
86
+
87
+ summ_result = (output[0].strip() if isinstance(
88
+ dataPop,
89
+ str,
90
+ ) else [o.strip() for o in output])
91
+
92
+ return {"summarization": summ_result}
93
+
94
+ def tokenize(
95
+ tokenizer,
96
+ texts: List[str],
97
+ max_len: int = 1024,
98
+ ) -> Dict:
99
+
100
+ if isinstance(texts, str):
101
+ texts = [texts]
102
+
103
+ texts = [f"<s> {text}" for text in texts]
104
+ eos = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
105
+ eos_list = [eos for _ in range(len(texts))]
106
+
107
+ tokens = tokenizer(
108
+ texts,
109
+ return_tensors="pt",
110
+ padding=True,
111
+ truncation=True,
112
+ add_special_tokens=False,
113
+ max_length=max_len - 1,
114
+ # result + <eos>
115
+ )
116
+
117
+ return add_bos_eos_tokens(tokenizer, tokens, eos_list)
118
+
119
+ def add_bos_eos_tokens(tokenizer, tokens, eos_list):
120
+ input_ids = tokens["input_ids"]
121
+ attention_mask = tokens["attention_mask"]
122
+ token_added_ids, token_added_masks = [], []
123
+
124
+ for input_id, atn_mask, eos in zip(
125
+ input_ids,
126
+ attention_mask,
127
+ eos_list,
128
+ ):
129
+ maximum_idx = [
130
+ i for i, val in enumerate(input_id)
131
+ if val != tokenizer.convert_tokens_to_ids("<pad>")
132
+ ]
133
+
134
+ if len(maximum_idx) == 0:
135
+ idx_to_add = 0
136
+ else:
137
+ idx_to_add = max(maximum_idx) + 1
138
+
139
+ eos = torch.tensor([eos], requires_grad=False)
140
+ additional_atn_mask = torch.tensor([1], requires_grad=False)
141
+
142
+ input_id = torch.cat([
143
+ input_id[:idx_to_add],
144
+ eos,
145
+ input_id[idx_to_add:],
146
+ ]).long()
147
+
148
+ atn_mask = torch.cat([
149
+ atn_mask[:idx_to_add],
150
+ additional_atn_mask,
151
+ atn_mask[idx_to_add:],
152
+ ]).long()
153
+
154
+ token_added_ids.append(input_id.unsqueeze(0))
155
+ token_added_masks.append(atn_mask.unsqueeze(0))
156
+
157
+ tokens["input_ids"] = torch.cat(token_added_ids, dim=0)
158
+ tokens["attention_mask"] = torch.cat(token_added_masks, dim=0)
159
+ return tokens