theerasin commited on
Commit
02e74b1
·
verified ·
1 Parent(s): 3a7caf6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -126
app.py CHANGED
@@ -1,163 +1,159 @@
1
  import streamlit as st
2
- from transformers import pipeline
3
- from sentence_transformers import SentenceTransformer
4
- import os
5
- import time
6
  from datetime import datetime
7
  import PyPDF2
8
  from fpdf import FPDF
9
  from docx import Document
10
  import io
11
- from langchain.vectorstores import FAISS
12
- from langchain.docstore.document import Document as LCDocument
13
- from langchain.text_splitter import RecursiveCharacterTextSplitter
14
-
15
- # === Load Models ===
16
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
17
- embedding_model = SentenceTransformer("BAAI/bge-large-en-v1.5")
18
-
19
- # === Streamlit Setup ===
20
- st.set_page_config(page_title="Chat With PDF", page_icon="📄")
21
-
22
- # === Session States ===
23
- if "current_file" not in st.session_state:
24
- st.session_state.current_file = None
25
- if "summary_text" not in st.session_state:
26
- st.session_state.summary_text = ""
27
- if "analysis_time" not in st.session_state:
28
- st.session_state.analysis_time = 0
29
- if "pdf_report" not in st.session_state:
30
- st.session_state.pdf_report = None
31
- if "word_report" not in st.session_state:
32
- st.session_state.word_report = None
33
- if "vectorstore" not in st.session_state:
34
- st.session_state.vectorstore = None
35
- if "messages" not in st.session_state:
36
- st.session_state.messages = []
37
-
38
- # === Utility Functions ===
39
- def extract_text_from_pdf(file):
40
- reader = PyPDF2.PdfReader(file)
41
- return "".join(page.extract_text() or "" for page in reader.pages)
42
-
43
- def summarize_text(text):
44
- chunks = [text[i:i+1000] for i in range(0, len(text), 1000)]
45
- summary = ""
46
- for chunk in chunks:
47
- summary_chunk = summarizer(chunk, max_length=130, min_length=30, do_sample=False)[0]["summary_text"]
48
- summary += summary_chunk + "\n"
49
- return summary.strip()
50
-
51
- def generate_keypoints(text, top_k=5):
52
- sentences = text.split(". ")
53
- sentences = list(filter(lambda x: len(x) > 20, sentences))
54
- embeddings = embedding_model.encode(sentences)
55
- doc_embedding = embedding_model.encode([text])[0]
56
-
57
- similarities = [(sentences[i], float(embedding_model.similarity([doc_embedding], [embeddings[i]])[0][0]))
58
- for i in range(len(sentences))]
59
-
60
- sorted_similarities = sorted(similarities, key=lambda x: x[1], reverse=True)
61
- keypoints = [f"{i+1}. {s[0]}" for i, s in enumerate(sorted_similarities[:top_k])]
62
- return "\n".join(keypoints)
63
-
64
- def create_pdf_report(summary, keypoints):
65
  pdf = FPDF()
66
  pdf.add_page()
67
- # เพิ่มฟอนต์ที่รองรับ Unicode
68
- font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf"
69
- pdf.add_font("DejaVu", "", font_path, uni=True)
70
- pdf.set_font("DejaVu", size=12)
71
  pdf.cell(200, 10, txt="PDF Analysis Report", ln=True, align='C')
72
  pdf.cell(200, 10, txt=f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True, align='C')
73
- pdf.ln(10)
74
- pdf.multi_cell(0, 10, txt=f"Summary:\n{summary}\n\nKey Points:\n{keypoints}")
75
- buffer = io.BytesIO()
76
- pdf.output(buffer)
77
- buffer.seek(0)
78
- return buffer.read()
79
-
80
- def create_word_report(summary, keypoints):
81
  doc = Document()
82
  doc.add_heading('PDF Analysis Report', 0)
83
  doc.add_paragraph(f'Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')
84
- doc.add_heading('Summary', level=1)
85
- doc.add_paragraph(summary)
86
- doc.add_heading('Key Points', level=1)
87
- for point in keypoints.split("\n"):
88
- doc.add_paragraph(point, style='List Bullet')
89
- buffer = io.BytesIO()
90
- doc.save(buffer)
91
- buffer.seek(0)
92
- return buffer.read()
93
-
94
- # === UI ===
95
  st.title("📄 Chat With PDF")
96
- uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
 
 
 
 
97
 
98
- if uploaded_file:
 
 
99
  if st.session_state.current_file != uploaded_file.name:
100
  st.session_state.current_file = uploaded_file.name
101
- st.session_state.summary_text = ""
102
- st.session_state.vectorstore = None
103
- st.session_state.messages = []
104
 
105
  text = extract_text_from_pdf(uploaded_file)
106
 
107
- if st.button("Analyze Document"):
108
  start_time = time.time()
