Spaces:
Sleeping
Sleeping
File size: 6,556 Bytes
2e80e49 5a30f3b 2e80e49 3d4a265 2e80e49 3d4a265 2e80e49 02e74b1 2e80e49 02e74b1 2e80e49 3201029 2e80e49 02e74b1 2e80e49 02e74b1 3d4a265 2e80e49 02e74b1 2e80e49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from pydantic import BaseModel, Field
from typing import List
from datetime import datetime
import PyPDF2
from fpdf import FPDF
from docx import Document
import io
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document as LCDocument
from langchain.embeddings import HuggingFaceEmbeddings
import time
# === Load summarization model ===
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
# === Load QA pipeline ===
qa_pipeline = pipeline("question-answering", model="facebook/bart-large-cnn", tokenizer=tokenizer)
# === Setup BGE Embedding model ===
embedding_model_name = "BAAI/bge-large-en-v1.5"
embedding_function = HuggingFaceEmbeddings(model_name=embedding_model_name)
# === Data models ===
class KeyPoint(BaseModel):
point: str = Field(description="A key point extracted from the document.")
class Summary(BaseModel):
summary: str = Field(description="A brief summary of the document content.")
class DocumentAnalysis(BaseModel):
key_points: List[KeyPoint]
summary: Summary
def extract_text_from_pdf(pdf_file):
pdf_reader = PyPDF2.PdfReader(pdf_file)
return "".join(page.extract_text() for page in pdf_reader.pages)
def analyze_text_structured(text):
inputs = tokenizer([text], max_length=1024, truncation=True, return_tensors="pt")
summary_ids = model.generate(
inputs["input_ids"], num_beams=4, length_penalty=2.0,
max_length=200, min_length=50, early_stopping=True
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
key_points = [KeyPoint(point=line.strip()) for line in summary.split(". ") if line.strip()]
return DocumentAnalysis(summary=Summary(summary=summary), key_points=key_points)
def json_to_text(analysis):
text_output = "=== Summary ===\n" + f"{analysis.summary.summary}\n\n"
text_output += "=== Key Points ===\n"
for i, key_point in enumerate(analysis.key_points, start=1):
text_output += f"{i}. {key_point.point}\n"
return text_output
def create_pdf_report(analysis):
pdf = FPDF()
pdf.add_page()
pdf.set_font('Helvetica', '', 12)
pdf.cell(200, 10, txt="PDF Analysis Report", ln=True, align='C')
pdf.cell(200, 10, txt=f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True, align='C')
pdf.multi_cell(0, 10, txt=json_to_text(analysis))
pdf_bytes = io.BytesIO()
pdf.output(pdf_bytes, dest='S')
pdf_bytes.seek(0)
return pdf_bytes.getvalue()
def create_word_report(analysis):
doc = Document()
doc.add_heading('PDF Analysis Report', 0)
doc.add_paragraph(f'Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')
doc.add_heading('Analysis', level=1)
doc.add_paragraph(json_to_text(analysis))
docx_bytes = io.BytesIO()
doc.save(docx_bytes)
docx_bytes.seek(0)
return docx_bytes.getvalue()
# === Streamlit UI ===
st.set_page_config(page_title="Chat With PDF (BART + BGE)", page_icon="📄")
st.title("📄 Chat With PDF")
st.caption("Summarize and Chat with Documents using facebook/bart-large-cnn + BGE Embeddings")
for key in ["current_file", "pdf_summary", "analysis_time", "pdf_report", "word_report", "vectorstore", "messages"]:
if key not in st.session_state:
st.session_state[key] = None if key != "messages" else []
uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
if uploaded_file is not None:
if st.session_state.current_file != uploaded_file.name:
st.session_state.current_file = uploaded_file.name
for key in ["pdf_summary", "pdf_report", "word_report", "vectorstore", "messages"]:
st.session_state[key] = None if key != "messages" else []
text = extract_text_from_pdf(uploaded_file)
if st.button("Analyze Text"):
start_time = time.time()
with st.spinner("Analyzing..."):
analysis = analyze_text_structured(text)
st.session_state.pdf_summary = analysis
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
chunks = text_splitter.split_text(text)
docs = [LCDocument(page_content=chunk) for chunk in chunks]
st.session_state.vectorstore = FAISS.from_documents(docs, embedding_function)
st.session_state.pdf_report = create_pdf_report(analysis)
st.session_state.word_report = create_word_report(analysis)
st.session_state.analysis_time = time.time() - start_time
st.subheader("Analysis Results")
st.text(json_to_text(analysis))
col1, col2 = st.columns(2)
with col1:
st.download_button(
label="Download PDF Report",
data=st.session_state.pdf_report,
file_name="analysis_report.pdf",
mime="application/pdf"
)
with col2:
st.download_button(
label="Download Word Report",
data=st.session_state.word_report,
file_name="analysis_report.docx",
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
)
if st.session_state.vectorstore is not None:
st.subheader("Chat with the Document")
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("Ask a question about the document"):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Searching..."):
docs = st.session_state.vectorstore.similarity_search(prompt, k=3)
context = "\n".join([doc.page_content for doc in docs])
answer = qa_pipeline({"question": prompt, "context": context})["answer"]
st.markdown(answer)
st.session_state.messages.append({"role": "assistant", "content": answer})
if st.session_state.analysis_time is not None:
st.markdown(
f'<div style="text-align:center; margin-top:2rem; color:gray;">Analysis Time: {st.session_state.analysis_time:.1f}s | Embedding: BGE Large v1.5</div>',
unsafe_allow_html=True
) |