import streamlit as st from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline from pydantic import BaseModel, Field from typing import List from datetime import datetime import PyPDF2 from fpdf import FPDF from docx import Document import io from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS from langchain_core.documents import Document as LCDocument from langchain.embeddings import HuggingFaceEmbeddings import time # === Load summarization model === tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn") # === Load QA pipeline === qa_pipeline = pipeline("question-answering", model="facebook/bart-large-cnn", tokenizer=tokenizer) # === Setup BGE Embedding model === embedding_model_name = "BAAI/bge-large-en-v1.5" embedding_function = HuggingFaceEmbeddings(model_name=embedding_model_name) # === Data models === class KeyPoint(BaseModel): point: str = Field(description="A key point extracted from the document.") class Summary(BaseModel): summary: str = Field(description="A brief summary of the document content.") class DocumentAnalysis(BaseModel): key_points: List[KeyPoint] summary: Summary def extract_text_from_pdf(pdf_file): pdf_reader = PyPDF2.PdfReader(pdf_file) return "".join(page.extract_text() for page in pdf_reader.pages) def analyze_text_structured(text): inputs = tokenizer([text], max_length=1024, truncation=True, return_tensors="pt") summary_ids = model.generate( inputs["input_ids"], num_beams=4, length_penalty=2.0, max_length=200, min_length=50, early_stopping=True ) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) key_points = [KeyPoint(point=line.strip()) for line in summary.split(". ") if line.strip()] return DocumentAnalysis(summary=Summary(summary=summary), key_points=key_points) def json_to_text(analysis): text_output = "=== Summary ===\n" + f"{analysis.summary.summary}\n\n" text_output += "=== Key Points ===\n" for i, key_point in enumerate(analysis.key_points, start=1): text_output += f"{i}. {key_point.point}\n" return text_output def create_pdf_report(analysis): pdf = FPDF() pdf.add_page() pdf.set_font('Helvetica', '', 12) pdf.cell(200, 10, txt="PDF Analysis Report", ln=True, align='C') pdf.cell(200, 10, txt=f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True, align='C') pdf.multi_cell(0, 10, txt=json_to_text(analysis)) pdf_bytes = io.BytesIO() pdf.output(pdf_bytes, dest='S') pdf_bytes.seek(0) return pdf_bytes.getvalue() def create_word_report(analysis): doc = Document() doc.add_heading('PDF Analysis Report', 0) doc.add_paragraph(f'Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}') doc.add_heading('Analysis', level=1) doc.add_paragraph(json_to_text(analysis)) docx_bytes = io.BytesIO() doc.save(docx_bytes) docx_bytes.seek(0) return docx_bytes.getvalue() # === Streamlit UI === st.set_page_config(page_title="Chat With PDF (BART + BGE)", page_icon="📄") st.title("📄 Chat With PDF") st.caption("Summarize and Chat with Documents using facebook/bart-large-cnn + BGE Embeddings") for key in ["current_file", "pdf_summary", "analysis_time", "pdf_report", "word_report", "vectorstore", "messages"]: if key not in st.session_state: st.session_state[key] = None if key != "messages" else [] uploaded_file = st.file_uploader("Upload a PDF file", type="pdf") if uploaded_file is not None: if st.session_state.current_file != uploaded_file.name: st.session_state.current_file = uploaded_file.name for key in ["pdf_summary", "pdf_report", "word_report", "vectorstore", "messages"]: st.session_state[key] = None if key != "messages" else [] text = extract_text_from_pdf(uploaded_file) if st.button("Analyze Text"): start_time = time.time() with st.spinner("Analyzing..."): analysis = analyze_text_structured(text) st.session_state.pdf_summary = analysis text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) chunks = text_splitter.split_text(text) docs = [LCDocument(page_content=chunk) for chunk in chunks] st.session_state.vectorstore = FAISS.from_documents(docs, embedding_function) st.session_state.pdf_report = create_pdf_report(analysis) st.session_state.word_report = create_word_report(analysis) st.session_state.analysis_time = time.time() - start_time st.subheader("Analysis Results") st.text(json_to_text(analysis)) col1, col2 = st.columns(2) with col1: st.download_button( label="Download PDF Report", data=st.session_state.pdf_report, file_name="analysis_report.pdf", mime="application/pdf" ) with col2: st.download_button( label="Download Word Report", data=st.session_state.word_report, file_name="analysis_report.docx", mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document" ) if st.session_state.vectorstore is not None: st.subheader("Chat with the Document") for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) if prompt := st.chat_input("Ask a question about the document"): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) with st.chat_message("assistant"): with st.spinner("Searching..."): docs = st.session_state.vectorstore.similarity_search(prompt, k=3) context = "\n".join([doc.page_content for doc in docs]) answer = qa_pipeline({"question": prompt, "context": context})["answer"] st.markdown(answer) st.session_state.messages.append({"role": "assistant", "content": answer}) if st.session_state.analysis_time is not None: st.markdown( f'
Analysis Time: {st.session_state.analysis_time:.1f}s | Embedding: BGE Large v1.5
', unsafe_allow_html=True )