109
- with st.spinner("Summarizing..."):
110
- summary = summarize_text(text)
111
- keypoints = generate_keypoints(text)
112
- st.session_state.summary_text = summary
113
- st.session_state.pdf_report = create_pdf_report(summary, keypoints)
114
- st.session_state.word_report = create_word_report(summary, keypoints)
115
-
116
- splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
117
- docs = [LCDocument(page_content=chunk) for chunk in splitter.split_text(text)]
118
- vectors = embedding_model.encode([doc.page_content for doc in docs])
119
- text_embeddings = [(doc.page_content, vector) for doc, vector in zip(docs, vectors)]
120
- st.session_state.vectorstore = FAISS.from_embeddings(text_embeddings, embedding_model)
121
 
122
  st.session_state.analysis_time = time.time() - start_time
123
- st.subheader("Summary")
124
- st.write(summary)
125
- st.subheader("Key Points")
126
- st.write(keypoints)
127
 
128
  col1, col2 = st.columns(2)
129
  with col1:
130
- st.download_button("Download PDF Report", data=st.session_state.pdf_report,
131
- file_name="analysis_report.pdf", mime="application/pdf")
 
 
 
 
132
  with col2:
133
- st.download_button("Download Word Report", data=st.session_state.word_report,
134
- file_name="analysis_report.docx",
135
- mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document")
 
 
 
136
 
137
- # === Chat Section ===
138
- if st.session_state.vectorstore:
139
- st.subheader("Chat with Document")
140
 
141
- for msg in st.session_state.messages:
142
- with st.chat_message(msg["role"]):
143
- st.markdown(msg["content"])
144
 
145
  if prompt := st.chat_input("Ask a question about the document"):
146
  st.session_state.messages.append({"role": "user", "content": prompt})
147
  with st.chat_message("user"):
148
  st.markdown(prompt)
149
-
150
  with st.chat_message("assistant"):
151
- with st.spinner("Thinking..."):
152
  docs = st.session_state.vectorstore.similarity_search(prompt, k=3)
153
  context = "\n".join([doc.page_content for doc in docs])
154
- question = f"Context: {context}\n\nQuestion: {prompt}\n\nAnswer based only on the context above."
155
- result = summarizer(question, max_length=150, do_sample=False)[0]["summary_text"]
156
- st.markdown(result)
157
- st.session_state.messages.append({"role": "assistant", "content": result})
158
-
159
- # === Footer ===
160
- st.markdown(f"""
161
- <hr>
162
- <div style="text-align: center;">⏱️ Analysis Time: {st.session_state.analysis_time:.1f}s | Powered by BART + BGE-Large</div>
163
- """, unsafe_allow_html=True)
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
+ from pydantic import BaseModel, Field
4
+ from typing import List
 
5
  from datetime import datetime
6
  import PyPDF2
7
  from fpdf import FPDF
8
  from docx import Document
9
  import io
10
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
11
+ from langchain_community.vectorstores import FAISS
12
+ from langchain_core.documents import Document as LCDocument
13
+ from langchain.embeddings import HuggingFaceEmbeddings
14
+ import time
15
+
16
+ # === Load summarization model ===
17
+ tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
18
+ model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
19
+
20
+ # === Load QA pipeline ===
21
+ qa_pipeline = pipeline("question-answering", model="facebook/bart-large-cnn", tokenizer=tokenizer)
22
+
23
+ # === Setup BGE Embedding model ===
24
+ embedding_model_name = "BAAI/bge-large-en-v1.5"
25
+ embedding_function = HuggingFaceEmbeddings(model_name=embedding_model_name)
26
+
27
+ # === Data models ===
28
+ class KeyPoint(BaseModel):
29
+ point: str = Field(description="A key point extracted from the document.")
30
+
31
+ class Summary(BaseModel):
32
+ summary: str = Field(description="A brief summary of the document content.")
33
+
34
+ class DocumentAnalysis(BaseModel):
35
+ key_points: List[KeyPoint]
36
+ summary: Summary
37
+
38
+ def extract_text_from_pdf(pdf_file):
39
+ pdf_reader = PyPDF2.PdfReader(pdf_file)
40
+ return "".join(page.extract_text() for page in pdf_reader.pages)
41
+
42
+ def analyze_text_structured(text):
43
+ inputs = tokenizer([text], max_length=1024, truncation=True, return_tensors="pt")
44
+ summary_ids = model.generate(
45
+ inputs["input_ids"], num_beams=4, length_penalty=2.0,
46
+ max_length=200, min_length=50, early_stopping=True
47
+ )
48
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
49
+ key_points = [KeyPoint(point=line.strip()) for line in summary.split(". ") if line.strip()]
50
+ return DocumentAnalysis(summary=Summary(summary=summary), key_points=key_points)
51
+
52
+ def json_to_text(analysis):
53
+ text_output = "=== Summary ===\n" + f"{analysis.summary.summary}\n\n"
54
+ text_output += "=== Key Points ===\n"
55
+ for i, key_point in enumerate(analysis.key_points, start=1):
56
+ text_output += f"{i}. {key_point.point}\n"
57
+ return text_output
58
+
59
+ def create_pdf_report(analysis):
 
 
 
 
60
  pdf = FPDF()
61
  pdf.add_page()
62
+ pdf.set_font('Helvetica', '', 12)
 
 
 
63
  pdf.cell(200, 10, txt="PDF Analysis Report", ln=True, align='C')
64
  pdf.cell(200, 10, txt=f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True, align='C')
65
+ pdf.multi_cell(0, 10, txt=json_to_text(analysis))
66
+ pdf_bytes = io.BytesIO()
67
+ pdf.output(pdf_bytes, dest='S')
68
+ pdf_bytes.seek(0)
69
+ return pdf_bytes.getvalue()
70
+
71
+ def create_word_report(analysis):
 
72
  doc = Document()
73
  doc.add_heading('PDF Analysis Report', 0)
74
  doc.add_paragraph(f'Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')
75
+ doc.add_heading('Analysis', level=1)
76
+ doc.add_paragraph(json_to_text(analysis))
77
+ docx_bytes = io.BytesIO()
78
+ doc.save(docx_bytes)
79
+ docx_bytes.seek(0)
80
+ return docx_bytes.getvalue()
81
+
82
+ # === Streamlit UI ===
83
+ st.set_page_config(page_title="Chat With PDF (BART + BGE)", page_icon="📄")
 
 
84
  st.title("📄 Chat With PDF")
85
+ st.caption("Summarize and Chat with Documents using facebook/bart-large-cnn + BGE Embeddings")
86
+
87
+ for key in ["current_file", "pdf_summary", "analysis_time", "pdf_report", "word_report", "vectorstore", "messages"]:
88
+ if key not in st.session_state:
89
+ st.session_state[key] = None if key != "messages" else []
90
 
91
+ uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
92
+
93
+ if uploaded_file is not None:
94
  if st.session_state.current_file != uploaded_file.name:
95
  st.session_state.current_file = uploaded_file.name
96
+ for key in ["pdf_summary", "pdf_report", "word_report", "vectorstore", "messages"]:
97
+ st.session_state[key] = None if key != "messages" else []
 
98
 
99
  text = extract_text_from_pdf(uploaded_file)
100
 
101
+ if st.button("Analyze Text"):
102
  start_time = time.time()
103
+ with st.spinner("Analyzing..."):
104
+ analysis = analyze_text_structured(text)
105
+ st.session_state.pdf_summary = analysis
106
+
107
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
108
+ chunks = text_splitter.split_text(text)
109
+ docs = [LCDocument(page_content=chunk) for chunk in chunks]
110
+
111
+ st.session_state.vectorstore = FAISS.from_documents(docs, embedding_function)
112
+
113
+ st.session_state.pdf_report = create_pdf_report(analysis)
114
+ st.session_state.word_report = create_word_report(analysis)
115
 
116
  st.session_state.analysis_time = time.time() - start_time
117
+ st.subheader("Analysis Results")
118
+ st.text(json_to_text(analysis))
 
 
119
 
120
  col1, col2 = st.columns(2)
121
  with col1:
122
+ st.download_button(
123
+ label="Download PDF Report",
124
+ data=st.session_state.pdf_report,
125
+ file_name="analysis_report.pdf",
126
+ mime="application/pdf"
127
+ )
128
  with col2:
129
+ st.download_button(
130
+ label="Download Word Report",
131
+ data=st.session_state.word_report,
132
+ file_name="analysis_report.docx",
133
+ mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
134
+ )
135
 
136
+ if st.session_state.vectorstore is not None:
137
+ st.subheader("Chat with the Document")
 
138
 
139
+ for message in st.session_state.messages:
140
+ with st.chat_message(message["role"]):
141
+ st.markdown(message["content"])
142
 
143
  if prompt := st.chat_input("Ask a question about the document"):
144
  st.session_state.messages.append({"role": "user", "content": prompt})
145
  with st.chat_message("user"):
146
  st.markdown(prompt)
 
147
  with st.chat_message("assistant"):
148
+ with st.spinner("Searching..."):
149
  docs = st.session_state.vectorstore.similarity_search(prompt, k=3)
150
  context = "\n".join([doc.page_content for doc in docs])
151
+ answer = qa_pipeline({"question": prompt, "context": context})["answer"]
152
+ st.markdown(answer)
153
+ st.session_state.messages.append({"role": "assistant", "content": answer})
154
+
155
+ if st.session_state.analysis_time is not None:
156
+ st.markdown(
157
+ f'<div style="text-align:center; margin-top:2rem; color:gray;">Analysis Time: {st.session_state.analysis_time:.1f}s | Embedding: BGE Large v1.5</div>',
158
+ unsafe_allow_html=True
159
+ )