Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from transformers import PegasusTokenizer, PegasusForConditionalGeneration | |
| def load_summarizers(): | |
| models = { | |
| "Pegasus": "google/pegasus-cnn_dailymail", | |
| "T5": "Overglitch/t5-small-cnn-dailymail", | |
| "BART": "facebook/bart-large-cnn", | |
| } | |
| summarizers = {} | |
| for model_name, model_path in models.items(): | |
| if model_name == "Pegasus": | |
| tokenizer = PegasusTokenizer.from_pretrained(model_path) | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to("cuda" if torch.cuda.is_available() else "cpu") | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| summarizers[model_name] = (model, tokenizer) | |
| return summarizers | |
| def abstractive_summary(summarizers, model_name, text, max_length, num_beams): | |
| model, tokenizer = summarizers[model_name] | |
| inputs = tokenizer( | |
| text, return_tensors="pt", max_length=1024, truncation=True | |
| ).to(model.device) | |
| outputs = model.generate( | |
| inputs["input_ids"], | |
| max_length=max_length, | |
| num_beams=num_beams, | |
| early_stopping=True, | |
| ) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |