Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| from pydantic import BaseModel, Field | |
| from typing import List | |
| from datetime import datetime | |
| import PyPDF2 | |
| from fpdf import FPDF | |
| from docx import Document | |
| import io | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_core.documents import Document as LCDocument | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| import time | |
| # === Load summarization model === | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") | |
| model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn") | |
| # === Load QA pipeline === | |
| qa_pipeline = pipeline("question-answering", model="facebook/bart-large-cnn", tokenizer=tokenizer) | |
| # === Setup BGE Embedding model === | |
| embedding_model_name = "BAAI/bge-large-en-v1.5" | |
| embedding_function = HuggingFaceEmbeddings(model_name=embedding_model_name) | |
| # === Data models === | |
| class KeyPoint(BaseModel): | |
| point: str = Field(description="A key point extracted from the document.") | |
| class Summary(BaseModel): | |
| summary: str = Field(description="A brief summary of the document content.") | |
| class DocumentAnalysis(BaseModel): | |
| key_points: List[KeyPoint] | |
| summary: Summary | |
| def extract_text_from_pdf(pdf_file): | |
| pdf_reader = PyPDF2.PdfReader(pdf_file) | |
| return "".join(page.extract_text() for page in pdf_reader.pages) | |
| def analyze_text_structured(text): | |
| inputs = tokenizer([text], max_length=1024, truncation=True, return_tensors="pt") | |
| summary_ids = model.generate( | |
| inputs["input_ids"], num_beams=4, length_penalty=2.0, | |
| max_length=200, min_length=50, early_stopping=True | |
| ) | |
| summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
| key_points = [KeyPoint(point=line.strip()) for line in summary.split(". ") if line.strip()] | |
| return DocumentAnalysis(summary=Summary(summary=summary), key_points=key_points) | |
| def json_to_text(analysis): | |
| text_output = "=== Summary ===\n" + f"{analysis.summary.summary}\n\n" | |
| text_output += "=== Key Points ===\n" | |
| for i, key_point in enumerate(analysis.key_points, start=1): | |
| text_output += f"{i}. {key_point.point}\n" | |
| return text_output | |
| def create_pdf_report(analysis): | |
| pdf = FPDF() | |
| pdf.add_page() | |
| pdf.set_font('Helvetica', '', 12) | |
| pdf.cell(200, 10, txt="PDF Analysis Report", ln=True, align='C') | |
| pdf.cell(200, 10, txt=f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True, align='C') | |
| pdf.multi_cell(0, 10, txt=json_to_text(analysis)) | |
| pdf_bytes = io.BytesIO() | |
| pdf.output(pdf_bytes, dest='S') | |
| pdf_bytes.seek(0) | |
| return pdf_bytes.getvalue() | |
| def create_word_report(analysis): | |
| doc = Document() | |
| doc.add_heading('PDF Analysis Report', 0) | |
| doc.add_paragraph(f'Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}') | |
| doc.add_heading('Analysis', level=1) | |
| doc.add_paragraph(json_to_text(analysis)) | |
| docx_bytes = io.BytesIO() | |
| doc.save(docx_bytes) | |
| docx_bytes.seek(0) | |
| return docx_bytes.getvalue() | |
| # === Streamlit UI === | |
| st.set_page_config(page_title="Chat With PDF (BART + BGE)", page_icon="📄") | |
| st.title("📄 Chat With PDF") | |
| st.caption("Summarize and Chat with Documents using facebook/bart-large-cnn + BGE Embeddings") | |
| for key in ["current_file", "pdf_summary", "analysis_time", "pdf_report", "word_report", "vectorstore", "messages"]: | |
| if key not in st.session_state: | |
| st.session_state[key] = None if key != "messages" else [] | |
| uploaded_file = st.file_uploader("Upload a PDF file", type="pdf") | |
| if uploaded_file is not None: | |
| if st.session_state.current_file != uploaded_file.name: | |
| st.session_state.current_file = uploaded_file.name | |
| for key in ["pdf_summary", "pdf_report", "word_report", "vectorstore", "messages"]: | |
| st.session_state[key] = None if key != "messages" else [] | |
| text = extract_text_from_pdf(uploaded_file) | |
| if st.button("Analyze Text"): | |
| start_time = time.time() | |
| with st.spinner("Analyzing..."): | |
| analysis = analyze_text_structured(text) | |
| st.session_state.pdf_summary = analysis | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| chunks = text_splitter.split_text(text) | |
| docs = [LCDocument(page_content=chunk) for chunk in chunks] | |
| st.session_state.vectorstore = FAISS.from_documents(docs, embedding_function) | |
| st.session_state.pdf_report = create_pdf_report(analysis) | |
| st.session_state.word_report = create_word_report(analysis) | |
| st.session_state.analysis_time = time.time() - start_time | |
| st.subheader("Analysis Results") | |
| st.text(json_to_text(analysis)) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.download_button( | |
| label="Download PDF Report", | |
| data=st.session_state.pdf_report, | |
| file_name="analysis_report.pdf", | |
| mime="application/pdf" | |
| ) | |
| with col2: | |
| st.download_button( | |
| label="Download Word Report", | |
| data=st.session_state.word_report, | |
| file_name="analysis_report.docx", | |
| mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document" | |
| ) | |
| if st.session_state.vectorstore is not None: | |
| st.subheader("Chat with the Document") | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| if prompt := st.chat_input("Ask a question about the document"): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| with st.chat_message("assistant"): | |
| with st.spinner("Searching..."): | |
| docs = st.session_state.vectorstore.similarity_search(prompt, k=3) | |
| context = "\n".join([doc.page_content for doc in docs]) | |
| answer = qa_pipeline({"question": prompt, "context": context})["answer"] | |
| st.markdown(answer) | |
| st.session_state.messages.append({"role": "assistant", "content": answer}) | |
| if st.session_state.analysis_time is not None: | |
| st.markdown( | |
| f'<div style="text-align:center; margin-top:2rem; color:gray;">Analysis Time: {st.session_state.analysis_time:.1f}s | Embedding: BGE Large v1.5</div>', | |
| unsafe_allow_html=True | |
| ) |