hackathon_demo / app.py
kukr3207's picture
Upload 2 files
ddabd78
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelWithLMHead
@st.cache_data(show_spinner=False)
def load_model():
with st.spinner('Loading the model, please wait...'):
tokenizer=AutoTokenizer.from_pretrained('T5-base')
model=AutoModelWithLMHead.from_pretrained('T5-base', return_dict=True)
st.success("Done")
return tokenizer,model
def summarize(input, tokenizer, model):
inputs = tokenizer.encode("sumarize: " +input,return_tensors='pt', max_length=1024, truncation=True)
output = model.generate(inputs, min_length=80, max_length=1024)
summary = tokenizer.decode(output[0])
return summary
def process_input_data(input_data, tokenizer, model):
output = summarize(input_data, tokenizer, model)
output_text = f"Summarized text :\n\n {output}"
return output_text
def main():
tokenizer,model = load_model()
st.title("Text Summarizer")
input_data = st.text_area("Enter input data",height=200)
if st.button("Submit"):
with st.spinner("Processing the input data"):
output_text = process_input_data(input_data,tokenizer,model)
st.success(output_text)
if __name__ == "__main__":
main()