File size: 2,074 Bytes
98dc5b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from transformers import T5ForConditionalGeneration,T5Tokenizer
import random
import numpy as np
import nltk

nltk.download('punkt')
nltk.download('brown')
nltk.download('wordnet')

from nltk.corpus import wordnet as wn
from nltk.tokenize import sent_tokenize

import locale
locale.getpreferredencoding = lambda: "UTF-8"

class Summarizer:
    def __init__(self):
        self.model = T5ForConditionalGeneration.from_pretrained('t5-base')
        self.tokenizer = T5Tokenizer.from_pretrained('t5-base')
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)
        self.set_seed(42)

    def set_seed(self, seed: int):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    def postprocesstext(self, content):
        final=""
        for sent in sent_tokenize(content):
            sent = sent.capitalize()
            final = final +" "+sent
        return final


    def summarizer(self, text, model = None, tokenizer = None):
        if(model == None):
            model = self.model
        if(tokenizer == None):
            tokenizer = self.tokenizer
        text = text.strip().replace("\n"," ")
        text = "summarize: "+text
        max_len = 512
        encoding = tokenizer.encode_plus(text,max_length=max_len, pad_to_max_length=False,truncation=True, return_tensors="pt").to(self.device)

        input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]

        outs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            early_stopping=True,
            num_beams=3,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            min_length = 75,
            max_length=300
        )

        dec = [tokenizer.decode(ids,skip_special_tokens=True) for ids in outs]
        summary = dec[0]
        summary = self.postprocesstext(summary)
        summary= summary.strip()

        return summary