Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from pymongo import MongoClient | |
| from neo4j import GraphDatabase | |
| from transformers import BertTokenizer, BertModel | |
| import torch | |
| import fitz # PyMuPDF | |
| import uuid | |
| import tempfile | |
| import os | |
| import logging | |
| from pydantic import BaseModel | |
| from typing import List | |
| app = FastAPI() | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Get environment variables | |
| MONGO_DB_URL=os.getenv('MONGO_DB_URL') | |
| NEO_DB_HOST = os.getenv('NEO_DB_HOST') | |
| NEO_DB_USER = os.getenv('NEO_DB_USER') | |
| NEO_DB_PASSWORD = os.getenv('NEO_DB_PASSWORD') | |
| # MongoDB setup | |
| mongo_client = MongoClient(MONGO_DB_URL) | |
| db = mongo_client["pdf_db"] | |
| chunks_collection = db["chunks"] | |
| # Neo4j setup | |
| neo4j_driver = GraphDatabase.driver(NEO_DB_HOST, auth=(NEO_DB_USER, NEO_DB_PASSWORD)) | |
| # Load pre-trained BERT model and tokenizer | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| model = BertModel.from_pretrained('bert-base-uncased') | |
| class ChunkEmbedding(BaseModel): | |
| chunk_id: str | |
| text: str | |
| embedding: List[float] | |
| doc_id: str | |
| # Utility function to create embeddings | |
| def get_embeddings(text: str): | |
| inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512) | |
| outputs = model(**inputs) | |
| embeddings = outputs.last_hidden_state.mean(dim=1).flatten().tolist() | |
| return embeddings | |
| # Utility function to save relationship in Neo4j | |
| def save_relationship(doc_id: str, chunk_ids: List[str]): | |
| with neo4j_driver.session() as session: | |
| session.run("CREATE (d:Document {id: $doc_id})", doc_id=doc_id) | |
| for chunk_id in chunk_ids: | |
| session.run("CREATE (c:Chunk {id: $chunk_id})", chunk_id=chunk_id) | |
| session.run("MATCH (d:Document {id: $doc_id}), (c:Chunk {id: $chunk_id}) " | |
| "CREATE (d)-[:CONTAINS]->(c)", doc_id=doc_id, chunk_id=chunk_id) | |
| # Utility function to extract text from a PDF file | |
| def extract_text_from_pdf(file_path: str) -> str: | |
| text = "" | |
| pdf_document = fitz.open(file_path) | |
| for page_num in range(len(pdf_document)): | |
| page = pdf_document.load_page(page_num) | |
| text += page.get_text() | |
| return text | |
| # Utility function to break text into chunks | |
| def break_text_into_chunks(text: str, chunk_size: int = 512) -> List[str]: | |
| return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] | |
| async def upload(file: UploadFile = File(...), chunk_size: int = 512): | |
| try: | |
| doc_id = str(uuid.uuid4()) | |
| chunk_ids = [] | |
| # Create a temporary file to handle the PDF | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file: | |
| temp_file.write(await file.read()) | |
| temp_file_path = temp_file.name | |
| # Extract text from PDF | |
| text = extract_text_from_pdf(file_path=temp_file_path) | |
| logger.info("Extracted text from PDF") | |
| # Break text into chunks | |
| chunks = break_text_into_chunks(text, chunk_size) | |
| logger.info(f"Text broken into {len(chunks)} chunks") | |
| # Insert chunks and embeddings into MongoDB | |
| for chunk in chunks: | |
| chunk_id = str(uuid.uuid4()) | |
| chunk_ids.append(chunk_id) | |
| embedding = get_embeddings(chunk) | |
| chunk_embedding = ChunkEmbedding(chunk_id=chunk_id, text=chunk, embedding=embedding, doc_id=doc_id) | |
| chunks_collection.insert_one(chunk_embedding.dict()) | |
| logger.info(f"Inserted chunk {chunk_id} into MongoDB") | |
| # Save relationships in Neo4j | |
| save_relationship(doc_id, chunk_ids) | |
| logger.info(f"Saved relationships in Neo4j for doc_id {doc_id}") | |
| # Clean up temporary file | |
| os.remove(temp_file_path) | |
| return JSONResponse(content={"doc_id": doc_id, "chunk_ids": chunk_ids}) | |
| except Exception as e: | |
| logger.error(f"Error processing file upload: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |