File size: 6,556 Bytes
2e80e49
 
 
5a30f3b
2e80e49
 
 
 
 
3d4a265
2e80e49
 
 
 
3d4a265
2e80e49
 
 
02e74b1
2e80e49
 
02e74b1
2e80e49
 
 
3201029
2e80e49
 
 
02e74b1
 
2e80e49
02e74b1
 
3d4a265
2e80e49
02e74b1
2e80e49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from pydantic import BaseModel, Field
from typing import List
from datetime import datetime
import PyPDF2
from fpdf import FPDF
from docx import Document
import io
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document as LCDocument
from langchain.embeddings import HuggingFaceEmbeddings
import time

# === Load summarization model ===
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")

# === Load QA pipeline ===
qa_pipeline = pipeline("question-answering", model="facebook/bart-large-cnn", tokenizer=tokenizer)

# === Setup BGE Embedding model ===
embedding_model_name = "BAAI/bge-large-en-v1.5"
embedding_function = HuggingFaceEmbeddings(model_name=embedding_model_name)

# === Data models ===
class KeyPoint(BaseModel):
    point: str = Field(description="A key point extracted from the document.")

class Summary(BaseModel):
    summary: str = Field(description="A brief summary of the document content.")

class DocumentAnalysis(BaseModel):
    key_points: List[KeyPoint]
    summary: Summary

def extract_text_from_pdf(pdf_file):
    pdf_reader = PyPDF2.PdfReader(pdf_file)
    return "".join(page.extract_text() for page in pdf_reader.pages)

def analyze_text_structured(text):
    inputs = tokenizer([text], max_length=1024, truncation=True, return_tensors="pt")
    summary_ids = model.generate(
        inputs["input_ids"], num_beams=4, length_penalty=2.0,
        max_length=200, min_length=50, early_stopping=True
    )
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    key_points = [KeyPoint(point=line.strip()) for line in summary.split(". ") if line.strip()]
    return DocumentAnalysis(summary=Summary(summary=summary), key_points=key_points)

def json_to_text(analysis):
    text_output = "=== Summary ===\n" + f"{analysis.summary.summary}\n\n"
    text_output += "=== Key Points ===\n"
    for i, key_point in enumerate(analysis.key_points, start=1):
        text_output += f"{i}. {key_point.point}\n"
    return text_output

def create_pdf_report(analysis):
    pdf = FPDF()
    pdf.add_page()
    pdf.set_font('Helvetica', '', 12)
    pdf.cell(200, 10, txt="PDF Analysis Report", ln=True, align='C')
    pdf.cell(200, 10, txt=f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True, align='C')
    pdf.multi_cell(0, 10, txt=json_to_text(analysis))
    pdf_bytes = io.BytesIO()
    pdf.output(pdf_bytes, dest='S')
    pdf_bytes.seek(0)
    return pdf_bytes.getvalue()

def create_word_report(analysis):
    doc = Document()
    doc.add_heading('PDF Analysis Report', 0)
    doc.add_paragraph(f'Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')
    doc.add_heading('Analysis', level=1)
    doc.add_paragraph(json_to_text(analysis))
    docx_bytes = io.BytesIO()
    doc.save(docx_bytes)
    docx_bytes.seek(0)
    return docx_bytes.getvalue()

# === Streamlit UI ===
st.set_page_config(page_title="Chat With PDF (BART + BGE)", page_icon="📄")
st.title("📄 Chat With PDF")
st.caption("Summarize and Chat with Documents using facebook/bart-large-cnn + BGE Embeddings")

for key in ["current_file", "pdf_summary", "analysis_time", "pdf_report", "word_report", "vectorstore", "messages"]:
    if key not in st.session_state:
        st.session_state[key] = None if key != "messages" else []

uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")

if uploaded_file is not None:
    if st.session_state.current_file != uploaded_file.name:
        st.session_state.current_file = uploaded_file.name
        for key in ["pdf_summary", "pdf_report", "word_report", "vectorstore", "messages"]:
            st.session_state[key] = None if key != "messages" else []

    text = extract_text_from_pdf(uploaded_file)

    if st.button("Analyze Text"):
        start_time = time.time()
        with st.spinner("Analyzing..."):
            analysis = analyze_text_structured(text)
            st.session_state.pdf_summary = analysis

            text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
            chunks = text_splitter.split_text(text)
            docs = [LCDocument(page_content=chunk) for chunk in chunks]

            st.session_state.vectorstore = FAISS.from_documents(docs, embedding_function)

            st.session_state.pdf_report = create_pdf_report(analysis)
            st.session_state.word_report = create_word_report(analysis)

        st.session_state.analysis_time = time.time() - start_time
        st.subheader("Analysis Results")
        st.text(json_to_text(analysis))

        col1, col2 = st.columns(2)
        with col1:
            st.download_button(
                label="Download PDF Report",
                data=st.session_state.pdf_report,
                file_name="analysis_report.pdf",
                mime="application/pdf"
            )
        with col2:
            st.download_button(
                label="Download Word Report",
                data=st.session_state.word_report,
                file_name="analysis_report.docx",
                mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
            )

if st.session_state.vectorstore is not None:
    st.subheader("Chat with the Document")

    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

    if prompt := st.chat_input("Ask a question about the document"):
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)
        with st.chat_message("assistant"):
            with st.spinner("Searching..."):
                docs = st.session_state.vectorstore.similarity_search(prompt, k=3)
                context = "\n".join([doc.page_content for doc in docs])
                answer = qa_pipeline({"question": prompt, "context": context})["answer"]
                st.markdown(answer)
        st.session_state.messages.append({"role": "assistant", "content": answer})

if st.session_state.analysis_time is not None:
    st.markdown(
        f'<div style="text-align:center; margin-top:2rem; color:gray;">Analysis Time: {st.session_state.analysis_time:.1f}s | Embedding: BGE Large v1.5</div>',
        unsafe_allow_html=True
    )