from transformers import T5Tokenizer, MT5ForConditionalGeneration from simpletransformers.t5 import T5Model import datetime import logging import os class Inference: def _discard_recommendations(self, original, proposal): proposal = proposal.lower() original = original.lower() if proposal == original: return True chars = [".", "!", " ", "?", ","] _proposal = proposal _original = original for char in chars: proposal = proposal.replace(char, "") original = original.replace(char, "") if proposal == original: return True return False # https://github.com/Vamsi995/Paraphrase-Generator/blob/master/evaluate.py def get_paraphrases( self, model_name, sentence, temperature, prefix="paraphrase: ", n_predictions=2, top_k=120, max_length=256, device="cpu", ): model = MT5ForConditionalGeneration.from_pretrained(model_name) tokenizer = T5Tokenizer.from_pretrained(model_name) discaded = 0 text = prefix + sentence + " " encoding = tokenizer.encode_plus( text, pad_to_max_length=True, return_tensors="pt" ) input_ids, attention_masks = encoding["input_ids"].to(device), encoding[ "attention_mask" ].to(device) do_sample = True if temperature > 0 else False print(f"do_sample: {do_sample}") print(f"temperature: {temperature}") # https://huggingface.co/blog/how-to-generate # https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html model_output = model.generate( input_ids=input_ids, attention_mask=attention_masks, do_sample=do_sample, max_length=max_length, top_k=top_k, num_beams=n_predictions * 2, ## ask for twice since some will be discarted top_p=0.98, temperature=temperature, early_stopping=True, num_return_sequences=n_predictions * 2, ) logging.debug(f"{len(model_output)} predictions for {sentence}") outputs = [] for output in model_output: generated_sent = tokenizer.decode( output, skip_special_tokens=True, clean_up_tokenization_spaces=True ) if ( self._discard_recommendations(sentence, generated_sent) is False and generated_sent not in outputs ): generated_sent = generated_sent.replace("’", "'") outputs.append(generated_sent) else: logging.debug(f"Discarded: {generated_sent} - source:{sentence}") discaded = +1 if len(outputs) == n_predictions: break return outputs def main(): i = Inference() sentence = "Aquesta és una associació sense ànim de lucre amb la missió de fomentar la presència i l'ús del català." model = os.getcwd() options = i.get_paraphrases(model, sentence, 1.0) print(f"original: {sentence}") for option in options: print(f" {option}") if __name__ == "__main__": main()