File size: 3,392 Bytes
5a30f3b
 
 
 
 
 
3d4a265
5a30f3b
 
 
3d4a265
5a30f3b
 
 
 
02e74b1
5a30f3b
 
3d4a265
5a30f3b
1c2fbbd
02e74b1
5a30f3b
 
02e74b1
5a30f3b
 
3201029
5a30f3b
 
02e74b1
5a30f3b
02e74b1
3d4a265
 
 
 
02e74b1
 
 
3d4a265
02e74b1
5a30f3b
 
 
 
 
 
02e74b1
5a30f3b
 
3d4a265
5a30f3b
 
 
3d4a265
 
5a30f3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d4a265
6289bae
5a30f3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# app.py

from fastapi import FastAPI, UploadFile, File
from pydantic import BaseModel
from typing import List
import fitz  # PyMuPDF
from transformers import pipeline
from sentence_transformers import SentenceTransformer
from langchain.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain.chains.question_answering import load_qa_chain
from langchain.llms import HuggingFacePipeline
from langchain_core.documents import Document as LangchainDocument

# --- Init FastAPI ---
app = FastAPI()

# --- Summarizer ---
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")

# --- Question Answering ---
qa_pipe = pipeline("question-answering", model="deepset/roberta-base-squad2")

# --- Embedding model ---
embedding_model = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-small-en-v1.5")

# --- Text Splitter ---
text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)

# --- Pydantic schemas ---
class Summary(BaseModel):
    summary: str

class KeyPoint(BaseModel):
    point: str

class DocumentAnalysis(BaseModel):
    summary: Summary
    key_points: List[KeyPoint]

class QARequest(BaseModel):
    question: str
    context: str

class QAResponse(BaseModel):
    answer: str

# --- PDF Text Extractor ---
def extract_text_from_pdf(pdf_file: UploadFile) -> str:
    text = ""
    with fitz.open(stream=pdf_file.file.read(), filetype="pdf") as doc:
        for page in doc:
            text += page.get_text()
    return text

# --- Analyze Text (summarization) ---
def analyze_text_structured(text: str) -> DocumentAnalysis:
    chunks = text_splitter.split_text(text)
    summaries = []
    for chunk in chunks:
        result = summarizer(chunk, max_length=200, min_length=50, do_sample=False)
        if result:
            summaries.append(result[0]["summary_text"])

    full_summary = " ".join(summaries)
    key_points = [KeyPoint(point=line.strip()) for line in full_summary.split(". ") if line.strip()]
    return DocumentAnalysis(summary=Summary(summary=full_summary), key_points=key_points)

# --- Question Answering ---
def answer_question(question: str, context: str) -> str:
    result = qa_pipe(question=question, context=context)
    return result["answer"]

# --- PDF Upload + Analysis Route ---
@app.post("/analyze-pdf", response_model=DocumentAnalysis)
async def analyze_pdf(file: UploadFile = File(...)):
    text = extract_text_from_pdf(file)
    analysis = analyze_text_structured(text)
    return analysis

# --- Question Answering Route ---
@app.post("/qa", response_model=QAResponse)
async def ask_question(qa_request: QARequest):
    answer = answer_question(qa_request.question, qa_request.context)
    return QAResponse(answer=answer)

# --- Embedding Search (FAISS) Demo ---
@app.post("/search-chunks")
async def search_chunks(file: UploadFile = File(...), query: str = ""):
    text = extract_text_from_pdf(file)
    chunks = text_splitter.split_text(text)
    documents = [LangchainDocument(page_content=chunk) for chunk in chunks]

    # Create FAISS vector store
    db = FAISS.from_documents(documents, embedding_model)

    # Similarity search
    results = db.similarity_search(query, k=3)
    return {"results": [doc.page_content for doc in results]}