import streamlit as st import os import time from datetime import datetime from datasets import load_dataset from PyPDF2 import PdfReader from transformers import pipeline from sentence_transformers import SentenceTransformer from langchain.text_splitter import RecursiveCharacterTextSplitter import faiss import numpy as np import os os.makedirs("/app/cache", exist_ok=True) os.makedirs("/.streamlit", exist_ok=True) st.title("Document Summarization") @st.cache_resource def load_models(): embedder = SentenceTransformer("all-MiniLM-L6-v2") summarizer = pipeline("summarization", model="facebook/bart-large-cnn") return embedder, summarizer embedder, summarizer = load_models() text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50, length_function=len) def process_file(uploaded_file): if uploaded_file.name.endswith('.pdf'): text = "" reader = PdfReader(uploaded_file) for page in reader.pages: text += page.extract_text() or "" elif uploaded_file.name.endswith(('.txt', '.md')): text = uploaded_file.read().decode("utf-8") else: raise ValueError("Unsupported file format. Use PDF, TXT, or Markdown.") return text st.header("View Summaries of 3 Dataset Documents") dataset = load_dataset("cnn_dailymail", "3.0.0") train_samples = dataset["train"].select(range(3)) for i, sample in enumerate(train_samples, 1): context = sample["article"] reference = sample["highlights"] with st.expander(f"Document {i}"): start_time = time.time() summary = summarizer(context, max_length=150)[0]['summary_text'] latency = time.time() - start_time st.subheader("Generated Summary") st.write(summary) st.subheader("Reference Summary") st.write(reference) st.text(f"Latency: {latency:.2f}s | Input tokens: {len(context.split())} | Output tokens: {len(summary.split())}") st.header("Upload Your Own Document") input_text = st.text_area("Or paste your text here") uploaded_file = st.file_uploader("Upload a PDF, TXT, or Markdown file") if st.button("Summarize"): try: if uploaded_file: text = process_file(uploaded_file) elif input_text.strip(): text = input_text.strip() else: st.warning("Please provide text input or upload a file.") st.stop() chunks = text_splitter.split_text(text) chunk_embeddings = embedder.encode(chunks) index = faiss.IndexFlatL2(chunk_embeddings.shape[1]) index.add(chunk_embeddings) query_embedding = embedder.encode(["Summarize this document"]) _, indices = index.search(query_embedding, 3) retrieved_chunks = [chunks[i] for i in indices[0]] context = " ".join(retrieved_chunks) start_time = time.time() summary = summarizer(context, max_length=150)[0]['summary_text'] latency = time.time() - start_time st.subheader("Generated Summary") st.write(summary) st.subheader("Context") st.write(context) st.text(f"Latency: {latency:.2f}s | Input tokens: {len(context.split())} | Output tokens: {len(summary.split())}") except Exception as e: st.error(f"Error: {str(e)}")