ChatWithPDF / app.py
theerasin's picture
Update app.py
2e80e49 verified
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from pydantic import BaseModel, Field
from typing import List
from datetime import datetime
import PyPDF2
from fpdf import FPDF
from docx import Document
import io
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document as LCDocument
from langchain.embeddings import HuggingFaceEmbeddings
import time
# === Load summarization model ===
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
# === Load QA pipeline ===
qa_pipeline = pipeline("question-answering", model="facebook/bart-large-cnn", tokenizer=tokenizer)
# === Setup BGE Embedding model ===
embedding_model_name = "BAAI/bge-large-en-v1.5"
embedding_function = HuggingFaceEmbeddings(model_name=embedding_model_name)
# === Data models ===
class KeyPoint(BaseModel):
point: str = Field(description="A key point extracted from the document.")
class Summary(BaseModel):
summary: str = Field(description="A brief summary of the document content.")
class DocumentAnalysis(BaseModel):
key_points: List[KeyPoint]
summary: Summary
def extract_text_from_pdf(pdf_file):
pdf_reader = PyPDF2.PdfReader(pdf_file)
return "".join(page.extract_text() for page in pdf_reader.pages)
def analyze_text_structured(text):
inputs = tokenizer([text], max_length=1024, truncation=True, return_tensors="pt")
summary_ids = model.generate(
inputs["input_ids"], num_beams=4, length_penalty=2.0,
max_length=200, min_length=50, early_stopping=True
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
key_points = [KeyPoint(point=line.strip()) for line in summary.split(". ") if line.strip()]
return DocumentAnalysis(summary=Summary(summary=summary), key_points=key_points)
def json_to_text(analysis):
text_output = "=== Summary ===\n" + f"{analysis.summary.summary}\n\n"
text_output += "=== Key Points ===\n"
for i, key_point in enumerate(analysis.key_points, start=1):
text_output += f"{i}. {key_point.point}\n"
return text_output
def create_pdf_report(analysis):
pdf = FPDF()
pdf.add_page()
pdf.set_font('Helvetica', '', 12)
pdf.cell(200, 10, txt="PDF Analysis Report", ln=True, align='C')
pdf.cell(200, 10, txt=f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True, align='C')
pdf.multi_cell(0, 10, txt=json_to_text(analysis))
pdf_bytes = io.BytesIO()
pdf.output(pdf_bytes, dest='S')
pdf_bytes.seek(0)
return pdf_bytes.getvalue()
def create_word_report(analysis):
doc = Document()
doc.add_heading('PDF Analysis Report', 0)
doc.add_paragraph(f'Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')
doc.add_heading('Analysis', level=1)
doc.add_paragraph(json_to_text(analysis))
docx_bytes = io.BytesIO()
doc.save(docx_bytes)
docx_bytes.seek(0)
return docx_bytes.getvalue()
# === Streamlit UI ===
st.set_page_config(page_title="Chat With PDF (BART + BGE)", page_icon="📄")
st.title("📄 Chat With PDF")
st.caption("Summarize and Chat with Documents using facebook/bart-large-cnn + BGE Embeddings")
for key in ["current_file", "pdf_summary", "analysis_time", "pdf_report", "word_report", "vectorstore", "messages"]:
if key not in st.session_state:
st.session_state[key] = None if key != "messages" else []
uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
if uploaded_file is not None:
if st.session_state.current_file != uploaded_file.name:
st.session_state.current_file = uploaded_file.name
for key in ["pdf_summary", "pdf_report", "word_report", "vectorstore", "messages"]:
st.session_state[key] = None if key != "messages" else []
text = extract_text_from_pdf(uploaded_file)
if st.button("Analyze Text"):
start_time = time.time()
with st.spinner("Analyzing..."):
analysis = analyze_text_structured(text)
st.session_state.pdf_summary = analysis
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
chunks = text_splitter.split_text(text)
docs = [LCDocument(page_content=chunk) for chunk in chunks]
st.session_state.vectorstore = FAISS.from_documents(docs, embedding_function)
st.session_state.pdf_report = create_pdf_report(analysis)
st.session_state.word_report = create_word_report(analysis)
st.session_state.analysis_time = time.time() - start_time
st.subheader("Analysis Results")
st.text(json_to_text(analysis))
col1, col2 = st.columns(2)
with col1:
st.download_button(
label="Download PDF Report",
data=st.session_state.pdf_report,
file_name="analysis_report.pdf",
mime="application/pdf"
)
with col2:
st.download_button(
label="Download Word Report",
data=st.session_state.word_report,
file_name="analysis_report.docx",
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
)
if st.session_state.vectorstore is not None:
st.subheader("Chat with the Document")
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("Ask a question about the document"):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Searching..."):
docs = st.session_state.vectorstore.similarity_search(prompt, k=3)
context = "\n".join([doc.page_content for doc in docs])
answer = qa_pipeline({"question": prompt, "context": context})["answer"]
st.markdown(answer)
st.session_state.messages.append({"role": "assistant", "content": answer})
if st.session_state.analysis_time is not None:
st.markdown(
f'<div style="text-align:center; margin-top:2rem; color:gray;">Analysis Time: {st.session_state.analysis_time:.1f}s | Embedding: BGE Large v1.5</div>',
unsafe_allow_html=True
)