theerasin commited on
Commit
5a30f3b
·
verified ·
1 Parent(s): 92722f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -85
app.py CHANGED
@@ -1,28 +1,35 @@
1
- import streamlit as st
 
 
 
 
 
2
  from transformers import pipeline
3
- from langchain_community.vectorstores import FAISS
4
- from langchain.embeddings import HuggingFaceEmbeddings
 
5
  from langchain_text_splitters import RecursiveCharacterTextSplitter
6
- from langchain_core.documents import Document as LCDocument
7
- import PyPDF2
8
- from docx import Document as DocxDocument
9
- import io
10
- from typing import List
11
- from pydantic import BaseModel
12
- import tempfile
13
 
 
 
14
 
15
- # === Summarizer ===
16
  summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
17
 
18
- # === QA Model ===
19
- qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2")
20
 
21
- # === Embedding model ===
22
- embedding_function = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5")
23
 
 
 
24
 
25
- # === Pydantic Models ===
26
  class Summary(BaseModel):
27
  summary: str
28
 
@@ -33,78 +40,62 @@ class DocumentAnalysis(BaseModel):
33
  summary: Summary
34
  key_points: List[KeyPoint]
35
 
 
 
 
 
 
 
36
 
37
- # === Loaders ===
38
- def load_pdf(file):
39
- reader = PyPDF2.PdfReader(file)
40
  text = ""
41
- for page in reader.pages:
42
- text += page.extract_text()
 
43
  return text
44
 
45
- def load_docx(file):
46
- doc = DocxDocument(file)
47
- return "\n".join([para.text for para in doc.paragraphs])
48
-
49
-
50
- # === Analysis ===
51
- def analyze_text_structured(text):
52
- result = summarizer(text, max_length=200, min_length=50, do_sample=False)[0]["summary_text"]
53
- key_points = [KeyPoint(point=line.strip()) for line in result.split(". ") if line.strip()]
54
- return DocumentAnalysis(summary=Summary(summary=result), key_points=key_points)
55
-
56
- # === Embedding & Retrieval ===
57
- def get_vectorstore_from_text(text):
58
- splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
59
- chunks = splitter.split_text(text)
60
- docs = [LCDocument(page_content=chunk) for chunk in chunks]
61
- return FAISS.from_documents(docs, embedding_function)
62
-
63
- def answer_question(vectorstore, question):
64
- retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
65
- docs = retriever.get_relevant_documents(question)
66
- context = "\n".join([doc.page_content for doc in docs])
67
- result = qa_pipeline(question=question, context=context)
68
  return result["answer"]
69
 
70
-
71
- # === Streamlit UI ===
72
- st.title("📄 AI Document Analyzer")
73
-
74
- uploaded_file = st.file_uploader("Upload a document (PDF or DOCX)", type=["pdf", "docx"])
75
- input_text = st.text_area("Or paste your text here", height=200)
76
-
77
- if st.button("Analyze"):
78
- if uploaded_file:
79
- file_bytes = uploaded_file.read()
80
- file_ext = uploaded_file.name.split(".")[-1]
81
- if file_ext == "pdf":
82
- text = load_pdf(io.BytesIO(file_bytes))
83
- elif file_ext == "docx":
84
- text = load_docx(io.BytesIO(file_bytes))
85
- else:
86
- st.error("Unsupported file format.")
87
- st.stop()
88
- elif input_text:
89
- text = input_text
90
- else:
91
- st.warning("Please upload a file or paste text.")
92
- st.stop()
93
-
94
- with st.spinner("Analyzing..."):
95
- analysis = analyze_text_structured(text)
96
- vectorstore = get_vectorstore_from_text(text)
97
-
98
- st.subheader("🔍 Summary")
99
- st.write(analysis.summary.summary)
100
-
101
- st.subheader("📌 Key Points")
102
- for point in analysis.key_points:
103
- st.markdown(f"- {point.point}")
104
-
105
- st.subheader("❓ Ask a Question")
106
- user_question = st.text_input("What do you want to know?")
107
- if user_question:
108
- with st.spinner("Searching for an answer..."):
109
- answer = answer_question(vectorstore, user_question)
110
- st.success(f"💬 Answer: {answer}")
 
