mojad121's picture
Update src/streamlit_app.py
e48c184 verified
import streamlit as st
import os
import time
from datetime import datetime
from datasets import load_dataset
from PyPDF2 import PdfReader
from transformers import pipeline
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
import faiss
import numpy as np
import os
os.makedirs("/app/cache", exist_ok=True)
os.makedirs("/.streamlit", exist_ok=True)
st.title("Document Summarization")
@st.cache_resource
def load_models():
embedder = SentenceTransformer("all-MiniLM-L6-v2")
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
return embedder, summarizer
embedder, summarizer = load_models()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50, length_function=len)
def process_file(uploaded_file):
if uploaded_file.name.endswith('.pdf'):
text = ""
reader = PdfReader(uploaded_file)
for page in reader.pages:
text += page.extract_text() or ""
elif uploaded_file.name.endswith(('.txt', '.md')):
text = uploaded_file.read().decode("utf-8")
else:
raise ValueError("Unsupported file format. Use PDF, TXT, or Markdown.")
return text
st.header("View Summaries of 3 Dataset Documents")
dataset = load_dataset("cnn_dailymail", "3.0.0")
train_samples = dataset["train"].select(range(3))
for i, sample in enumerate(train_samples, 1):
context = sample["article"]
reference = sample["highlights"]
with st.expander(f"Document {i}"):
start_time = time.time()
summary = summarizer(context, max_length=150)[0]['summary_text']
latency = time.time() - start_time
st.subheader("Generated Summary")
st.write(summary)
st.subheader("Reference Summary")
st.write(reference)
st.text(f"Latency: {latency:.2f}s | Input tokens: {len(context.split())} | Output tokens: {len(summary.split())}")
st.header("Upload Your Own Document")
input_text = st.text_area("Or paste your text here")
uploaded_file = st.file_uploader("Upload a PDF, TXT, or Markdown file")
if st.button("Summarize"):
try:
if uploaded_file:
text = process_file(uploaded_file)
elif input_text.strip():
text = input_text.strip()
else:
st.warning("Please provide text input or upload a file.")
st.stop()
chunks = text_splitter.split_text(text)
chunk_embeddings = embedder.encode(chunks)
index = faiss.IndexFlatL2(chunk_embeddings.shape[1])
index.add(chunk_embeddings)
query_embedding = embedder.encode(["Summarize this document"])
_, indices = index.search(query_embedding, 3)
retrieved_chunks = [chunks[i] for i in indices[0]]
context = " ".join(retrieved_chunks)
start_time = time.time()
summary = summarizer(context, max_length=150)[0]['summary_text']
latency = time.time() - start_time
st.subheader("Generated Summary")
st.write(summary)
st.subheader("Context")
st.write(context)
st.text(f"Latency: {latency:.2f}s | Input tokens: {len(context.split())} | Output tokens: {len(summary.split())}")
except Exception as e:
st.error(f"Error: {str(e)}")