Spaces:
Sleeping
Sleeping
File size: 5,195 Bytes
3daa0bb 77d141f 3daa0bb 77d141f 3daa0bb 1b5ae0f 3daa0bb 77d141f 3daa0bb 77d141f f148a7a 77d141f 3daa0bb 4f8f5c3 77d141f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | # from annotated_types import doc
from flask import Flask, request, jsonify
from flask_cors import CORS
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo import ASCENDING
from pymongo.operations import SearchIndexModel
from fastapi.middleware.cors import CORSMiddleware
import fitz
from dotenv import load_dotenv
# import numpy as np
# import pytesseract
# from pdf2image import convert_from_bytes
import img2pdf
# from PIL import Image
from google import genai
# import time
import io
import os
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
load_dotenv()
app = Flask(__name__)
CORS(app)
ocr_model = ocr_predictor(pretrained=True)
print("model downloaded")
client = MongoClient(os.getenv("MONGO_URI"))
db = client["NaviQ"]
collection: Collection = db["rag_db"]
collection.create_index([("organization_id", ASCENDING)])
# The embedding model
api_key = os.getenv("GEMINI_API_KEY")
genai_client = genai.Client(api_key=api_key)
def get_embedding(data):
"""Generates vector embeddings for the given data."""
result = genai_client.models.embed_content( model="gemini-embedding-001", contents=data)
return result.embeddings[0].values
def getChunks(text, chunk_size, overlap):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap)
return text_splitter.split_documents(text)
def get_query_results(org_id, query):
"""Gets results from a vector search query."""
query_embedding = get_embedding(query)
pipeline = [
{
"$vectorSearch": {
"index": "vector_index",
"queryVector": query_embedding,
"path": "embedding",
"exact": True,
"limit": 5,
"filter": {
"organization_id": org_id
}
}
}, {
"$project": {
"_id": 0,
"text": 1
}
}
]
results = collection.aggregate(pipeline=pipeline)
array_of_results = []
for doc in results:
array_of_results.append(doc)
return array_of_results
def extract_text_from_doctr(result):
json_export = result.export()
text = ""
for page in json_export["pages"]:
for block in page["blocks"]:
for line in block["lines"]:
text += " ".join([w["value"] for w in line["words"]]) + "\n"
return text
@app.route("/upload", methods=["POST"])
def upload_file():
organization_id = request.form.get("organization_id")
file = request.files.get("file")
if not file or not organization_id:
return jsonify({"error": "Missing file or organization_id"}), 400
contents = file.read()
doc = fitz.open(stream=contents, filetype="pdf")
print(doc)
text = ""
# Here the case 1
if file.filename.lower().endswith(".pdf"):
for page in doc:
text += page.get_text()
if text.strip() == "":
# Here I will use OCR
ocr_doc = DocumentFile.from_pdf(io.BytesIO(contents))
result = ocr_model(ocr_doc)
text = extract_text_from_doctr(result)
else:
pdf_bytes = img2pdf.convert(contents)
ocr_doc = DocumentFile.from_pdf(io.BytesIO(pdf_bytes))
result = ocr_model(ocr_doc)
text = extract_text_from_doctr(result)
print(text)
# return text
doc_obj = [Document(page_content=text)]
documents = getChunks(doc_obj, 400, 20)
# print(documents)
docs_to_insert = [{
"organization_id": organization_id, # app-write id
"text": d.page_content,
"embedding": get_embedding(d.page_content)
} for d in documents]
collection.insert_many(docs_to_insert)
index_name="vector_index"
search_index_model = SearchIndexModel(
definition = {
"fields": [
{
"type": "vector",
"numDimensions": 3072,
"path": "embedding",
"similarity": "cosine"
},
{
"type": "filter",
"path": "organization_id"
}
]
},
name = index_name,
type = "vectorSearch"
)
# collection.create_search_index(model=search_index_model)
try:
collection.create_search_index(model=search_index_model)
except Exception:
pass
return {"message": "File uploaded successfully"}
@app.route("/query", methods=["GET"])
def query():
organization_id = request.args.get("organization_id")
question = request.args.get("question")
context_docs = get_query_results(organization_id, question)
context_string = " ".join([doc["text"] for doc in context_docs])
prompt = f"""Use the following pieces of context to answer the question at the end.
{context_string}
Question: {question}
"""
response = genai_client.models.generate_content(model='gemini-2.5-flash', contents=prompt)
return jsonify({"answer": response.text})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860) |