theerasin commited on
Commit
3d4a265
·
verified ·
1 Parent(s): 5aefeb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -145
app.py CHANGED
@@ -1,166 +1,110 @@
1
  import streamlit as st
2
- from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering
3
- from pydantic import BaseModel, Field
4
- from typing import List
5
- from datetime import datetime
6
- import PyPDF2
7
- from fpdf import FPDF
8
- from docx import Document
9
- import io
10
- from langchain_text_splitters import RecursiveCharacterTextSplitter
11
  from langchain_community.vectorstores import FAISS
 
 
12
  from langchain_core.documents import Document as LCDocument
13
- import time
 
 
 
 
 
14
 
15
- # === Summarization model ===
 
16
  summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
17
 
18
- # === QA model ===
19
- qa_tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
20
- qa_model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
21
- qa_pipeline = pipeline("question-answering", model=qa_model, tokenizer=qa_tokenizer)
22
 
23
  # === Embedding model ===
24
- from sentence_transformers import SentenceTransformer
25
- from langchain.embeddings import HuggingFaceEmbeddings
26
- embedding_model = SentenceTransformer("BAAI/bge-small-en-v1.5")
27
- embedding_function = HuggingFaceEmbeddings(model=embedding_model)
28
 
29
- # === Data models ===
30
- class KeyPoint(BaseModel):
31
- point: str = Field(description="A key point extracted from the document.")
32
 
 
33
  class Summary(BaseModel):
34
- summary: str = Field(description="A brief summary of the document content.")
 
 
 
35
 
36
  class DocumentAnalysis(BaseModel):
37
- key_points: List[KeyPoint]
38
  summary: Summary
 
39
 
40
- def extract_text_from_pdf(pdf_file):
41
- pdf_reader = PyPDF2.PdfReader(pdf_file)
42
- return "".join(page.extract_text() for page in pdf_reader.pages)
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def analyze_text_structured(text):
45
- splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
 
 
 
 
 
 
46
  chunks = splitter.split_text(text)
 
 
 
 
 
 
 
 
 
47
 
48
- summaries = []
49
- for chunk in chunks:
50
- try:
51
- result = summarizer(chunk, max_length=200, min_length=50, do_sample=False)
52
- summaries.append(result[0]["summary_text"])
53
- except Exception:
54
- summaries.append("")
55
-
56
- full_summary = " ".join(summaries)
57
- key_points = [KeyPoint(point=line.strip()) for line in full_summary.split(". ") if line.strip()]
58
- return DocumentAnalysis(summary=Summary(summary=full_summary), key_points=key_points)
59
-
60
- def json_to_text(analysis):
61
- text_output = "=== Summary ===\n" + f"{analysis.summary.summary}\n\n"
62
- text_output += "=== Key Points ===\n"
63
- for i, key_point in enumerate(analysis.key_points, start=1):
64
- text_output += f"{i}. {key_point.point}\n"
65
- return text_output
66
-
67
- def create_pdf_report(analysis):
68
- pdf = FPDF()
69
- pdf.add_page()
70
- pdf.set_font('Helvetica', '', 12)
71
- pdf.cell(200, 10, txt="PDF Analysis Report", ln=True, align='C')
72
- pdf.cell(200, 10, txt=f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True, align='C')
73
- pdf.multi_cell(0, 10, txt=json_to_text(analysis))
74
- pdf_bytes = io.BytesIO()
75
- pdf.output(pdf_bytes, dest='S')
76
- pdf_bytes.seek(0)
77
- return pdf_bytes.getvalue()
78
-
79
- def create_word_report(analysis):
80
- doc = Document()
81
- doc.add_heading('PDF Analysis Report', 0)
82
- doc.add_paragraph(f'Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')
83
- doc.add_heading('Analysis', level=1)
84
- doc.add_paragraph(json_to_text(analysis))
85
- docx_bytes = io.BytesIO()
86
- doc.save(docx_bytes)
87
- docx_bytes.seek(0)
88
- return docx_bytes.getvalue()
89
 
