theerasin commited on
Commit
2e80e49
·
verified ·
1 Parent(s): 5a30f3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -89
app.py CHANGED
@@ -1,101 +1,159 @@
1
- # app.py
2
-
3
- from fastapi import FastAPI, UploadFile, File
4
- from pydantic import BaseModel
5
  from typing import List
6
- import fitz # PyMuPDF
7
- from transformers import pipeline
8
- from sentence_transformers import SentenceTransformer
9
- from langchain.vectorstores import FAISS
10
- from langchain_community.embeddings import HuggingFaceBgeEmbeddings
11
  from langchain_text_splitters import RecursiveCharacterTextSplitter
12
- from langchain.schema import Document
13
- from langchain.chains.question_answering import load_qa_chain
14
- from langchain.llms import HuggingFacePipeline
15
- from langchain_core.documents import Document as LangchainDocument
16
-
17
- # --- Init FastAPI ---
18
- app = FastAPI()
19
 
20
- # --- Summarizer ---
21
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
 
22
 
23
- # --- Question Answering ---
24
- qa_pipe = pipeline("question-answering", model="deepset/roberta-base-squad2")
25
 
26
- # --- Embedding model ---
27
- embedding_model = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-small-en-v1.5")
 
28
 
29
- # --- Text Splitter ---
30
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
 
31
 
32
- # --- Pydantic schemas ---
33
  class Summary(BaseModel):
34
- summary: str
35
-
36
- class KeyPoint(BaseModel):
37
- point: str
38
 
39
  class DocumentAnalysis(BaseModel):
40
- summary: Summary
41
  key_points: List[KeyPoint]
 
42
 
