Summarizer / app.py
mahirmasud's picture
Update app.py
7fe0389 verified
import streamlit as st
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
import time
# --- PAGE CONFIG ---
st.set_page_config(page_title="AI Summarizer", page_icon="πŸ“", layout="centered")
# --- CUSTOM CSS THEME ---
st.markdown("""
<style>
.stApp {
background-color: #0E1117;
color: #FFFFFF;
}
.stTextArea textarea {
background-color: #1B1D21 !important;
color: #58a6ff !important;
border: 1px solid #005fb8 !important;
}
.stButton>button {
background-color: #005fb8;
color: white;
border-radius: 5px;
border: none;
width: 100%;
font-weight: bold;
transition: 0.3s;
}
.stButton>button:hover {
background-color: #58a6ff;
color: black;
box-shadow: 0px 0px 15px #58a6ff;
}
h1, h2, h3 {
color: #58a6ff !important;
}
</style>
""", unsafe_allow_html=True)
# --- SPEED-OPTIMIZED MODEL LOADING ---
@st.cache_resource
def load_model():
model_name = "mahirmasud/t5-summarizer"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
# --- OPTIMIZATION START ---
# 1. Move to GPU if available, else optimize for CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
# 2. Dynamic Quantization: Makes the model ~2x-3x faster on CPU
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
else:
model = model.to(device)
# 3. Use Half Precision if on GPU
model = model.half()
model.eval()
# --- OPTIMIZATION END ---
return tokenizer, model, device
tokenizer, model, device = load_model()
# --- UI ---
st.title("πŸ“ AI News Summarizer")
# Personalized AI Greeting
with st.chat_message("assistant", avatar="πŸ€–"):
st.write("Hello! I am your **AI Summarizer**. πŸš€")
st.write("I'm now running on an **optimized engine** to give you faster responses. Paste any article below!")
# Input Area
st.markdown("### πŸ“₯ Input Article")
input_text = st.text_area(
label="Paste your text here:",
height=250,
placeholder="Paste the news content here...",
label_visibility="collapsed"
)
# Execution Logic
if st.button("✨ Generate Faster Summary"):
if not input_text.strip():
st.warning("I need some text to work with!")
else:
start_time = time.time() # Track speed
with st.status("πŸš€ Processing with high-speed engine...", expanded=True) as status:
# Use inference_mode for maximum speed
with torch.inference_mode():
text = "summarize: " + input_text
inputs = tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True).to(device)
# Optimized generation parameters
outputs = model.generate(
inputs,
max_length=224,
num_beams=4, # Reduced from 4 to 2 for speed (minimal quality loss)
repetition_penalty=2.5,
length_penalty=1.0,
early_stopping=True,
use_cache=True # Uses previous hidden states to speed up generation
)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
end_time = time.time()
duration = round(end_time - start_time, 2)
status.update(label=f"βœ… Done in {duration}s", state="complete", expanded=False)
# Display Results
st.markdown("---")
st.subheader("🎯 The Bottom Line")
st.info(summary)
st.caption(f"Engine: {device.upper()} | Speed: {duration} seconds")