90
  # === Streamlit UI ===
91
- st.set_page_config(page_title="Chat With PDF (BART + BGE)", page_icon="📄")
92
- st.title("📄 Chat With PDF")
93
- st.caption("Summarize and Chat with Documents using facebook/bart-large-cnn + BGE-small Embeddings + RoBERTa QA")
94
-
95
- for key in ["current_file", "pdf_summary", "analysis_time", "pdf_report", "word_report", "vectorstore", "messages"]:
96
- if key not in st.session_state:
97
- st.session_state[key] = None if key != "messages" else []
98
-
99
- uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
100
-
101
- if uploaded_file is not None:
102
- if st.session_state.current_file != uploaded_file.name:
103
- st.session_state.current_file = uploaded_file.name
104
- for key in ["pdf_summary", "pdf_report", "word_report", "vectorstore", "messages"]:
105
- st.session_state[key] = None if key != "messages" else []
106
-
107
- text = extract_text_from_pdf(uploaded_file)
108
-
109
- if st.button("Analyze Text"):
110
- start_time = time.time()
111
- with st.spinner("Analyzing..."):
112
- analysis = analyze_text_structured(text)
113
- st.session_state.pdf_summary = analysis
114
-
115
- splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
116
- chunks = splitter.split_text(text)
117
- docs = [LCDocument(page_content=chunk) for chunk in chunks]
118
- st.session_state.vectorstore = FAISS.from_documents(docs, embedding_function)
119
-
120
- st.session_state.pdf_report = create_pdf_report(analysis)
121
- st.session_state.word_report = create_word_report(analysis)
122
-
123
- st.session_state.analysis_time = time.time() - start_time
124
- st.subheader("Analysis Results")
125
- st.text(json_to_text(analysis))
126
-
127
- col1, col2 = st.columns(2)
128
- with col1:
129
- st.download_button(
130
- label="Download PDF Report",
131
- data=st.session_state.pdf_report,
132
- file_name="analysis_report.pdf",
133
- mime="application/pdf"
134
- )
135
- with col2:
136
- st.download_button(
137
- label="Download Word Report",
138
- data=st.session_state.word_report,
139
- file_name="analysis_report.docx",
140
- mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
141
- )
142
-
143
- if st.session_state.vectorstore is not None:
144
- st.subheader("Chat with the Document")
145
-
146
- for message in st.session_state.messages:
147
- with st.chat_message(message["role"]):
148
- st.markdown(message["content"])
149
-
150
- if prompt := st.chat_input("Ask a question about the document"):
151
- st.session_state.messages.append({"role": "user", "content": prompt})
152
- with st.chat_message("user"):
153
- st.markdown(prompt)
154
- with st.chat_message("assistant"):
155
- with st.spinner("Searching..."):
156
- docs = st.session_state.vectorstore.similarity_search(prompt, k=3)
157
- context = "\n".join([doc.page_content for doc in docs])
158
- answer = qa_pipeline({"question": prompt, "context": context})["answer"]
159
- st.markdown(answer)
160
- st.session_state.messages.append({"role": "assistant", "content": answer})
161
-
162
- if st.session_state.analysis_time is not None:
163
- st.markdown(
164
- f'<div style="text-align:center; margin-top:2rem; color:gray;">Analysis Time: {st.session_state.analysis_time:.1f}s | Embedding: BGE-small v1.5 | QA: RoBERTa-SQuAD2</div>',
165
- unsafe_allow_html=True
166
- )
 
1
  import streamlit as st
2
+ from transformers import pipeline
 
 
 
 
 
 
 
 
3
  from langchain_community.vectorstores import FAISS
4
+ from langchain.embeddings import HuggingFaceEmbeddings
5
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
6
  from langchain_core.documents import Document as LCDocument
