Upload pipeline.py
Browse files- pipeline.py +197 -0
pipeline.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
from transformers import Pipeline, pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
|
| 3 |
+
from transformers.tokenization_utils_base import TruncationStrategy
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
import html.parser
|
| 6 |
+
import unicodedata
|
| 7 |
+
import sys, os, re
|
| 8 |
+
|
| 9 |
+
class ReaccentPipeline(Pipeline):
|
| 10 |
+
|
| 11 |
+
def __init__(self, beam_size=5, batch_size=32, **kwargs):
|
| 12 |
+
self.beam_size = beam_size
|
| 13 |
+
super().__init__(**kwargs)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _sanitize_parameters(self, clean_up_tokenisation_spaces=None, truncation=None, **generate_kwargs):
|
| 17 |
+
preprocess_params = {}
|
| 18 |
+
if truncation is not None:
|
| 19 |
+
preprocess_params["truncation"] = truncation
|
| 20 |
+
|
| 21 |
+
forward_params = generate_kwargs
|
| 22 |
+
|
| 23 |
+
postprocess_params = {}
|
| 24 |
+
|
| 25 |
+
if clean_up_tokenisation_spaces is not None:
|
| 26 |
+
postprocess_params["clean_up_tokenisation_spaces"] = clean_up_tokenisation_spaces
|
| 27 |
+
|
| 28 |
+
return preprocess_params, forward_params, postprocess_params
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def check_inputs(self, input_length: int, min_length: int, max_length: int):
|
| 32 |
+
"""
|
| 33 |
+
Checks whether there might be something wrong with given input with regard to the model.
|
| 34 |
+
"""
|
| 35 |
+
return True
|
| 36 |
+
|
| 37 |
+
def make_printable(self, s):
|
| 38 |
+
'''Replace non-printable characters in a string.'''
|
| 39 |
+
return s.translate(NOPRINT_TRANS_TABLE)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def normalise(self, line):
|
| 43 |
+
#line = unicodedata.normalize('NFKC', line)
|
| 44 |
+
#line = self.make_printable(line)
|
| 45 |
+
for before, after in [('[«»\“\”]', '"'),
|
| 46 |
+
('[‘’]', "'"),
|
| 47 |
+
(' +', ' '),
|
| 48 |
+
('\"+', '"'),
|
| 49 |
+
("'+", "'"),
|
| 50 |
+
('^ *', ''),
|
| 51 |
+
(' *$', '')]:
|
| 52 |
+
line = re.sub(before, after, line)
|
| 53 |
+
return line.strip() + ' </s>'
|
| 54 |
+
|
| 55 |
+
def _parse_and_tokenise(self, *args, truncation):
|
| 56 |
+
prefix = ""
|
| 57 |
+
if isinstance(args[0], list):
|
| 58 |
+
if self.tokenizer.pad_token_id is None:
|
| 59 |
+
raise ValueError("Please make sure that the tokeniser has a pad_token_id when using a batch input")
|
| 60 |
+
args = ([prefix + arg for arg in args[0]],)
|
| 61 |
+
padding = True
|
| 62 |
+
|
| 63 |
+
elif isinstance(args[0], str):
|
| 64 |
+
args = (prefix + args[0],)
|
| 65 |
+
padding = False
|
| 66 |
+
else:
|
| 67 |
+
raise ValueError(
|
| 68 |
+
f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`"
|
| 69 |
+
)
|
| 70 |
+
inputs = [self.normalise(x) for x in args]
|
| 71 |
+
inputs = self.tokenizer(inputs, padding=padding, truncation=truncation, return_tensors=self.framework)
|
| 72 |
+
toks = []
|
| 73 |
+
for tok_ids in inputs.input_ids:
|
| 74 |
+
toks.append(" ".join(self.tokenizer.convert_ids_to_tokens(tok_ids)))
|
| 75 |
+
# This is produced by tokenisers but is an invalid generate kwargs
|
| 76 |
+
if "token_type_ids" in inputs:
|
| 77 |
+
del inputs["token_type_ids"]
|
| 78 |
+
return inputs
|
| 79 |
+
|
| 80 |
+
def preprocess(self, inputs, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs):
|
| 81 |
+
inputs = self._parse_and_tokenise(inputs, truncation=truncation, **kwargs)
|
| 82 |
+
return inputs
|
| 83 |
+
|
| 84 |
+
def _forward(self, model_inputs, **generate_kwargs):
|
| 85 |
+
in_b, input_length = model_inputs["input_ids"].shape
|
| 86 |
+
|
| 87 |
+
generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length)
|
| 88 |
+
generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length)
|
| 89 |
+
generate_kwargs['num_beams'] = self.beam_size
|
| 90 |
+
self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"])
|
| 91 |
+
output_ids = self.model.generate(**model_inputs, **generate_kwargs)
|
| 92 |
+
out_b = output_ids.shape[0]
|
| 93 |
+
output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
|
| 94 |
+
return {"output_ids": output_ids}
|
| 95 |
+
|
| 96 |
+
def postprocess(self, model_outputs, clean_up_tokenisation_spaces=False):
|
| 97 |
+
records = []
|
| 98 |
+
for output_ids in model_outputs["output_ids"][0]:
|
| 99 |
+
record = {
|
| 100 |
+
"text": self.tokenizer.decode(
|
| 101 |
+
output_ids,
|
| 102 |
+
skip_special_tokens=True,
|
| 103 |
+
clean_up_tokenisation_spaces=clean_up_tokenisation_spaces,
|
| 104 |
+
)
|
| 105 |
+
}
|
| 106 |
+
records.append(record)
|
| 107 |
+
return records
|
| 108 |
+
|
| 109 |
+
def correct_hallunications(self, orig, output):
|
| 110 |
+
# align the original and output tokens
|
| 111 |
+
|
| 112 |
+
# check that the correspondences are legitimate and correct if not
|
| 113 |
+
|
| 114 |
+
# replace <EMOJI> symbols by the original ones
|
| 115 |
+
return output
|
| 116 |
+
|
| 117 |
+
def __call__(self, *args, **kwargs):
|
| 118 |
+
r"""
|
| 119 |
+
Generate the output text(s) using text(s) given as inputs.
|
| 120 |
+
Args:
|
| 121 |
+
args (`str` or `List[str]`):
|
| 122 |
+
Input text for the encoder.
|
| 123 |
+
return_tensors (`bool`, *optional*, defaults to `False`):
|
| 124 |
+
Whether or not to include the tensors of predictions (as token indices) in the outputs.
|
| 125 |
+
return_text (`bool`, *optional*, defaults to `True`):
|
| 126 |
+
Whether or not to include the decoded texts in the outputs.
|
| 127 |
+
clean_up_tokenisation_spaces (`bool`, *optional*, defaults to `False`):
|
| 128 |
+
Whether or not to clean up the potential extra spaces in the text output.
|
| 129 |
+
truncation (`TruncationStrategy`, *optional*, defaults to `TruncationStrategy.DO_NOT_TRUNCATE`):
|
| 130 |
+
The truncation strategy for the tokenisation within the pipeline. `TruncationStrategy.DO_NOT_TRUNCATE`
|
| 131 |
+
(default) will never truncate, but it is sometimes desirable to truncate the input to fit the model's
|
| 132 |
+
max_length instead of throwing an error down the line.
|
| 133 |
+
generate_kwargs:
|
| 134 |
+
Additional keyword arguments to pass along to the generate method of the model (see the generate method
|
| 135 |
+
corresponding to your framework [here](./model#generative-models)).
|
| 136 |
+
Return:
|
| 137 |
+
A list or a list of list of `dict`: Each result comes as a dictionary with the following keys:
|
| 138 |
+
- **generated_text** (`str`, present when `return_text=True`) -- The generated text.
|
| 139 |
+
- **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
|
| 140 |
+
ids of the generated text.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
result = super().__call__(*args, **kwargs)
|
| 144 |
+
if (
|
| 145 |
+
isinstance(args[0], list)
|
| 146 |
+
and all(isinstance(el, str) for el in args[0])
|
| 147 |
+
and all(len(res) == 1 for res in result)
|
| 148 |
+
):
|
| 149 |
+
return [res[0] for res in result]
|
| 150 |
+
return result
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def normalise_text(list_sents, batch_size=32, beam_size=5):
|
| 154 |
+
tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
| 155 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
| 156 |
+
normalisation_pipeline = ReaccentPipeline(model=model,
|
| 157 |
+
tokenizer=tokeniser,
|
| 158 |
+
batch_size=batch_size,
|
| 159 |
+
beam_size=beam_size)
|
| 160 |
+
normalised_outputs = normalisation_pipeline(list_sents)
|
| 161 |
+
return normalised_outputs
|
| 162 |
+
|
| 163 |
+
def normalise_from_stdin(batch_size=32, beam_size=5):
|
| 164 |
+
tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
| 165 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
| 166 |
+
normalisation_pipeline = ReaccentPipeline(model=model,
|
| 167 |
+
tokenizer=tokeniser,
|
| 168 |
+
batch_size=batch_size,
|
| 169 |
+
beam_size=beam_size)
|
| 170 |
+
list_sents = []
|
| 171 |
+
for sent in sys.stdin:
|
| 172 |
+
list_sents.append(sent)
|
| 173 |
+
normalised_outputs = normalisation_pipeline(list_sents)
|
| 174 |
+
for sent in normalised_outputs:
|
| 175 |
+
print(sent['text'].strip())
|
| 176 |
+
return normalised_outputs
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
if __name__ == '__main__':
|
| 180 |
+
|
| 181 |
+
import argparse
|
| 182 |
+
parser = argparse.ArgumentParser()
|
| 183 |
+
parser.add_argument('-k', '--batch_size', type=int, default=32, help='Set the batch size for decoding')
|
| 184 |
+
parser.add_argument('-b', '--beam_size', type=int, default=5, help='Set the beam size for decoding')
|
| 185 |
+
parser.add_argument('-i', '--input_file', type=str, default=None, help='Input file. If None, read from STDIN')
|
| 186 |
+
args = parser.parse_args()
|
| 187 |
+
|
| 188 |
+
if args.input_file is None:
|
| 189 |
+
normalise_from_stdin(batch_size=args.batch_size, beam_size=args.beam_size)
|
| 190 |
+
else:
|
| 191 |
+
list_sents = []
|
| 192 |
+
with open(args.input_file) as fp:
|
| 193 |
+
for line in fp:
|
| 194 |
+
list_sents.append(line.strip())
|
| 195 |
+
output_sents = normalise_text(list_sents, batch_size=args.batch_size, beam_size=args.beam_size)
|
| 196 |
+
for output_sent in output_sents:
|
| 197 |
+
print(output_sent)
|