1
+ # app.py
2
+
3
+ from fastapi import FastAPI, UploadFile, File
4
+ from pydantic import BaseModel
5
+ from typing import List
6
+ import fitz # PyMuPDF
7
  from transformers import pipeline
8
+ from sentence_transformers import SentenceTransformer
9
+ from langchain.vectorstores import FAISS
10
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
11
  from langchain_text_splitters import RecursiveCharacterTextSplitter
12
+ from langchain.schema import Document
13
+ from langchain.chains.question_answering import load_qa_chain
14
+ from langchain.llms import HuggingFacePipeline
15
+ from langchain_core.documents import Document as LangchainDocument
 
 
 
16
 
17
+ # --- Init FastAPI ---
18
+ app = FastAPI()
19
 
20
+ # --- Summarizer ---
21
  summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
22
 
23
+ # --- Question Answering ---
24
+ qa_pipe = pipeline("question-answering", model="deepset/roberta-base-squad2")
25
 
26
+ # --- Embedding model ---
27
+ embedding_model = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-small-en-v1.5")
28
 
29
+ # --- Text Splitter ---
30
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
31
 
32
+ # --- Pydantic schemas ---
33
  class Summary(BaseModel):
34
  summary: str
35
 
 
40
  summary: Summary
41
  key_points: List[KeyPoint]
42
 
43
+ class QARequest(BaseModel):
44
+ question: str
45
+ context: str
46
+
47
+ class QAResponse(BaseModel):
48
+ answer: str
49
 
50
+ # --- PDF Text Extractor ---
51
+ def extract_text_from_pdf(pdf_file: UploadFile) -> str:
 
52
  text = ""
53
+ with fitz.open(stream=pdf_file.file.read(), filetype="pdf") as doc:
54
+ for page in doc:
55
+ text += page.get_text()
56
  return text
57
 
58
+ # --- Analyze Text (summarization) ---
59
+ def analyze_text_structured(text: str) -> DocumentAnalysis:
60
+ chunks = text_splitter.split_text(text)
61
+ summaries = []
62
+ for chunk in chunks:
63
+ result = summarizer(chunk, max_length=200, min_length=50, do_sample=False)
64
+ if result:
65
+ summaries.append(result[0]["summary_text"])
66
+
67
+ full_summary = " ".join(summaries)
68
+ key_points = [KeyPoint(point=line.strip()) for line in full_summary.split(". ") if line.strip()]
69
+ return DocumentAnalysis(summary=Summary(summary=full_summary), key_points=key_points)
70
+
71
+ # --- Question Answering ---
72
+ def answer_question(question: str, context: str) -> str:
73
+ result = qa_pipe(question=question, context=context)
 
 
 
 
 
 
 
74
  return result["answer"]
75
 
76
+ # --- PDF Upload + Analysis Route ---
77
+ @app.post("/analyze-pdf", response_model=DocumentAnalysis)
78
+ async def analyze_pdf(file: UploadFile = File(...)):
79
+ text = extract_text_from_pdf(file)
80
+ analysis = analyze_text_structured(text)
81
+ return analysis
82
+
83
+ # --- Question Answering Route ---
84
+ @app.post("/qa", response_model=QAResponse)
85
+ async def ask_question(qa_request: QARequest):
86
+ answer = answer_question(qa_request.question, qa_request.context)
87
+ return QAResponse(answer=answer)
88
+
89
+ # --- Embedding Search (FAISS) Demo ---
90
+ @app.post("/search-chunks")
91
+ async def search_chunks(file: UploadFile = File(...), query: str = ""):
92
+ text = extract_text_from_pdf(file)
93
+ chunks = text_splitter.split_text(text)
94
+ documents = [LangchainDocument(page_content=chunk) for chunk in chunks]
95
+
96
+ # Create FAISS vector store
97
+ db = FAISS.from_documents(documents, embedding_model)
98
+
99
+ # Similarity search
100
+ results = db.similarity_search(query, k=3)
101
+ return {"results": [doc.page_content for doc in results]}