canovich commited on
Commit
00fb7d5
·
1 Parent(s): 817beb2

Upload code/ with huggingface_hub

Browse files
Files changed (1) hide show
  1. code/inference.py +342 -0
code/inference.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging, requests, os, io, glob, time
2
+ import json
3
+
4
+
5
+ from transformers import BertTokenizer
6
+ from transformers import PreTrainedModel
7
+ import torch
8
+
9
+ from fastai.text import *
10
+ import itertools
11
+ from typing import Optional, Dict, Union
12
+
13
+ from nltk import sent_tokenize
14
+
15
+ from transformers import(
16
+ AutoModelForSeq2SeqLM,
17
+
18
+ PreTrainedModel,
19
+ PreTrainedTokenizer,
20
+ )
21
+ from transformers import AutoTokenizer
22
+ import torch
23
+
24
+
25
+ class QGPipeline:
26
+
27
+ def __init__(
28
+ self,
29
+ model: PreTrainedModel,
30
+ tokenizer: PreTrainedTokenizer,
31
+ ans_model: PreTrainedModel,
32
+ ans_tokenizer: PreTrainedTokenizer,
33
+ qg_format: str,
34
+ use_cuda: bool
35
+ ):
36
+ self.model = model
37
+ self.tokenizer = tokenizer
38
+
39
+ self.ans_model = ans_model
40
+ self.ans_tokenizer = ans_tokenizer
41
+
42
+ self.qg_format = qg_format
43
+
44
+ self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
45
+ self.model.to(self.device)
46
+
47
+ if self.ans_model is not self.model:
48
+ self.ans_model.to(self.device)
49
+
50
+ assert self.model.__class__.__name__ in ["MT5ForConditionalGeneration"]
51
+
52
+ self.model_type = "mt5"
53
+
54
+ def __call__(self, inputs: str):
55
+ inputs = " ".join(inputs.split())
56
+ sents, answers = self._extract_answers(inputs)
57
+ flat_answers = list(itertools.chain(*answers))
58
+
59
+ if len(flat_answers) == 0:
60
+ return []
61
+
62
+ qg_examples = self._prepare_inputs_for_qg_from_answers_hl(sents, answers)
63
+
64
+ qg_inputs = [example['source_text'] for example in qg_examples]
65
+ questions = self._generate_questions(qg_inputs)
66
+ output = [{'answer': example['answer'], 'question': que} for example, que in zip(qg_examples, questions)]
67
+ return output
68
+
69
+ def _generate_questions(self, inputs):
70
+ inputs = self._tokenize(inputs, padding=True, truncation=True)
71
+
72
+ outs = self.model.generate(
73
+ input_ids=inputs['input_ids'].to(self.device),
74
+ attention_mask=inputs['attention_mask'].to(self.device),
75
+ max_length=80,
76
+ num_beams=4,
77
+ )
78
+
79
+ questions = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
80
+ return questions
81
+
82
+ def _extract_answers(self, context):
83
+ sents, inputs = self._prepare_inputs_for_ans_extraction(context)
84
+
85
+ inputs = self._tokenize(inputs, padding=True, truncation=True)
86
+
87
+ outs = self.ans_model.generate(
88
+ input_ids=inputs['input_ids'].to(self.device),
89
+ attention_mask=inputs['attention_mask'].to(self.device),
90
+ max_length=80,
91
+ )
92
+
93
+
94
+ dec = [self.ans_tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
95
+
96
+ answers = [item.split('<sep>') for item in dec]
97
+
98
+ answers = [i[:-1] for i in answers]
99
+ answ_ = []
100
+ for i in answers:
101
+ l = []
102
+ for b in i:
103
+ l.append(b.replace("<pad>", ""))
104
+ answ_.append(l)
105
+ print(answers)
106
+ return sents, answ_
107
+
108
+ def _tokenize(self,
109
+ inputs,
110
+ padding=True,
111
+ truncation=True,
112
+ add_special_tokens=True,
113
+ max_length=512
114
+ ):
115
+ inputs = self.tokenizer.batch_encode_plus(
116
+ inputs,
117
+ max_length=max_length,
118
+ add_special_tokens=add_special_tokens,
119
+ truncation=truncation,
120
+ padding="max_length" if padding else False,
121
+ pad_to_max_length=padding,
122
+ return_tensors="pt"
123
+ )
124
+
125
+ return inputs
126
+
127
+ def _prepare_inputs_for_ans_extraction(self, text):
128
+ sents = sent_tokenize(text)
129
+
130
+ inputs = []
131
+ for i in range(len(sents)):
132
+ source_text = "extract answers:"
133
+ for j, sent in enumerate(sents):
134
+ if i == j:
135
+ sent = "<hl> %s <hl>" % sent
136
+ source_text = "%s %s" % (source_text, sent)
137
+ source_text = source_text.strip()
138
+
139
+ if self.model_type == "mt5":
140
+ source_text = source_text + " </s>"
141
+
142
+ inputs.append(source_text)
143
+
144
+
145
+
146
+ return sents, inputs
147
+
148
+ def _prepare_inputs_for_qg_from_answers_hl(self, sents, answers):
149
+ inputs = []
150
+ for i, answer in enumerate(answers):
151
+ if len(answer) == 0: continue
152
+ for answer_text in answer:
153
+ sent = sents[i]
154
+ sents_copy = sents[:]
155
+
156
+ answer_text = answer_text.strip()
157
+
158
+ try:
159
+
160
+ ans_start_idx = sent.index(answer_text)
161
+
162
+ sent = f"{sent[:ans_start_idx]} <hl> {answer_text} <hl> {sent[ans_start_idx + len(answer_text):]}"
163
+ sents_copy[i] = sent
164
+
165
+ source_text = " ".join(sents_copy)
166
+ source_text = f"generate question: {source_text}"
167
+ if self.model_type == "mt5":
168
+ source_text = source_text + " </s>"
169
+ except:
170
+
171
+ continue
172
+
173
+ inputs.append({"answer": answer_text, "source_text": source_text})
174
+
175
+ return inputs
176
+
177
+
178
+ class MultiTaskQAQGPipeline(QGPipeline):
179
+ def __init__(self, **kwargs):
180
+ super().__init__(**kwargs)
181
+
182
+ def __call__(self, inputs: Union[Dict, str]):
183
+ if type(inputs) is str:
184
+ # do qg
185
+ return super().__call__(inputs)
186
+ else:
187
+ # do qa
188
+ return self._extract_answer(inputs["question"], inputs["context"])
189
+
190
+ def _prepare_inputs_for_qa(self, question, context):
191
+ source_text = f"question: {question} context: {context}"
192
+ if self.model_type == "mt5":
193
+ source_text = source_text + " </s>"
194
+ return source_text
195
+
196
+ def _extract_answer(self, question, context):
197
+ source_text = self._prepare_inputs_for_qa(question, context)
198
+ inputs = self._tokenize([source_text], padding=False)
199
+ outs = self.model.generate(
200
+ input_ids=inputs['input_ids'].to(self.device),
201
+ attention_mask=inputs['attention_mask'].to(self.device),
202
+ max_length=80,
203
+ )
204
+
205
+ answer = self.tokenizer.decode(outs[0], skip_special_tokens=True)
206
+
207
+ return answer
208
+
209
+
210
+ SUPPORTED_TASKS = {
211
+ "multitask-qa-qg": {
212
+ "impl": MultiTaskQAQGPipeline,
213
+ "default": {
214
+ "model": "ozcangundes/mt5-multitask-qa-qg-turkish",
215
+ }
216
+ },
217
+ }
218
+
219
+
220
+ def pipelinex(
221
+ task: str,
222
+ model: Optional = None,
223
+ tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
224
+ qg_format: Optional[str] = "highlight",
225
+ ans_model: Optional = None,
226
+ ans_tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
227
+ use_cuda: Optional[bool] = True,
228
+ **kwargs,
229
+ ):
230
+ # Retrieve the task
231
+ if task not in SUPPORTED_TASKS:
232
+ raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))
233
+
234
+ targeted_task = SUPPORTED_TASKS[task]
235
+ task_class = targeted_task["impl"]
236
+
237
+ # Use default model/config/tokenizer for the task if no model is provided
238
+ if model is None:
239
+ model = targeted_task["default"]["model"]
240
+
241
+ # Try to infer tokenizer from model or config name (if provided as str)
242
+ if tokenizer is None:
243
+ if isinstance(model, str):
244
+ tokenizer = model
245
+ else:
246
+ # Impossible to guest what is the right tokenizer here
247
+ raise Exception(
248
+ "Impossible to guess which tokenizer to use. "
249
+ "Please provided a PretrainedTokenizer class or a path/identifier to a pretrained tokenizer."
250
+ )
251
+
252
+ # Instantiate tokenizer if needed
253
+ if isinstance(tokenizer, (str, tuple)):
254
+ if isinstance(tokenizer, tuple):
255
+ # For tuple we have (tokenizer name, {kwargs})
256
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer[0], **tokenizer[1])
257
+ else:
258
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
259
+
260
+ # Instantiate model if needed
261
+ if isinstance(model, str):
262
+ model = AutoModelForSeq2SeqLM.from_pretrained(model)
263
+ print(ans_model)
264
+ return task_class(model=model, tokenizer=tokenizer, ans_model=model, ans_tokenizer=tokenizer, qg_format=qg_format,
265
+ use_cuda=use_cuda)
266
+
267
+ ################################################################################################
268
+
269
+
270
+
271
+ logger = logging.getLogger(__name__)
272
+ logger.setLevel(logging.DEBUG)
273
+
274
+ JSON_CONTENT_TYPE = 'application/json'
275
+
276
+
277
+
278
+ # loads the model into memory from disk and returns it
279
+ def model_fn():
280
+
281
+ model = AutoModelForSeq2SeqLM.from_pretrained("canovich/myprivateee")
282
+ return model
283
+
284
+
285
+ # Perform prediction on the deserialized object, with the loaded model
286
+ def predict_fn(input, model,tokenizer):
287
+
288
+
289
+
290
+ logger.info("Calling model")
291
+ start_time = time.time()
292
+ #pipelines.py script in the cloned repo
293
+ multimodel = pipelinex("multitask-qa-qg",tokenizer=tokenizer,model=model)
294
+ answers = multimodel(input)
295
+ print("--- Inference time: %s seconds ---" % (time.time() - start_time))
296
+
297
+
298
+ return answers
299
+ # Deserialize the Invoke request body into an object we can perform prediction on
300
+ def input_fn(request_body, content_type=JSON_CONTENT_TYPE):
301
+ logger.info('Deserializing the input data.')
302
+ # process an jsonlines uploaded to the endpoint
303
+ if content_type == JSON_CONTENT_TYPE: return request_body["text"]
304
+ raise Exception('Requested unsupported ContentType in content_type: {}'.format(content_type))
305
+
306
+ # Serialize the prediction result into the desired response content type
307
+ def output_fn(prediction, accept=JSON_CONTENT_TYPE):
308
+ logger.info('Serializing the generated output.')
309
+ if accept == JSON_CONTENT_TYPE: return json.dumps(prediction), accept
310
+ raise Exception('Requested unsupported ContentType in Accept: {}'.format(accept))
311
+
312
+
313
+
314
+
315
+
316
+
317
+
318
+
319
+
320
+
321
+
322
+
323
+
324
+
325
+
326
+
327
+
328
+
329
+
330
+
331
+
332
+
333
+
334
+
335
+
336
+
337
+
338
+
339
+
340
+
341
+
342
+