|
|
from transformers import T5Tokenizer, MT5ForConditionalGeneration |
|
|
from simpletransformers.t5 import T5Model |
|
|
import datetime |
|
|
import logging |
|
|
import os |
|
|
|
|
|
|
|
|
class Inference: |
|
|
def _discard_recommendations(self, original, proposal): |
|
|
proposal = proposal.lower() |
|
|
original = original.lower() |
|
|
if proposal == original: |
|
|
return True |
|
|
|
|
|
chars = [".", "!", " ", "?", ","] |
|
|
_proposal = proposal |
|
|
_original = original |
|
|
for char in chars: |
|
|
proposal = proposal.replace(char, "") |
|
|
original = original.replace(char, "") |
|
|
|
|
|
if proposal == original: |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
def get_paraphrases( |
|
|
self, |
|
|
model_name, |
|
|
sentence, |
|
|
temperature, |
|
|
prefix="paraphrase: ", |
|
|
n_predictions=2, |
|
|
top_k=120, |
|
|
max_length=256, |
|
|
device="cpu", |
|
|
): |
|
|
model = MT5ForConditionalGeneration.from_pretrained(model_name) |
|
|
tokenizer = T5Tokenizer.from_pretrained(model_name) |
|
|
|
|
|
discaded = 0 |
|
|
text = prefix + sentence + " </s>" |
|
|
encoding = tokenizer.encode_plus( |
|
|
text, pad_to_max_length=True, return_tensors="pt" |
|
|
) |
|
|
input_ids, attention_masks = encoding["input_ids"].to(device), encoding[ |
|
|
"attention_mask" |
|
|
].to(device) |
|
|
|
|
|
do_sample = True if temperature > 0 else False |
|
|
print(f"do_sample: {do_sample}") |
|
|
print(f"temperature: {temperature}") |
|
|
|
|
|
|
|
|
model_output = model.generate( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_masks, |
|
|
do_sample=do_sample, |
|
|
max_length=max_length, |
|
|
top_k=top_k, |
|
|
num_beams=n_predictions * 2, |
|
|
top_p=0.98, |
|
|
temperature=temperature, |
|
|
early_stopping=True, |
|
|
num_return_sequences=n_predictions * 2, |
|
|
) |
|
|
logging.debug(f"{len(model_output)} predictions for {sentence}") |
|
|
outputs = [] |
|
|
for output in model_output: |
|
|
generated_sent = tokenizer.decode( |
|
|
output, skip_special_tokens=True, clean_up_tokenization_spaces=True |
|
|
) |
|
|
if ( |
|
|
self._discard_recommendations(sentence, generated_sent) is False |
|
|
and generated_sent not in outputs |
|
|
): |
|
|
generated_sent = generated_sent.replace("’", "'") |
|
|
outputs.append(generated_sent) |
|
|
else: |
|
|
logging.debug(f"Discarded: {generated_sent} - source:{sentence}") |
|
|
discaded = +1 |
|
|
|
|
|
if len(outputs) == n_predictions: |
|
|
break |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
def main(): |
|
|
i = Inference() |
|
|
sentence = "Aquesta és una associació sense ànim de lucre amb la missió de fomentar la presència i l'ús del català." |
|
|
model = os.getcwd() |
|
|
options = i.get_paraphrases(model, sentence, 1.0) |
|
|
print(f"original: {sentence}") |
|
|
for option in options: |
|
|
print(f" {option}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|