7
+ import PyPDF2
8
+ from docx import Document as DocxDocument
9
+ import io
10
+ from typing import List
11
+ from pydantic import BaseModel
12
+ import tempfile
13
 
14
+
15
+ # === Summarizer ===
16
  summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
17
 
18
+ # === QA Model ===
19
+ qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2")
 
 
20
 
21
  # === Embedding model ===
22
+ embedding_function = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5")
 
 
 
23
 
 
 
 
24
 
25
+ # === Pydantic Models ===
26
  class Summary(BaseModel):
27
+ summary: str
28
+
29
+ class KeyPoint(BaseModel):
30
+ point: str
31
 
32
  class DocumentAnalysis(BaseModel):
 
33
  summary: Summary
34
+ key_points: List[KeyPoint]
35
 
 
 
 
36
 
37
+ # === Loaders ===
38
+ def load_pdf(file):
39
+ reader = PyPDF2.PdfReader(file)
40
+ text = ""
41
+ for page in reader.pages:
42
+ text += page.extract_text()
43
+ return text
44
+
45
+ def load_docx(file):
46
+ doc = DocxDocument(file)
47
+ return "\n".join([para.text for para in doc.paragraphs])
48
+
49
+
50
+ # === Analysis ===
51
  def analyze_text_structured(text):
52
+ result = summarizer(text, max_length=200, min_length=50, do_sample=False)[0]["summary_text"]
53
+ key_points = [KeyPoint(point=line.strip()) for line in result.split(". ") if line.strip()]
54
+ return DocumentAnalysis(summary=Summary(summary=result), key_points=key_points)
55
+
56
+ # === Embedding & Retrieval ===
57
+ def get_vectorstore_from_text(text):
58
+ splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
59
  chunks = splitter.split_text(text)
60
+ docs = [LCDocument(page_content=chunk) for chunk in chunks]
61
+ return FAISS.from_documents(docs, embedding_function)
62
+
63
+ def answer_question(vectorstore, question):
64
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
65
+ docs = retriever.get_relevant_documents(question)
66
+ context = "\n".join([doc.page_content for doc in docs])
67
+ result = qa_pipeline(question=question, context=context)
68
+ return result["answer"]
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  # === Streamlit UI ===
72
+ st.title("📄 AI Document Analyzer")
73
+
74
+ uploaded_file = st.file_uploader("Upload a document (PDF or DOCX)", type=["pdf", "docx"])
75
+ input_text = st.text_area("Or paste your text here", height=200)
76
+
77
+ if st.button("Analyze"):
78
+ if uploaded_file:
79
+ file_bytes = uploaded_file.read()
80
+ file_ext = uploaded_file.name.split(".")[-1]
81
+ if file_ext == "pdf":
82
+ text = load_pdf(io.BytesIO(file_bytes))
83
+ elif file_ext == "docx":
84
+ text = load_docx(io.BytesIO(file_bytes))
85
+ else:
86
+ st.error("Unsupported file format.")
87
+ st.stop()
88
+ elif input_text:
89
+ text = input_text
90
+ else:
91
+ st.warning("Please upload a file or paste text.")
92
+ st.stop()
93
+
94
+ with st.spinner("Analyzing..."):
95
+ analysis = analyze_text_structured(text)
96
+ vectorstore = get_vectorstore_from_text(text)
97
+
98
+ st.subheader("🔍 Summary")
99
+ st.write(analysis.summary.summary)
100
+
101
+ st.subheader("📌 Key Points")
102
+ for point in analysis.key_points:
103
+ st.markdown(f"- {point.point}")
104
+
105
+ st.subheader(" Ask a Question")
106
+ user_question = st.text_input("What do you want to know?")
107
+ if user_question:
108
+ with st.spinner("Searching for an answer..."):
109
+ answer = answer_question(vectorstore, user_question)
110
+ st.success(f"💬 Answer: {answer}")