File size: 1,639 Bytes
6b50e97
d11f7ef
9eadd05
6b50e97
f3377bd
9f63629
6b50e97
 
9eadd05
c9e60c5
9eadd05
 
 
 
 
75bbfe8
a8af142
9eadd05
e267a07
9735c16
e267a07
9735c16
fb4f1d2
 
 
c9e60c5
 
 
 
fb4f1d2
 
9735c16
a7bbdab
a224e26
bcff816
a224e26
9eadd05
6b50e97
 
 
 
bd0f5fe
01b15b8
505f8c9
 
6b50e97
505f8c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import streamlit as st
import torch
st.title("TextPressoMachine")

from transformers import AutoModelForSeq2SeqLM
from t5_model import T5
from transformers import AutoTokenizer
from transformers import pipeline
models = {
    "T5 Small": "ZinebSN/t5_ckpt",
    "GPT2": "ZinebSN/GPT2_Summarier"
}

selected_model = st.radio("Select Model", list(models.keys()))
model_name = models[selected_model]
tokenizer = AutoTokenizer.from_pretrained('t5-small')
#model = torch.load(model_name+'/model.pt')

#if model_name == "T5 Small":
#model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    #model = torch.load(model_name+'/model.pt')
checkpoint_path='./t5_epoch9.ckpt'


# Choose the appropriate device
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model_state_dict = torch.load(checkpoint_path, map_location=device)
#model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
#model.load_state_dict(model_state_dict)


model=T5.load_from_checkpoint(checkpoint_path, map_location=torch.device('cpu'))
    
#else:
    #model = GPT2().from_pretrained(model_name)
    #tokenizer = AutoTokenizer.from_pretrained(model_name)

input_text=st.text_area("Input the text to summarize","")
if st.button("Summarize"):
    st.text("It may take a minute or two.")
    nwords=len(input_text.split(" "))
    text_input_ids=tokenizer('summarize: '+input_text, max_length=600, padding="max_length", truncation=True).input_ids
    output_ids = model.generate(torch.tensor(text_input_ids))
    generated_summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    
    st.header("Summary")
    st.markdown(generated_summary)