import streamlit as st import torch from transformers import BartForConditionalGeneration, BartTokenizer import PyPDF2 import requests from io import StringIO, BytesIO import nltk from nltk.tokenize import sent_tokenize import spacy import numpy as np from typing import List, Tuple, Optional import time import re # Download NLTK data try: nltk.data.find('tokenizers/punkt') except LookupError: nltk.download('punkt') # Cache the BART model and tokenizer @st.cache_resource def load_bart_model(model_name: str = "facebook/bart-large-cnn"): """Load BART model and tokenizer""" try: device = "cuda" if torch.cuda.is_available() else "cpu" st.info(f"Using device: {device}") tokenizer = BartTokenizer.from_pretrained(model_name) model = BartForConditionalGeneration.from_pretrained(model_name) model = model.to(device) return model, tokenizer, device except Exception as e: st.error(f"Error loading model: {str(e)}") return None, None, None # Text preprocessing functions def clean_text(text: str) -> str: """Clean and preprocess text""" # Remove extra whitespace text = re.sub(r'\s+', ' ', text) # Remove special characters but keep basic punctuation text = re.sub(r'[^\w\s.,!?;:]', ' ', text) return text.strip() def split_into_chunks(text: str, max_chunk_length: int = 1024) -> List[str]: """Split text into chunks for processing""" sentences = sent_tokenize(text) chunks = [] current_chunk = "" for sentence in sentences: if len(current_chunk) + len(sentence) < max_chunk_length: current_chunk += " " + sentence else: if current_chunk: chunks.append(current_chunk.strip()) current_chunk = sentence if current_chunk: chunks.append(current_chunk.strip()) return chunks # Text extraction functions def extract_text_from_pdf(pdf_file) -> str: """Extract text from PDF file""" text = "" try: pdf_reader = PyPDF2.PdfReader(pdf_file) for page in pdf_reader.pages: text += page.extract_text() + "\n" except Exception as e: st.error(f"Error reading PDF: {str(e)}") return clean_text(text) def extract_text_from_url(url: str) -> str: """Extract text from Wikipedia or other web pages""" try: headers = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' } response = requests.get(url, headers=headers, timeout=10) if response.status_code == 200: # Simple HTML stripping text = re.sub(r'<[^>]+>', ' ', response.text) text = re.sub(r'\s+', ' ', text) return clean_text(text) else: st.error(f"Failed to fetch URL: Status {response.status_code}") return "" except Exception as e: st.error(f"Error fetching URL: {str(e)}") return "" # Summarization functions def summarize_with_bart( text: str, model, tokenizer, device: str, max_length: int = 150, min_length: int = 40, do_sample: bool = False ) -> str: """Summarize text using BART model""" if not text or len(text.split()) < 10: return text # Return original if too short try: # Split text into chunks if too long chunks = split_into_chunks(text, max_chunk_length=1000) summaries = [] for chunk in chunks: inputs = tokenizer( chunk, max_length=1024, truncation=True, return_tensors="pt" ).to(device) # Generate summary summary_ids = model.generate( inputs["input_ids"], max_length=max_length, min_length=min_length, length_penalty=2.0, num_beams=4, early_stopping=True, do_sample=do_sample ) summary = tokenizer.decode( summary_ids[0], skip_special_tokens=True ) summaries.append(summary) # Combine chunk summaries combined_summary = " ".join(summaries) # If combined summary is still too long, summarize it again if len(combined_summary.split()) > 200: inputs = tokenizer( combined_summary, max_length=1024, truncation=True, return_tensors="pt" ).to(device) final_summary_ids = model.generate( inputs["input_ids"], max_length=max_length, min_length=min_length, length_penalty=2.0, num_beams=4, early_stopping=True ) final_summary = tokenizer.decode( final_summary_ids[0], skip_special_tokens=True ) return final_summary return combined_summary except Exception as e: st.error(f"Error during summarization: {str(e)}") return "" # Streamlit UI def main(): st.set_page_config( page_title="BART Text Summarizer", page_icon="🤖", layout="wide", initial_sidebar_state="expanded" ) # Custom CSS st.markdown(""" """, unsafe_allow_html=True) # Header st.markdown('
Powered by Facebook\'s BART-large-CNN model from Hugging Face
', unsafe_allow_html=True) st.markdown("---") # Sidebar with st.sidebar: st.header("⚙️ Configuration") # Model selection model_option = st.selectbox( "Choose BART model:", [ "facebook/bart-large-cnn", "facebook/bart-large-xsum", "sshleifer/distilbart-cnn-12-6" ], help="BART-large-cnn is best for general summarization" ) # Summary length st.subheader("Summary Settings") max_length = st.slider( "Maximum summary length (words)", min_value=50, max_value=500, value=150, step=10 ) min_length = st.slider( "Minimum summary length (words)", min_value=10, max_value=100, value=40, step=5 ) # Advanced options with st.expander("Advanced Options"): do_sample = st.checkbox( "Use sampling (more creative)", value=False, help="When enabled, uses sampling instead of beam search" ) num_beams = st.slider( "Number of beams", min_value=1, max_value=8, value=4, help="Higher values produce better results but are slower" ) st.markdown("---") # Model info st.info(""" **Model Information:** - BART-large-CNN: Fine-tuned on CNN/Daily Mail - Parameters: 400 million - Best for: Article summarization """) # Load model button if st.button("🔄 Load Model", type="secondary"): with st.spinner("Loading BART model..."): model, tokenizer, device = load_bart_model(model_option) if model: st.success(f"Model loaded successfully on {device}!") # Main content col1, col2 = st.columns([1, 1]) with col1: st.subheader("📥 Input Text") # Input method selection input_method = st.radio( "Choose input method:", ["📝 Direct Text", "📄 Upload File", "🌐 Website URL"], horizontal=True ) text_input = "" if input_method == "📝 Direct Text": text_input = st.text_area( "Enter your text here:", height=300, placeholder="Paste or type your text here...", help="Minimum 100 words for best results" ) elif input_method == "📄 Upload File": uploaded_file = st.file_uploader( "Upload a file", type=['txt', 'pdf', 'docx'], help="Supports TXT, PDF, and DOCX files" ) if uploaded_file: file_ext = uploaded_file.name.split('.')[-1].lower() if file_ext == 'pdf': with st.spinner("Extracting text from PDF..."): text_input = extract_text_from_pdf(uploaded_file) elif file_ext == 'txt': stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) text_input = stringio.read() else: st.warning("Please upload a PDF or TXT file") elif input_method == "🌐 Website URL": url = st.text_input( "Enter URL:", placeholder="https://en.wikipedia.org/wiki/...", help="Supports Wikipedia and other websites" ) if url and st.button("Fetch Content", type="secondary"): with st.spinner("Fetching content from URL..."): text_input = extract_text_from_url(url) # Display text stats if text_input: words = text_input.split() sentences = sent_tokenize(text_input) with st.expander("📊 Text Statistics", expanded=True): col_a, col_b, col_c = st.columns(3) with col_a: st.metric("Words", len(words)) with col_b: st.metric("Sentences", len(sentences)) with col_c: st.metric("Characters", len(text_input)) # Preview with st.expander("🔍 Preview Original Text"): preview_length = min(500, len(text_input)) st.text(text_input[:preview_length] + "..." if len(text_input) > preview_length else text_input) with col2: st.subheader("📤 Generated Summary") if text_input and len(text_input.split()) >= 10: if st.button("🚀 Generate Summary", type="primary", use_container_width=True): with st.spinner("Generating summary with BART..."): # Load model if not already loaded model, tokenizer, device = load_bart_model(model_option) if model and tokenizer: # Generate summary start_time = time.time() summary = summarize_with_bart( text_input, model, tokenizer, device, max_length=max_length, min_length=min_length, do_sample=do_sample ) processing_time = time.time() - start_time if summary: # Display summary st.markdown('