Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from transformers import BartForConditionalGeneration, BartTokenizer | |
| import PyPDF2 | |
| import requests | |
| from io import StringIO, BytesIO | |
| import nltk | |
| from nltk.tokenize import sent_tokenize | |
| import spacy | |
| import numpy as np | |
| from typing import List, Tuple, Optional | |
| import time | |
| import re | |
| # Download NLTK data | |
| try: | |
| nltk.data.find('tokenizers/punkt') | |
| except LookupError: | |
| nltk.download('punkt') | |
| # Cache the BART model and tokenizer | |
| def load_bart_model(model_name: str = "facebook/bart-large-cnn"): | |
| """Load BART model and tokenizer""" | |
| try: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| st.info(f"Using device: {device}") | |
| tokenizer = BartTokenizer.from_pretrained(model_name) | |
| model = BartForConditionalGeneration.from_pretrained(model_name) | |
| model = model.to(device) | |
| return model, tokenizer, device | |
| except Exception as e: | |
| st.error(f"Error loading model: {str(e)}") | |
| return None, None, None | |
| # Text preprocessing functions | |
| def clean_text(text: str) -> str: | |
| """Clean and preprocess text""" | |
| # Remove extra whitespace | |
| text = re.sub(r'\s+', ' ', text) | |
| # Remove special characters but keep basic punctuation | |
| text = re.sub(r'[^\w\s.,!?;:]', ' ', text) | |
| return text.strip() | |
| def split_into_chunks(text: str, max_chunk_length: int = 1024) -> List[str]: | |
| """Split text into chunks for processing""" | |
| sentences = sent_tokenize(text) | |
| chunks = [] | |
| current_chunk = "" | |
| for sentence in sentences: | |
| if len(current_chunk) + len(sentence) < max_chunk_length: | |
| current_chunk += " " + sentence | |
| else: | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| current_chunk = sentence | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| return chunks | |
| # Text extraction functions | |
| def extract_text_from_pdf(pdf_file) -> str: | |
| """Extract text from PDF file""" | |
| text = "" | |
| try: | |
| pdf_reader = PyPDF2.PdfReader(pdf_file) | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() + "\n" | |
| except Exception as e: | |
| st.error(f"Error reading PDF: {str(e)}") | |
| return clean_text(text) | |
| def extract_text_from_url(url: str) -> str: | |
| """Extract text from Wikipedia or other web pages""" | |
| try: | |
| headers = { | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' | |
| } | |
| response = requests.get(url, headers=headers, timeout=10) | |
| if response.status_code == 200: | |
| # Simple HTML stripping | |
| text = re.sub(r'<[^>]+>', ' ', response.text) | |
| text = re.sub(r'\s+', ' ', text) | |
| return clean_text(text) | |
| else: | |
| st.error(f"Failed to fetch URL: Status {response.status_code}") | |
| return "" | |
| except Exception as e: | |
| st.error(f"Error fetching URL: {str(e)}") | |
| return "" | |
| # Summarization functions | |
| def summarize_with_bart( | |
| text: str, | |
| model, | |
| tokenizer, | |
| device: str, | |
| max_length: int = 150, | |
| min_length: int = 40, | |
| do_sample: bool = False | |
| ) -> str: | |
| """Summarize text using BART model""" | |
| if not text or len(text.split()) < 10: | |
| return text # Return original if too short | |
| try: | |
| # Split text into chunks if too long | |
| chunks = split_into_chunks(text, max_chunk_length=1000) | |
| summaries = [] | |
| for chunk in chunks: | |
| inputs = tokenizer( | |
| chunk, | |
| max_length=1024, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).to(device) | |
| # Generate summary | |
| summary_ids = model.generate( | |
| inputs["input_ids"], | |
| max_length=max_length, | |
| min_length=min_length, | |
| length_penalty=2.0, | |
| num_beams=4, | |
| early_stopping=True, | |
| do_sample=do_sample | |
| ) | |
| summary = tokenizer.decode( | |
| summary_ids[0], | |
| skip_special_tokens=True | |
| ) | |
| summaries.append(summary) | |
| # Combine chunk summaries | |
| combined_summary = " ".join(summaries) | |
| # If combined summary is still too long, summarize it again | |
| if len(combined_summary.split()) > 200: | |
| inputs = tokenizer( | |
| combined_summary, | |
| max_length=1024, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).to(device) | |
| final_summary_ids = model.generate( | |
| inputs["input_ids"], | |
| max_length=max_length, | |
| min_length=min_length, | |
| length_penalty=2.0, | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| final_summary = tokenizer.decode( | |
| final_summary_ids[0], | |
| skip_special_tokens=True | |
| ) | |
| return final_summary | |
| return combined_summary | |
| except Exception as e: | |
| st.error(f"Error during summarization: {str(e)}") | |
| return "" | |
| # Streamlit UI | |
| def main(): | |
| st.set_page_config( | |
| page_title="BART Text Summarizer", | |
| page_icon="π€", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS | |
| st.markdown(""" | |
| <style> | |
| .main-header { | |
| font-size: 2.5rem; | |
| color: #4A90E2; | |
| text-align: center; | |
| margin-bottom: 1rem; | |
| } | |
| .sub-header { | |
| font-size: 1.2rem; | |
| color: #666; | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| } | |
| .stats-card { | |
| background-color: #f0f2f6; | |
| padding: 1rem; | |
| border-radius: 10px; | |
| margin: 0.5rem 0; | |
| } | |
| .summary-box { | |
| border: 2px solid #4A90E2; | |
| border-radius: 10px; | |
| padding: 1rem; | |
| background-color: #f8f9fa; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Header | |
| st.markdown('<h1 class="main-header">π€ BART Text Summarizer</h1>', unsafe_allow_html=True) | |
| st.markdown('<p class="sub-header">Powered by Facebook\'s BART-large-CNN model from Hugging Face</p>', unsafe_allow_html=True) | |
| st.markdown("---") | |
| # Sidebar | |
| with st.sidebar: | |
| st.header("βοΈ Configuration") | |
| # Model selection | |
| model_option = st.selectbox( | |
| "Choose BART model:", | |
| [ | |
| "facebook/bart-large-cnn", | |
| "facebook/bart-large-xsum", | |
| "sshleifer/distilbart-cnn-12-6" | |
| ], | |
| help="BART-large-cnn is best for general summarization" | |
| ) | |
| # Summary length | |
| st.subheader("Summary Settings") | |
| max_length = st.slider( | |
| "Maximum summary length (words)", | |
| min_value=50, | |
| max_value=500, | |
| value=150, | |
| step=10 | |
| ) | |
| min_length = st.slider( | |
| "Minimum summary length (words)", | |
| min_value=10, | |
| max_value=100, | |
| value=40, | |
| step=5 | |
| ) | |
| # Advanced options | |
| with st.expander("Advanced Options"): | |
| do_sample = st.checkbox( | |
| "Use sampling (more creative)", | |
| value=False, | |
| help="When enabled, uses sampling instead of beam search" | |
| ) | |
| num_beams = st.slider( | |
| "Number of beams", | |
| min_value=1, | |
| max_value=8, | |
| value=4, | |
| help="Higher values produce better results but are slower" | |
| ) | |
| st.markdown("---") | |
| # Model info | |
| st.info(""" | |
| **Model Information:** | |
| - BART-large-CNN: Fine-tuned on CNN/Daily Mail | |
| - Parameters: 400 million | |
| - Best for: Article summarization | |
| """) | |
| # Load model button | |
| if st.button("π Load Model", type="secondary"): | |
| with st.spinner("Loading BART model..."): | |
| model, tokenizer, device = load_bart_model(model_option) | |
| if model: | |
| st.success(f"Model loaded successfully on {device}!") | |
| # Main content | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.subheader("π₯ Input Text") | |
| # Input method selection | |
| input_method = st.radio( | |
| "Choose input method:", | |
| ["π Direct Text", "π Upload File", "π Website URL"], | |
| horizontal=True | |
| ) | |
| text_input = "" | |
| if input_method == "π Direct Text": | |
| text_input = st.text_area( | |
| "Enter your text here:", | |
| height=300, | |
| placeholder="Paste or type your text here...", | |
| help="Minimum 100 words for best results" | |
| ) | |
| elif input_method == "π Upload File": | |
| uploaded_file = st.file_uploader( | |
| "Upload a file", | |
| type=['txt', 'pdf', 'docx'], | |
| help="Supports TXT, PDF, and DOCX files" | |
| ) | |
| if uploaded_file: | |
| file_ext = uploaded_file.name.split('.')[-1].lower() | |
| if file_ext == 'pdf': | |
| with st.spinner("Extracting text from PDF..."): | |
| text_input = extract_text_from_pdf(uploaded_file) | |
| elif file_ext == 'txt': | |
| stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) | |
| text_input = stringio.read() | |
| else: | |
| st.warning("Please upload a PDF or TXT file") | |
| elif input_method == "π Website URL": | |
| url = st.text_input( | |
| "Enter URL:", | |
| placeholder="https://en.wikipedia.org/wiki/...", | |
| help="Supports Wikipedia and other websites" | |
| ) | |
| if url and st.button("Fetch Content", type="secondary"): | |
| with st.spinner("Fetching content from URL..."): | |
| text_input = extract_text_from_url(url) | |
| # Display text stats | |
| if text_input: | |
| words = text_input.split() | |
| sentences = sent_tokenize(text_input) | |
| with st.expander("π Text Statistics", expanded=True): | |
| col_a, col_b, col_c = st.columns(3) | |
| with col_a: | |
| st.metric("Words", len(words)) | |
| with col_b: | |
| st.metric("Sentences", len(sentences)) | |
| with col_c: | |
| st.metric("Characters", len(text_input)) | |
| # Preview | |
| with st.expander("π Preview Original Text"): | |
| preview_length = min(500, len(text_input)) | |
| st.text(text_input[:preview_length] + "..." if len(text_input) > preview_length else text_input) | |
| with col2: | |
| st.subheader("π€ Generated Summary") | |
| if text_input and len(text_input.split()) >= 10: | |
| if st.button("π Generate Summary", type="primary", use_container_width=True): | |
| with st.spinner("Generating summary with BART..."): | |
| # Load model if not already loaded | |
| model, tokenizer, device = load_bart_model(model_option) | |
| if model and tokenizer: | |
| # Generate summary | |
| start_time = time.time() | |
| summary = summarize_with_bart( | |
| text_input, | |
| model, | |
| tokenizer, | |
| device, | |
| max_length=max_length, | |
| min_length=min_length, | |
| do_sample=do_sample | |
| ) | |
| processing_time = time.time() - start_time | |
| if summary: | |
| # Display summary | |
| st.markdown('<div class="summary-box">', unsafe_allow_html=True) | |
| st.write(summary) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| # Summary stats | |
| st.markdown("### π Summary Statistics") | |
| col1_stat, col2_stat, col3_stat, col4_stat = st.columns(4) | |
| with col1_stat: | |
| st.metric( | |
| "Summary Words", | |
| len(summary.split()), | |
| delta=f"-{len(text_input.split()) - len(summary.split())}" | |
| ) | |
| with col2_stat: | |
| reduction = ((len(text_input.split()) - len(summary.split())) / len(text_input.split()) * 100) | |
| st.metric( | |
| "Reduction", | |
| f"{reduction:.1f}%" | |
| ) | |
| with col3_stat: | |
| st.metric( | |
| "Processing Time", | |
| f"{processing_time:.2f}s" | |
| ) | |
| with col4_stat: | |
| st.metric( | |
| "Compression Ratio", | |
| f"1:{len(text_input.split())//len(summary.split()) if summary.split() else 0}" | |
| ) | |
| # Download button | |
| st.download_button( | |
| label="π₯ Download Summary", | |
| data=summary, | |
| file_name="bart_summary.txt", | |
| mime="text/plain", | |
| use_container_width=True | |
| ) | |
| # Show sample comparison | |
| with st.expander("π Compare Original vs Summary"): | |
| col_orig, col_sum = st.columns(2) | |
| with col_orig: | |
| st.write("**Original (first 200 words):**") | |
| st.write(" ".join(text_input.split()[:200]) + "...") | |
| with col_sum: | |
| st.write("**Summary:**") | |
| st.write(summary) | |
| else: | |
| st.error("Failed to generate summary. Please try again.") | |
| else: | |
| st.error("Failed to load model. Please check your internet connection.") | |
| elif text_input and len(text_input.split()) < 10: | |
| st.warning("Please enter at least 10 words for summarization") | |
| else: | |
| st.info("π Enter text on the left to generate a summary") | |
| # Example text | |
| with st.expander("π Try with Example Text"): | |
| example_text = """ | |
| Artificial Intelligence (AI) is transforming industries across the globe. | |
| From healthcare to finance, AI algorithms are being deployed to solve complex problems, | |
| automate processes, and generate insights from massive datasets. Machine learning, | |
| a subset of AI, enables computers to learn from data without being explicitly programmed. | |
| Deep learning, powered by neural networks, has achieved remarkable success in areas like | |
| image recognition, natural language processing, and autonomous vehicles. | |
| However, AI also raises important ethical considerations around bias, privacy, | |
| and job displacement. As AI continues to evolve, it's crucial to develop responsible | |
| AI frameworks that ensure these technologies benefit society while mitigating potential risks. | |
| The future of AI holds tremendous promise, but requires careful stewardship and collaboration | |
| between technologists, policymakers, and the public. | |
| """ | |
| if st.button("Load Example Text"): | |
| st.session_state.example_loaded = example_text | |
| st.rerun() | |
| # Footer | |
| st.markdown("---") | |
| col_footer1, col_footer2, col_footer3 = st.columns(3) | |
| with col_footer1: | |
| st.markdown("**Powered by:**") | |
| st.markdown("[](https://huggingface.co)") | |
| with col_footer2: | |
| st.markdown("**Model:**") | |
| st.markdown("[BART-large-CNN](https://huggingface.co/facebook/bart-large-cnn)") | |
| with col_footer3: | |
| st.markdown("**Built with:**") | |
| st.markdown("[Streamlit](https://streamlit.io) | [PyTorch](https://pytorch.org)") | |
| st.caption("Β© 2024 BART Summarizer | Deploy your own on [Hugging Face Spaces](https://huggingface.co/spaces)") | |
| if __name__ == "__main__": | |
| main() |