Week3Day2Task / app.py
ms1449's picture
Update app.py
61872d5 verified
raw
history blame
1.4 kB
import streamlit as st
from transformers import pipeline, AutoTokenizer
# Define the model
model = "facebook/bart-large-cnn"
@st.cache_resource
def load_summarizer():
return pipeline("summarization", model=model)
def get_model_max_length(model_name):
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
return tokenizer.model_max_length
except:
return 1024 # default value if unable to determine
def generate_summary(text):
summarizer = load_summarizer()
max_input_length = min(summarizer.tokenizer.model_max_length, 1024)
truncated_text = summarizer.tokenizer.decode(
summarizer.tokenizer.encode(text, truncation=True, max_length=max_input_length)
)
summary = summarizer(truncated_text, max_new_tokens=150, min_new_tokens=40, do_sample=False)[0]['summary_text']
return summary
st.title("Text Summarization Tool")
st.write("Using the BART-large-CNN model.")
input_text = st.text_area("Enter the text:", height=200)
if st.button("Generate Summary"):
if input_text:
with st.spinner("Generating summary..."):
summary = generate_summary(input_text)
st.subheader("Summary:")
st.write(summary)
else:
st.warning("Please enter some text to summarize.")
st.write("---")
st.write("Model: facebook/bart-large-cnn")
st.write(f"Max input length: {get_model_max_length(model)}")