| import streamlit as st
|
| import torch
|
| from transformers import BartForConditionalGeneration, BartTokenizer
|
| import PyPDF2
|
| import requests
|
| from io import StringIO, BytesIO
|
| import nltk
|
| from nltk.tokenize import sent_tokenize
|
| import spacy
|
| import numpy as np
|
| from typing import List, Tuple, Optional
|
| import time
|
| import re
|
|
|
|
|
| try:
|
| nltk.data.find('tokenizers/punkt')
|
| except LookupError:
|
| nltk.download('punkt')
|
|
|
|
|
| @st.cache_resource
|
| def load_bart_model(model_name: str = "facebook/bart-large-cnn"):
|
| """Load BART model and tokenizer"""
|
| try:
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
| st.info(f"Using device: {device}")
|
|
|
| tokenizer = BartTokenizer.from_pretrained(model_name)
|
| model = BartForConditionalGeneration.from_pretrained(model_name)
|
| model = model.to(device)
|
|
|
| return model, tokenizer, device
|
| except Exception as e:
|
| st.error(f"Error loading model: {str(e)}")
|
| return None, None, None
|
|
|
|
|
| def clean_text(text: str) -> str:
|
| """Clean and preprocess text"""
|
|
|
| text = re.sub(r'\s+', ' ', text)
|
|
|
| text = re.sub(r'[^\w\s.,!?;:]', ' ', text)
|
| return text.strip()
|
|
|
| def split_into_chunks(text: str, max_chunk_length: int = 1024) -> List[str]:
|
| """Split text into chunks for processing"""
|
| sentences = sent_tokenize(text)
|
| chunks = []
|
| current_chunk = ""
|
|
|
| for sentence in sentences:
|
| if len(current_chunk) + len(sentence) < max_chunk_length:
|
| current_chunk += " " + sentence
|
| else:
|
| if current_chunk:
|
| chunks.append(current_chunk.strip())
|
| current_chunk = sentence
|
|
|
| if current_chunk:
|
| chunks.append(current_chunk.strip())
|
|
|
| return chunks
|
|
|
|
|
| def extract_text_from_pdf(pdf_file) -> str:
|
| """Extract text from PDF file"""
|
| text = ""
|
| try:
|
| pdf_reader = PyPDF2.PdfReader(pdf_file)
|
| for page in pdf_reader.pages:
|
| text += page.extract_text() + "\n"
|
| except Exception as e:
|
| st.error(f"Error reading PDF: {str(e)}")
|
| return clean_text(text)
|
|
|
| def extract_text_from_url(url: str) -> str:
|
| """Extract text from Wikipedia or other web pages"""
|
| try:
|
| headers = {
|
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
|
| }
|
| response = requests.get(url, headers=headers, timeout=10)
|
|
|
| if response.status_code == 200:
|
|
|
| text = re.sub(r'<[^>]+>', ' ', response.text)
|
| text = re.sub(r'\s+', ' ', text)
|
| return clean_text(text)
|
| else:
|
| st.error(f"Failed to fetch URL: Status {response.status_code}")
|
| return ""
|
| except Exception as e:
|
| st.error(f"Error fetching URL: {str(e)}")
|
| return ""
|
|
|
|
|
| def summarize_with_bart(
|
| text: str,
|
| model,
|
| tokenizer,
|
| device: str,
|
| max_length: int = 150,
|
| min_length: int = 40,
|
| do_sample: bool = False
|
| ) -> str:
|
| """Summarize text using BART model"""
|
| if not text or len(text.split()) < 10:
|
| return text
|
|
|
| try:
|
|
|
| chunks = split_into_chunks(text, max_chunk_length=1000)
|
| summaries = []
|
|
|
| for chunk in chunks:
|
| inputs = tokenizer(
|
| chunk,
|
| max_length=1024,
|
| truncation=True,
|
| return_tensors="pt"
|
| ).to(device)
|
|
|
|
|
| summary_ids = model.generate(
|
| inputs["input_ids"],
|
| max_length=max_length,
|
| min_length=min_length,
|
| length_penalty=2.0,
|
| num_beams=4,
|
| early_stopping=True,
|
| do_sample=do_sample
|
| )
|
|
|
| summary = tokenizer.decode(
|
| summary_ids[0],
|
| skip_special_tokens=True
|
| )
|
| summaries.append(summary)
|
|
|
|
|
| combined_summary = " ".join(summaries)
|
|
|
|
|
| if len(combined_summary.split()) > 200:
|
| inputs = tokenizer(
|
| combined_summary,
|
| max_length=1024,
|
| truncation=True,
|
| return_tensors="pt"
|
| ).to(device)
|
|
|
| final_summary_ids = model.generate(
|
| inputs["input_ids"],
|
| max_length=max_length,
|
| min_length=min_length,
|
| length_penalty=2.0,
|
| num_beams=4,
|
| early_stopping=True
|
| )
|
|
|
| final_summary = tokenizer.decode(
|
| final_summary_ids[0],
|
| skip_special_tokens=True
|
| )
|
| return final_summary
|
|
|
| return combined_summary
|
|
|
| except Exception as e:
|
| st.error(f"Error during summarization: {str(e)}")
|
| return ""
|
|
|
|
|
| def main():
|
| st.set_page_config(
|
| page_title="BART Text Summarizer",
|
| page_icon="π€",
|
| layout="wide",
|
| initial_sidebar_state="expanded"
|
| )
|
|
|
|
|
| st.markdown("""
|
| <style>
|
| .main-header {
|
| font-size: 2.5rem;
|
| color: #4A90E2;
|
| text-align: center;
|
| margin-bottom: 1rem;
|
| }
|
| .sub-header {
|
| font-size: 1.2rem;
|
| color: #666;
|
| text-align: center;
|
| margin-bottom: 2rem;
|
| }
|
| .stats-card {
|
| background-color: #f0f2f6;
|
| padding: 1rem;
|
| border-radius: 10px;
|
| margin: 0.5rem 0;
|
| }
|
| .summary-box {
|
| border: 2px solid #4A90E2;
|
| border-radius: 10px;
|
| padding: 1rem;
|
| background-color: #f8f9fa;
|
| }
|
| </style>
|
| """, unsafe_allow_html=True)
|
|
|
|
|
| st.markdown('<h1 class="main-header">π€ BART Text Summarizer</h1>', unsafe_allow_html=True)
|
| st.markdown('<p class="sub-header">Powered by Facebook\'s BART-large-CNN model from Hugging Face</p>', unsafe_allow_html=True)
|
| st.markdown("---")
|
|
|
|
|
| with st.sidebar:
|
| st.header("βοΈ Configuration")
|
|
|
|
|
| model_option = st.selectbox(
|
| "Choose BART model:",
|
| [
|
| "facebook/bart-large-cnn",
|
| "facebook/bart-large-xsum",
|
| "sshleifer/distilbart-cnn-12-6"
|
| ],
|
| help="BART-large-cnn is best for general summarization"
|
| )
|
|
|
|
|
| st.subheader("Summary Settings")
|
| max_length = st.slider(
|
| "Maximum summary length (words)",
|
| min_value=50,
|
| max_value=500,
|
| value=150,
|
| step=10
|
| )
|
|
|
| min_length = st.slider(
|
| "Minimum summary length (words)",
|
| min_value=10,
|
| max_value=100,
|
| value=40,
|
| step=5
|
| )
|
|
|
|
|
| with st.expander("Advanced Options"):
|
| do_sample = st.checkbox(
|
| "Use sampling (more creative)",
|
| value=False,
|
| help="When enabled, uses sampling instead of beam search"
|
| )
|
|
|
| num_beams = st.slider(
|
| "Number of beams",
|
| min_value=1,
|
| max_value=8,
|
| value=4,
|
| help="Higher values produce better results but are slower"
|
| )
|
|
|
| st.markdown("---")
|
|
|
|
|
| st.info("""
|
| **Model Information:**
|
| - BART-large-CNN: Fine-tuned on CNN/Daily Mail
|
| - Parameters: 400 million
|
| - Best for: Article summarization
|
| """)
|
|
|
|
|
| if st.button("π Load Model", type="secondary"):
|
| with st.spinner("Loading BART model..."):
|
| model, tokenizer, device = load_bart_model(model_option)
|
| if model:
|
| st.success(f"Model loaded successfully on {device}!")
|
|
|
|
|
| col1, col2 = st.columns([1, 1])
|
|
|
| with col1:
|
| st.subheader("π₯ Input Text")
|
|
|
|
|
| input_method = st.radio(
|
| "Choose input method:",
|
| ["π Direct Text", "π Upload File", "π Website URL"],
|
| horizontal=True
|
| )
|
|
|
| text_input = ""
|
|
|
| if input_method == "π Direct Text":
|
| text_input = st.text_area(
|
| "Enter your text here:",
|
| height=300,
|
| placeholder="Paste or type your text here...",
|
| help="Minimum 100 words for best results"
|
| )
|
|
|
| elif input_method == "π Upload File":
|
| uploaded_file = st.file_uploader(
|
| "Upload a file",
|
| type=['txt', 'pdf', 'docx'],
|
| help="Supports TXT, PDF, and DOCX files"
|
| )
|
|
|
| if uploaded_file:
|
| file_ext = uploaded_file.name.split('.')[-1].lower()
|
|
|
| if file_ext == 'pdf':
|
| with st.spinner("Extracting text from PDF..."):
|
| text_input = extract_text_from_pdf(uploaded_file)
|
| elif file_ext == 'txt':
|
| stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
|
| text_input = stringio.read()
|
| else:
|
| st.warning("Please upload a PDF or TXT file")
|
|
|
| elif input_method == "π Website URL":
|
| url = st.text_input(
|
| "Enter URL:",
|
| placeholder="https://en.wikipedia.org/wiki/...",
|
| help="Supports Wikipedia and other websites"
|
| )
|
|
|
| if url and st.button("Fetch Content", type="secondary"):
|
| with st.spinner("Fetching content from URL..."):
|
| text_input = extract_text_from_url(url)
|
|
|
|
|
| if text_input:
|
| words = text_input.split()
|
| sentences = sent_tokenize(text_input)
|
|
|
| with st.expander("π Text Statistics", expanded=True):
|
| col_a, col_b, col_c = st.columns(3)
|
| with col_a:
|
| st.metric("Words", len(words))
|
| with col_b:
|
| st.metric("Sentences", len(sentences))
|
| with col_c:
|
| st.metric("Characters", len(text_input))
|
|
|
|
|
| with st.expander("π Preview Original Text"):
|
| preview_length = min(500, len(text_input))
|
| st.text(text_input[:preview_length] + "..." if len(text_input) > preview_length else text_input)
|
|
|
| with col2:
|
| st.subheader("π€ Generated Summary")
|
|
|
| if text_input and len(text_input.split()) >= 10:
|
| if st.button("π Generate Summary", type="primary", use_container_width=True):
|
| with st.spinner("Generating summary with BART..."):
|
|
|
| model, tokenizer, device = load_bart_model(model_option)
|
|
|
| if model and tokenizer:
|
|
|
| start_time = time.time()
|
| summary = summarize_with_bart(
|
| text_input,
|
| model,
|
| tokenizer,
|
| device,
|
| max_length=max_length,
|
| min_length=min_length,
|
| do_sample=do_sample
|
| )
|
| processing_time = time.time() - start_time
|
|
|
| if summary:
|
|
|
| st.markdown('<div class="summary-box">', unsafe_allow_html=True)
|
| st.write(summary)
|
| st.markdown('</div>', unsafe_allow_html=True)
|
|
|
|
|
| st.markdown("### π Summary Statistics")
|
| col1_stat, col2_stat, col3_stat, col4_stat = st.columns(4)
|
|
|
| with col1_stat:
|
| st.metric(
|
| "Summary Words",
|
| len(summary.split()),
|
| delta=f"-{len(text_input.split()) - len(summary.split())}"
|
| )
|
|
|
| with col2_stat:
|
| reduction = ((len(text_input.split()) - len(summary.split())) / len(text_input.split()) * 100)
|
| st.metric(
|
| "Reduction",
|
| f"{reduction:.1f}%"
|
| )
|
|
|
| with col3_stat:
|
| st.metric(
|
| "Processing Time",
|
| f"{processing_time:.2f}s"
|
| )
|
|
|
| with col4_stat:
|
| st.metric(
|
| "Compression Ratio",
|
| f"1:{len(text_input.split())//len(summary.split()) if summary.split() else 0}"
|
| )
|
|
|
|
|
| st.download_button(
|
| label="π₯ Download Summary",
|
| data=summary,
|
| file_name="bart_summary.txt",
|
| mime="text/plain",
|
| use_container_width=True
|
| )
|
|
|
|
|
| with st.expander("π Compare Original vs Summary"):
|
| col_orig, col_sum = st.columns(2)
|
| with col_orig:
|
| st.write("**Original (first 200 words):**")
|
| st.write(" ".join(text_input.split()[:200]) + "...")
|
| with col_sum:
|
| st.write("**Summary:**")
|
| st.write(summary)
|
| else:
|
| st.error("Failed to generate summary. Please try again.")
|
| else:
|
| st.error("Failed to load model. Please check your internet connection.")
|
| elif text_input and len(text_input.split()) < 10:
|
| st.warning("Please enter at least 10 words for summarization")
|
| else:
|
| st.info("π Enter text on the left to generate a summary")
|
|
|
|
|
| with st.expander("π Try with Example Text"):
|
| example_text = """
|
| Artificial Intelligence (AI) is transforming industries across the globe.
|
| From healthcare to finance, AI algorithms are being deployed to solve complex problems,
|
| automate processes, and generate insights from massive datasets. Machine learning,
|
| a subset of AI, enables computers to learn from data without being explicitly programmed.
|
| Deep learning, powered by neural networks, has achieved remarkable success in areas like
|
| image recognition, natural language processing, and autonomous vehicles.
|
| However, AI also raises important ethical considerations around bias, privacy,
|
| and job displacement. As AI continues to evolve, it's crucial to develop responsible
|
| AI frameworks that ensure these technologies benefit society while mitigating potential risks.
|
| The future of AI holds tremendous promise, but requires careful stewardship and collaboration
|
| between technologists, policymakers, and the public.
|
| """
|
| if st.button("Load Example Text"):
|
| st.session_state.example_loaded = example_text
|
| st.rerun()
|
|
|
|
|
| st.markdown("---")
|
|
|
| col_footer1, col_footer2, col_footer3 = st.columns(3)
|
|
|
| with col_footer1:
|
| st.markdown("**Powered by:**")
|
| st.markdown("[](https://huggingface.co)")
|
|
|
| with col_footer2:
|
| st.markdown("**Model:**")
|
| st.markdown("[BART-large-CNN](https://huggingface.co/facebook/bart-large-cnn)")
|
|
|
| with col_footer3:
|
| st.markdown("**Built with:**")
|
| st.markdown("[Streamlit](https://streamlit.io) | [PyTorch](https://pytorch.org)")
|
|
|
| st.caption("Β© 2024 BART Summarizer | Deploy your own on [Hugging Face Spaces](https://huggingface.co/spaces)")
|
|
|
| if __name__ == "__main__":
|
| main() |