import streamlit as st from transformers import T5ForConditionalGeneration, T5Tokenizer import torch import time # --- PAGE CONFIG --- st.set_page_config(page_title="AI Summarizer", page_icon="📝", layout="centered") # --- CUSTOM CSS THEME --- st.markdown(""" """, unsafe_allow_html=True) # --- SPEED-OPTIMIZED MODEL LOADING --- @st.cache_resource def load_model(): model_name = "mahirmasud/t5-summarizer" tokenizer = T5Tokenizer.from_pretrained(model_name) model = T5ForConditionalGeneration.from_pretrained(model_name) # --- OPTIMIZATION START --- # 1. Move to GPU if available, else optimize for CPU device = "cuda" if torch.cuda.is_available() else "cpu" if device == "cpu": # 2. Dynamic Quantization: Makes the model ~2x-3x faster on CPU model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 ) else: model = model.to(device) # 3. Use Half Precision if on GPU model = model.half() model.eval() # --- OPTIMIZATION END --- return tokenizer, model, device tokenizer, model, device = load_model() # --- UI --- st.title("📝 AI News Summarizer") # Personalized AI Greeting with st.chat_message("assistant", avatar="🤖"): st.write("Hello! I am your **AI Summarizer**. 🚀") st.write("I'm now running on an **optimized engine** to give you faster responses. Paste any article below!") # Input Area st.markdown("### 📥 Input Article") input_text = st.text_area( label="Paste your text here:", height=250, placeholder="Paste the news content here...", label_visibility="collapsed" ) # Execution Logic if st.button("✨ Generate Faster Summary"): if not input_text.strip(): st.warning("I need some text to work with!") else: start_time = time.time() # Track speed with st.status("🚀 Processing with high-speed engine...", expanded=True) as status: # Use inference_mode for maximum speed with torch.inference_mode(): text = "summarize: " + input_text inputs = tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True).to(device) # Optimized generation parameters outputs = model.generate( inputs, max_length=224, num_beams=4, # Reduced from 4 to 2 for speed (minimal quality loss) repetition_penalty=2.5, length_penalty=1.0, early_stopping=True, use_cache=True # Uses previous hidden states to speed up generation ) summary = tokenizer.decode(outputs[0], skip_special_tokens=True) end_time = time.time() duration = round(end_time - start_time, 2) status.update(label=f"✅ Done in {duration}s", state="complete", expanded=False) # Display Results st.markdown("---") st.subheader("🎯 The Bottom Line") st.info(summary) st.caption(f"Engine: {device.upper()} | Speed: {duration} seconds")