43
- class QARequest(BaseModel):
44
- question: str
45
- context: str
46
-
47
- class QAResponse(BaseModel):
48
- answer: str
49
-
50
- # --- PDF Text Extractor ---
51
- def extract_text_from_pdf(pdf_file: UploadFile) -> str:
52
- text = ""
53
- with fitz.open(stream=pdf_file.file.read(), filetype="pdf") as doc:
54
- for page in doc:
55
- text += page.get_text()
56
- return text
57
-
58
- # --- Analyze Text (summarization) ---
59
- def analyze_text_structured(text: str) -> DocumentAnalysis:
60
- chunks = text_splitter.split_text(text)
61
- summaries = []
62
- for chunk in chunks:
63
- result = summarizer(chunk, max_length=200, min_length=50, do_sample=False)
64
- if result:
65
- summaries.append(result[0]["summary_text"])
66
-
67
- full_summary = " ".join(summaries)
68
- key_points = [KeyPoint(point=line.strip()) for line in full_summary.split(". ") if line.strip()]
69
- return DocumentAnalysis(summary=Summary(summary=full_summary), key_points=key_points)
70
-
71
- # --- Question Answering ---
72
- def answer_question(question: str, context: str) -> str:
73
- result = qa_pipe(question=question, context=context)
74
- return result["answer"]
75
-
76
- # --- PDF Upload + Analysis Route ---
77
- @app.post("/analyze-pdf", response_model=DocumentAnalysis)
78
- async def analyze_pdf(file: UploadFile = File(...)):
79
- text = extract_text_from_pdf(file)
80
- analysis = analyze_text_structured(text)
81
- return analysis
82
-
83
- # --- Question Answering Route ---
84
- @app.post("/qa", response_model=QAResponse)
85
- async def ask_question(qa_request: QARequest):
86
- answer = answer_question(qa_request.question, qa_request.context)
87
- return QAResponse(answer=answer)
88
-
89
- # --- Embedding Search (FAISS) Demo ---
90
- @app.post("/search-chunks")
91
- async def search_chunks(file: UploadFile = File(...), query: str = ""):
92
- text = extract_text_from_pdf(file)
93
- chunks = text_splitter.split_text(text)
94
- documents = [LangchainDocument(page_content=chunk) for chunk in chunks]
95
-
96
- # Create FAISS vector store
97
- db = FAISS.from_documents(documents, embedding_model)
98
-
99
- # Similarity search
100
- results = db.similarity_search(query, k=3)
101
- return {"results": [doc.page_content for doc in results]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
+ from pydantic import BaseModel, Field
 
4
  from typing import List
5
+ from datetime import datetime
6
+ import PyPDF2
7
+ from fpdf import FPDF
8
+ from docx import Document
9
+ import io
10
  from langchain_text_splitters import RecursiveCharacterTextSplitter
11
+ from langchain_community.vectorstores import FAISS
12
+ from langchain_core.documents import Document as LCDocument
13
+ from langchain.embeddings import HuggingFaceEmbeddings
14
+ import time
 
 
 
15
 
16
+ # === Load summarization model ===
17
+ tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
18
+ model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
19
 
20
+ # === Load QA pipeline ===
21
+ qa_pipeline = pipeline("question-answering", model="facebook/bart-large-cnn", tokenizer=tokenizer)
22
 
23
+ # === Setup BGE Embedding model ===
24
+ embedding_model_name = "BAAI/bge-large-en-v1.5"
25
+ embedding_function = HuggingFaceEmbeddings(model_name=embedding_model_name)
26
 
27
+ # === Data models ===
28
+ class KeyPoint(BaseModel):
29
+ point: str = Field(description="A key point extracted from the document.")
30
 
 
31
  class Summary(BaseModel):
32
+ summary: str = Field(description="A brief summary of the document content.")
 
 
 
33
 
34
  class DocumentAnalysis(BaseModel):
 
35
  key_points: List[KeyPoint]
36
+ summary: Summary
37
 
38
+ def extract_text_from_pdf(pdf_file):
39
+ pdf_reader = PyPDF2.PdfReader(pdf_file)
40
+ return "".join(page.extract_text() for page in pdf_reader.pages)
41
+
42
+ def analyze_text_structured(text):
43
+ inputs = tokenizer([text], max_length=1024, truncation=True, return_tensors="pt")
44
+ summary_ids = model.generate(
45
+ inputs["input_ids"], num_beams=4, length_penalty=2.0,
46
+ max_length=200, min_length=50, early_stopping=True
47
+ )
48
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
49
+ key_points = [KeyPoint(point=line.strip()) for line in summary.split(". ") if line.strip()]
50
+ return DocumentAnalysis(summary=Summary(summary=summary), key_points=key_points)
51
+
52
+ def json_to_text(analysis):
53
+ text_output = "=== Summary ===\n" + f"{analysis.summary.summary}\n\n"
54
+ text_output += "=== Key Points ===\n"
55
+ for i, key_point in enumerate(analysis.key_points, start=1):
56
+ text_output += f"{i}. {key_point.point}\n"
57
+ return text_output
58
+
59
+ def create_pdf_report(analysis):
60
+ pdf = FPDF()
61
+ pdf.add_page()
62
+ pdf.set_font('Helvetica', '', 12)
63
+ pdf.cell(200, 10, txt="PDF Analysis Report", ln=True, align='C')
64
+ pdf.cell(200, 10, txt=f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True, align='C')
65
+ pdf.multi_cell(0, 10, txt=json_to_text(analysis))
66
+ pdf_bytes = io.BytesIO()
67
+ pdf.output(pdf_bytes, dest='S')
68
+ pdf_bytes.seek(0)
69
+ return pdf_bytes.getvalue()
70
+
71
+ def create_word_report(analysis):
72
+ doc = Document()
73
+ doc.add_heading('PDF Analysis Report', 0)
74
+ doc.add_paragraph(f'Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')
75
+ doc.add_heading('Analysis', level=1)
76
+ doc.add_paragraph(json_to_text(analysis))
77
+ docx_bytes = io.BytesIO()
78
+ doc.save(docx_bytes)
79
+ docx_bytes.seek(0)
80
+ return docx_bytes.getvalue()
81
+
82
+ # === Streamlit UI ===
83
+ st.set_page_config(page_title="Chat With PDF (BART + BGE)", page_icon="📄")
84
+ st.title("📄 Chat With PDF")
85
+ st.caption("Summarize and Chat with Documents using facebook/bart-large-cnn + BGE Embeddings")
86
+
87
+ for key in ["current_file", "pdf_summary", "analysis_time", "pdf_report", "word_report", "vectorstore", "messages"]:
88
+ if key not in st.session_state:
89
+ st.session_state[key] = None if key != "messages" else []
90
+
91
+ uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
92
+
93
+ if uploaded_file is not None:
94
+ if st.session_state.current_file != uploaded_file.name:
95
+ st.session_state.current_file = uploaded_file.name
96
+ for key in ["pdf_summary", "pdf_report", "word_report", "vectorstore", "messages"]:
97
+ st.session_state[key] = None if key != "messages" else []
98
+
99
+ text = extract_text_from_pdf(uploaded_file)
100
+
101
+ if st.button("Analyze Text"):
102
+ start_time = time.time()
103
+ with st.spinner("Analyzing..."):
104
+ analysis = analyze_text_structured(text)
105
+ st.session_state.pdf_summary = analysis
106
+
107
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
108
+ chunks = text_splitter.split_text(text)
109
+ docs = [LCDocument(page_content=chunk) for chunk in chunks]
110
+
111
+ st.session_state.vectorstore = FAISS.from_documents(docs, embedding_function)
112
+
113
+ st.session_state.pdf_report = create_pdf_report(analysis)
114
+ st.session_state.word_report = create_word_report(analysis)
115
+
116
+ st.session_state.analysis_time = time.time() - start_time
117
+ st.subheader("Analysis Results")
118
+ st.text(json_to_text(analysis))
119
+
120
+ col1, col2 = st.columns(2)
121
+ with col1:
122
+ st.download_button(
123
+ label="Download PDF Report",
124
+ data=st.session_state.pdf_report,
125
+ file_name="analysis_report.pdf",
126
+ mime="application/pdf"
127
+ )
128
+ with col2:
129
+ st.download_button(
130
+ label="Download Word Report",
131
+ data=st.session_state.word_report,
132
+ file_name="analysis_report.docx",
133
+ mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
134
+ )
135
+
136
+ if st.session_state.vectorstore is not None:
137
+ st.subheader("Chat with the Document")
138
+
139
+ for message in st.session_state.messages:
140
+ with st.chat_message(message["role"]):
141
+ st.markdown(message["content"])
142
+
143
+ if prompt := st.chat_input("Ask a question about the document"):
144
+ st.session_state.messages.append({"role": "user", "content": prompt})
145
+ with st.chat_message("user"):
146
+ st.markdown(prompt)
147
+ with st.chat_message("assistant"):
148
+ with st.spinner("Searching..."):
149
+ docs = st.session_state.vectorstore.similarity_search(prompt, k=3)
150
+ context = "\n".join([doc.page_content for doc in docs])
151
+ answer = qa_pipeline({"question": prompt, "context": context})["answer"]
152
+ st.markdown(answer)
153
+ st.session_state.messages.append({"role": "assistant", "content": answer})
154
+
155
+ if st.session_state.analysis_time is not None:
156
+ st.markdown(
157
+ f'<div style="text-align:center; margin-top:2rem; color:gray;">Analysis Time: {st.session_state.analysis_time:.1f}s | Embedding: BGE Large v1.5</div>',
158
+ unsafe_allow_html=True
159
+ )