Spaces:
Runtime error
Runtime error
| ### Imports | |
| from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
| from transformers import BartForConditionalGeneration, BartTokenizer | |
| from transformers import T5ForConditionalGeneration, T5Tokenizer | |
| from transformers import ProphetNetForConditionalGeneration, ProphetNetTokenizer | |
| import torch | |
| from config import config | |
| ### Classes and functions | |
| ##========================================================================================================== | |
| class SummarizationUtilities: | |
| ##========================================================================================================== | |
| """ | |
| Definition of attributes | |
| """ | |
| model_name = None | |
| device = None | |
| tokenizer = None | |
| model = None | |
| ##========================================================================================================== | |
| """ | |
| Function: __init__ | |
| Arguments: | |
| - model_name | |
| - device | |
| """ | |
| def __init__(self, model_name="google/pegasus-xsum", device=None, model_path=config.pegasus_model_path): | |
| self.model_name = model_name | |
| if device == None: | |
| self.device = self.detect_available_cuda_device() | |
| else: | |
| self.device = device | |
| self.tokenizer = PegasusTokenizer.from_pretrained(model_path) | |
| self.model = PegasusForConditionalGeneration.from_pretrained(model_path).to(device) | |
| ##========================================================================================================= | |
| """ | |
| Function: detect_available_cuda_device | |
| Arguments: NA | |
| """ | |
| def detect_available_cuda_device(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| ##========================================================================================================= | |
| """ | |
| Function: detect_available_cuda_device | |
| Arguments: NA | |
| """ | |
| def tokenize(self, src_text, truncation = True, padding="longest", return_tensors="pt"): | |
| return self.tokenizer(src_text, truncation=truncation, padding=padding, return_tensors=return_tensors).to(self.device) | |
| ##========================================================================================================= | |
| """ | |
| Function: generate | |
| Arguments: | |
| - batch | |
| """ | |
| def generate(self, batch): | |
| text_generated = self.model.generate(**batch) | |
| return text_generated | |
| ##========================================================================================================= | |
| """ | |
| Function: decode_generated_text | |
| Arguments: | |
| - batch | |
| """ | |
| def decode_generated_text(self, generated_text, skip_special_tokens=True): | |
| return self.tokenizer.batch_decode(generated_text, skip_special_tokens=skip_special_tokens) | |
| ##========================================================================================================= | |
| """ | |
| Function: get_summary | |
| Arguments: | |
| - src_text | |
| """ | |
| def get_summary(self, src_text): | |
| summary = None | |
| batch = self.tokenize(src_text) | |
| generated_text = self.generate(batch) | |
| target_text = self.decode_generated_text(generated_text) | |
| #print("target_text", target_text) | |
| summary = target_text | |
| return summary | |
| def summarize(self, src_text): | |
| summary = None | |
| batch = self.tokenize(src_text) | |
| generated_text = self.generate(batch) | |
| target_text = self.decode_generated_text(generated_text) | |
| #print("target_text", target_text) | |
| summary = target_text | |
| return summary | |
| ##========================================================================================================= | |
| ##========================================================================================================== | |
| class BARTSummarizer: | |
| def __init__(self, device=None, model_path=config.bart_model_path): | |
| # https://stackoverflow.com/questions/66639722/why-does-huggingfaces-bart-summarizer-replicate-the-given-input-text | |
| self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # self.tokenizer = BartTokenizer.from_pretrained("sshleifer/distilbart-xsum-6-6") #facebook/bart-large-cnn | |
| # self.model = BartForConditionalGeneration.from_pretrained("sshleifer/distilbart-xsum-6-6").to(self.device) | |
| self.tokenizer = BartTokenizer.from_pretrained(model_path) | |
| self.model = BartForConditionalGeneration.from_pretrained(model_path) | |
| def summarize(self, text): | |
| inputs = self.tokenizer([text], truncation=True, padding="longest", return_tensors="pt").to(self.device) | |
| summary_ids = self.model.generate(inputs["input_ids"], num_beams=4, max_length=200, early_stopping=True) | |
| summary = self.tokenizer.decode(summary_ids.squeeze(), skip_special_tokens=True) | |
| return summary | |
| class T5Summarizer: | |
| def __init__(self, device=None, model_path=config.t5_model_path): | |
| self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # self.tokenizer = T5Tokenizer.from_pretrained("t5-base") | |
| # self.model = T5ForConditionalGeneration.from_pretrained("t5-base").to(self.device) | |
| self.tokenizer = T5Tokenizer.from_pretrained(model_path) | |
| self.model = T5ForConditionalGeneration.from_pretrained(model_path).to(self.device) | |
| def summarize(self, text): | |
| inputs = self.tokenizer.encode_plus(text, return_tensors="pt", truncation=True, padding="longest").to(self.device) | |
| summary_ids = self.model.generate(inputs.input_ids) | |
| summary = self.tokenizer.decode(summary_ids.squeeze(), skip_special_tokens=True) | |
| return summary | |
| class ProphetNetSummarizer: | |
| def __init__(self, device=None, model_path=config.prophetnet_model_path): | |
| self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # self.tokenizer = ProphetNetTokenizer.from_pretrained("microsoft/prophetnet-large-uncased") | |
| # self.model = ProphetNetForConditionalGeneration.from_pretrained("microsoft/prophetnet-large-uncased").to(self.device) | |
| self.tokenizer = ProphetNetTokenizer.from_pretrained(model_path) | |
| self.model = ProphetNetForConditionalGeneration.from_pretrained(model_path).to(self.device) | |
| def summarize(self, text): | |
| inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding="longest").to(self.device) | |
| summary_ids = self.model.generate(inputs.input_ids) | |
| summary = self.tokenizer.decode(summary_ids.squeeze(), skip_special_tokens=True) | |
| return summary |