Spaces:
Runtime error
Runtime error
| """ | |
| Allows to predict the summary for a given entry text | |
| """ | |
| import re | |
| import string | |
| import os | |
| os.environ['TRANSFORMERS_CACHE'] = './.cache' | |
| import contractions | |
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| def clean_text(texts: str) -> str: | |
| texts = texts.lower() | |
| texts = texts.translate(str.maketrans("", "", string.punctuation)) | |
| texts = re.sub(r"\n", " ", texts) | |
| return texts | |
| def inference_t5(text: str) -> str: | |
| """ | |
| Predict the summary for an input text | |
| -------- | |
| Parameter | |
| text: str | |
| the text to sumarize | |
| Return | |
| str | |
| The summary for the input text | |
| """ | |
| # On défini les paramètres d'entrée pour le modèle | |
| text = clean_text(text) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| hf_token = "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" | |
| tokenizer = AutoTokenizer.from_pretrained("Linggg/t5_summary", use_auth_token=hf_token ) | |
| # load local model | |
| model = (AutoModelForSeq2SeqLM | |
| .from_pretrained("Linggg/t5_summary", use_auth_token = hf_token ) | |
| .to(device)) | |
| text_encoding = tokenizer( | |
| text, | |
| max_length=1024, | |
| padding="max_length", | |
| truncation=True, | |
| return_attention_mask=True, | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| ) | |
| generated_ids = model.generate( | |
| input_ids=text_encoding["input_ids"], | |
| attention_mask=text_encoding["attention_mask"], | |
| max_length=128, | |
| num_beams=8, | |
| length_penalty=0.8, | |
| early_stopping=True, | |
| ) | |
| preds = [ | |
| tokenizer.decode( | |
| gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
| ) | |
| for gen_id in generated_ids | |
| ] | |
| return "".join(preds) | |
| # if __name__ == "__main__": | |
| # text = input('Entrez votre phrase à résumer : ') | |
| # print('summary:', inferenceAPI_T5(text)) | |