| import torch | |
| from transformers import T5Tokenizer, T5ForConditionalGeneration | |
| import random | |
| import numpy as np | |
| from nltk.tokenize import sent_tokenize | |
| class T5_Base: | |
| def __init__(self,path,device,model_max_length): | |
| self.model=T5ForConditionalGeneration.from_pretrained(path) | |
| self.tokenizer=T5Tokenizer.from_pretrained(path,model_max_length=model_max_length) | |
| self.device=torch.device(device) | |
| def set_seed(seed): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| def preprocess(self,data): | |
| preprocess_text=data.strip().replace('\n',' ') | |
| return preprocess_text | |
| def post_process(self,data): | |
| final="" | |
| for sent in sent_tokenize(data): | |
| sent=sent.capitalize() | |
| final+=sent+" "+sent | |
| return final | |
| def getSummary(self,data): | |
| data=self.preprocess(data) | |
| t5_prepared_Data="summarize: "+data | |
| tokenized_text=self.tokenizer.encode_plus(t5_prepared_Data,max_length=512,pad_to_max_length=False,truncation=True,return_tensors='pt').to(self.device) | |
| input_ids,attention_mask=tokenized_text['input_ids'],tokenized_text['attention_mask'] | |
| summary_ids=self.model.generate(input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| early_stopping=True, | |
| num_beams=3, | |
| num_return_sequences=1, | |
| no_repeat_ngram_size=2, | |
| min_length = 75, | |
| max_length=300) | |
| output=[self.tokenizer.decode(ids,skip_special_tokens=True) for ids in summary_ids] | |
| summary=output[0] | |
| summary=self.post_process(summary) | |
| summary=summary.strip() | |
| return summary | |