File size: 2,528 Bytes
57f1d8d
 
 
 
 
 
 
 
 
 
9c88b5e
 
 
001f0b6
 
57f1d8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c1087e
57f1d8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import streamlit as st
import torch
import torch.nn.functional as F

# Import GPT2 Model and Tokenizer
from transformers import GPT2Tokenizer, GPT2LMHeadModel

# Import T5 Model and Tokenizer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

st.title("Text Presso Machine ☕️")





models = {
    "T5 Small": "ZinebSN/T5_summarizer",
    "GPT2": "ZinebSN/GPT2_summarizer"
}

selected_model = st.radio("Select Model", list(models.keys()))

model_name = models[selected_model]

if selected_model=='GPT2':
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    model = GPT2LMHeadModel.from_pretrained(model_name)
else:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    
# Inference function for GPT2
def gpt2_summarize(input_text, tokenizer, model, length):
    text=tokenizer.encode_plus(f'<bos> {input_text} <sep>', truncation=True, max_length=1024).input_ids
    text_length=len(text)
    text = torch.tensor(text, dtype=torch.long)
    text = text.unsqueeze(0)
    generated = text
    with torch.no_grad():
        for _ in range(length):
            inputs = {'input_ids': generated}
            outputs = model(**inputs)
            next_token_logits = outputs[0][0, -1, :]
            next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1)
            generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
            generated=generated[:, -1024:]            
    generated = generated[0, text_length:]
    text = tokenizer.convert_ids_to_tokens(generated,skip_special_tokens=True)
    text = tokenizer.convert_tokens_to_string(text)
    return text

# Inference function for T5
def t5_summarize(input_text, tokenizer, model):
    inputs=tokenizer('summarize: '+input_text, truncation=True, padding='max_length', max_length=600, return_tensors='pt')
    output_sequence=model.generate(input_ids=inputs["input_ids"],attention_mask=inputs["attention_mask"], max_new_tokens=100)
    summary = tokenizer.batch_decode(output_sequence, skip_special_tokens=True)
    return summary[0]
    

input_text=st.text_area("Input the text to summarize","", height=300)
if st.button("Summarize"):
    st.text("It may take a minute or two.")
    nwords=len(input_text.split(" "))
    if selected_model=='GPT2':
        summary=gpt2_summarize(input_text, tokenizer, model, 30)
    else:
        summary=t5_summarize(input_text, tokenizer, model)
    
    st.header("Summary")
    st.markdown(summary)