|
|
import streamlit as st |
|
|
import os |
|
|
import time |
|
|
from datetime import datetime |
|
|
from datasets import load_dataset |
|
|
from PyPDF2 import PdfReader |
|
|
from transformers import pipeline |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
import faiss |
|
|
import numpy as np |
|
|
import os |
|
|
os.makedirs("/app/cache", exist_ok=True) |
|
|
os.makedirs("/.streamlit", exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
st.title("Document Summarization") |
|
|
|
|
|
@st.cache_resource |
|
|
def load_models(): |
|
|
embedder = SentenceTransformer("all-MiniLM-L6-v2") |
|
|
summarizer = pipeline("summarization", model="facebook/bart-large-cnn") |
|
|
return embedder, summarizer |
|
|
|
|
|
embedder, summarizer = load_models() |
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50, length_function=len) |
|
|
|
|
|
def process_file(uploaded_file): |
|
|
if uploaded_file.name.endswith('.pdf'): |
|
|
text = "" |
|
|
reader = PdfReader(uploaded_file) |
|
|
for page in reader.pages: |
|
|
text += page.extract_text() or "" |
|
|
elif uploaded_file.name.endswith(('.txt', '.md')): |
|
|
text = uploaded_file.read().decode("utf-8") |
|
|
else: |
|
|
raise ValueError("Unsupported file format. Use PDF, TXT, or Markdown.") |
|
|
return text |
|
|
|
|
|
st.header("View Summaries of 3 Dataset Documents") |
|
|
dataset = load_dataset("cnn_dailymail", "3.0.0") |
|
|
train_samples = dataset["train"].select(range(3)) |
|
|
|
|
|
for i, sample in enumerate(train_samples, 1): |
|
|
context = sample["article"] |
|
|
reference = sample["highlights"] |
|
|
with st.expander(f"Document {i}"): |
|
|
start_time = time.time() |
|
|
summary = summarizer(context, max_length=150)[0]['summary_text'] |
|
|
latency = time.time() - start_time |
|
|
|
|
|
st.subheader("Generated Summary") |
|
|
st.write(summary) |
|
|
st.subheader("Reference Summary") |
|
|
st.write(reference) |
|
|
st.text(f"Latency: {latency:.2f}s | Input tokens: {len(context.split())} | Output tokens: {len(summary.split())}") |
|
|
|
|
|
st.header("Upload Your Own Document") |
|
|
input_text = st.text_area("Or paste your text here") |
|
|
uploaded_file = st.file_uploader("Upload a PDF, TXT, or Markdown file") |
|
|
|
|
|
if st.button("Summarize"): |
|
|
try: |
|
|
if uploaded_file: |
|
|
text = process_file(uploaded_file) |
|
|
elif input_text.strip(): |
|
|
text = input_text.strip() |
|
|
else: |
|
|
st.warning("Please provide text input or upload a file.") |
|
|
st.stop() |
|
|
|
|
|
chunks = text_splitter.split_text(text) |
|
|
chunk_embeddings = embedder.encode(chunks) |
|
|
index = faiss.IndexFlatL2(chunk_embeddings.shape[1]) |
|
|
index.add(chunk_embeddings) |
|
|
|
|
|
query_embedding = embedder.encode(["Summarize this document"]) |
|
|
_, indices = index.search(query_embedding, 3) |
|
|
retrieved_chunks = [chunks[i] for i in indices[0]] |
|
|
context = " ".join(retrieved_chunks) |
|
|
|
|
|
start_time = time.time() |
|
|
summary = summarizer(context, max_length=150)[0]['summary_text'] |
|
|
latency = time.time() - start_time |
|
|
|
|
|
st.subheader("Generated Summary") |
|
|
st.write(summary) |
|
|
st.subheader("Context") |
|
|
st.write(context) |
|
|
st.text(f"Latency: {latency:.2f}s | Input tokens: {len(context.split())} | Output tokens: {len(summary.split())}") |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error: {str(e)}") |
|
|
|