File size: 2,074 Bytes
98dc5b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
from transformers import T5ForConditionalGeneration,T5Tokenizer
import random
import numpy as np
import nltk
nltk.download('punkt')
nltk.download('brown')
nltk.download('wordnet')
from nltk.corpus import wordnet as wn
from nltk.tokenize import sent_tokenize
import locale
locale.getpreferredencoding = lambda: "UTF-8"
class Summarizer:
def __init__(self):
self.model = T5ForConditionalGeneration.from_pretrained('t5-base')
self.tokenizer = T5Tokenizer.from_pretrained('t5-base')
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
self.set_seed(42)
def set_seed(self, seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def postprocesstext(self, content):
final=""
for sent in sent_tokenize(content):
sent = sent.capitalize()
final = final +" "+sent
return final
def summarizer(self, text, model = None, tokenizer = None):
if(model == None):
model = self.model
if(tokenizer == None):
tokenizer = self.tokenizer
text = text.strip().replace("\n"," ")
text = "summarize: "+text
max_len = 512
encoding = tokenizer.encode_plus(text,max_length=max_len, pad_to_max_length=False,truncation=True, return_tensors="pt").to(self.device)
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
outs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
early_stopping=True,
num_beams=3,
num_return_sequences=1,
no_repeat_ngram_size=2,
min_length = 75,
max_length=300
)
dec = [tokenizer.decode(ids,skip_special_tokens=True) for ids in outs]
summary = dec[0]
summary = self.postprocesstext(summary)
summary= summary.strip()
return summary
|