NuExtract-large / app.py
amrita06's picture
modified filea
fd622e0
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from pymongo import MongoClient
from neo4j import GraphDatabase
from transformers import BertTokenizer, BertModel
import torch
import fitz # PyMuPDF
import uuid
import tempfile
import os
import logging
from pydantic import BaseModel
from typing import List
app = FastAPI()
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Get environment variables
MONGO_DB_URL=os.getenv('MONGO_DB_URL')
NEO_DB_HOST = os.getenv('NEO_DB_HOST')
NEO_DB_USER = os.getenv('NEO_DB_USER')
NEO_DB_PASSWORD = os.getenv('NEO_DB_PASSWORD')
# MongoDB setup
mongo_client = MongoClient(MONGO_DB_URL)
db = mongo_client["pdf_db"]
chunks_collection = db["chunks"]
# Neo4j setup
neo4j_driver = GraphDatabase.driver(NEO_DB_HOST, auth=(NEO_DB_USER, NEO_DB_PASSWORD))
# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
class ChunkEmbedding(BaseModel):
chunk_id: str
text: str
embedding: List[float]
doc_id: str
# Utility function to create embeddings
def get_embeddings(text: str):
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
outputs = model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1).flatten().tolist()
return embeddings
# Utility function to save relationship in Neo4j
def save_relationship(doc_id: str, chunk_ids: List[str]):
with neo4j_driver.session() as session:
session.run("CREATE (d:Document {id: $doc_id})", doc_id=doc_id)
for chunk_id in chunk_ids:
session.run("CREATE (c:Chunk {id: $chunk_id})", chunk_id=chunk_id)
session.run("MATCH (d:Document {id: $doc_id}), (c:Chunk {id: $chunk_id}) "
"CREATE (d)-[:CONTAINS]->(c)", doc_id=doc_id, chunk_id=chunk_id)
# Utility function to extract text from a PDF file
def extract_text_from_pdf(file_path: str) -> str:
text = ""
pdf_document = fitz.open(file_path)
for page_num in range(len(pdf_document)):
page = pdf_document.load_page(page_num)
text += page.get_text()
return text
# Utility function to break text into chunks
def break_text_into_chunks(text: str, chunk_size: int = 512) -> List[str]:
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
@app.post("/upload/")
async def upload(file: UploadFile = File(...), chunk_size: int = 512):
try:
doc_id = str(uuid.uuid4())
chunk_ids = []
# Create a temporary file to handle the PDF
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
temp_file.write(await file.read())
temp_file_path = temp_file.name
# Extract text from PDF
text = extract_text_from_pdf(file_path=temp_file_path)
logger.info("Extracted text from PDF")
# Break text into chunks
chunks = break_text_into_chunks(text, chunk_size)
logger.info(f"Text broken into {len(chunks)} chunks")
# Insert chunks and embeddings into MongoDB
for chunk in chunks:
chunk_id = str(uuid.uuid4())
chunk_ids.append(chunk_id)
embedding = get_embeddings(chunk)
chunk_embedding = ChunkEmbedding(chunk_id=chunk_id, text=chunk, embedding=embedding, doc_id=doc_id)
chunks_collection.insert_one(chunk_embedding.dict())
logger.info(f"Inserted chunk {chunk_id} into MongoDB")
# Save relationships in Neo4j
save_relationship(doc_id, chunk_ids)
logger.info(f"Saved relationships in Neo4j for doc_id {doc_id}")
# Clean up temporary file
os.remove(temp_file_path)
return JSONResponse(content={"doc_id": doc_id, "chunk_ids": chunk_ids})
except Exception as e:
logger.error(f"Error processing file upload: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)