File size: 1,639 Bytes
6b50e97 d11f7ef 9eadd05 6b50e97 f3377bd 9f63629 6b50e97 9eadd05 c9e60c5 9eadd05 75bbfe8 a8af142 9eadd05 e267a07 9735c16 e267a07 9735c16 fb4f1d2 c9e60c5 fb4f1d2 9735c16 a7bbdab a224e26 bcff816 a224e26 9eadd05 6b50e97 bd0f5fe 01b15b8 505f8c9 6b50e97 505f8c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import streamlit as st
import torch
st.title("TextPressoMachine")
from transformers import AutoModelForSeq2SeqLM
from t5_model import T5
from transformers import AutoTokenizer
from transformers import pipeline
models = {
"T5 Small": "ZinebSN/t5_ckpt",
"GPT2": "ZinebSN/GPT2_Summarier"
}
selected_model = st.radio("Select Model", list(models.keys()))
model_name = models[selected_model]
tokenizer = AutoTokenizer.from_pretrained('t5-small')
#model = torch.load(model_name+'/model.pt')
#if model_name == "T5 Small":
#model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
#model = torch.load(model_name+'/model.pt')
checkpoint_path='./t5_epoch9.ckpt'
# Choose the appropriate device
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model_state_dict = torch.load(checkpoint_path, map_location=device)
#model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
#model.load_state_dict(model_state_dict)
model=T5.load_from_checkpoint(checkpoint_path, map_location=torch.device('cpu'))
#else:
#model = GPT2().from_pretrained(model_name)
#tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text=st.text_area("Input the text to summarize","")
if st.button("Summarize"):
st.text("It may take a minute or two.")
nwords=len(input_text.split(" "))
text_input_ids=tokenizer('summarize: '+input_text, max_length=600, padding="max_length", truncation=True).input_ids
output_ids = model.generate(torch.tensor(text_input_ids))
generated_summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
st.header("Summary")
st.markdown(generated_summary) |