import uvicorn from fastapi.staticfiles import StaticFiles import hashlib from enum import Enum from fastapi import FastAPI, Header, Query, Depends, HTTPException from PIL import Image import io import fitz # PyMuPDF for PDF handling import logging from pymongo import MongoClient import boto3 import openai import os import traceback # For detailed traceback of errors import re import json from dotenv import load_dotenv import base64 from bson.objectid import ObjectId db_client = None load_dotenv() # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # MongoDB Configuration MONGODB_URI = os.getenv("MONGODB_URI") DATABASE_NAME = os.getenv("DATABASE_NAME") COLLECTION_NAME = os.getenv("COLLECTION_NAME") SCHEMA = os.getenv("SCHEMA") # Check if environment variables are set if not MONGODB_URI: raise ValueError("MONGODB_URI is not set. Please add it to your secrets.") # Initialize MongoDB Connection db_client = MongoClient(MONGODB_URI) db = db_client[DATABASE_NAME] invoice_collection = db[COLLECTION_NAME] schema_collection = db[SCHEMA] app = FastAPI(docs_url='/') use_gpu = False output_dir = 'output' @app.on_event("startup") def startup_db(): try: db_client.server_info() logger.info("MongoDB connection successful") except Exception as e: logger.error(f"MongoDB connection failed: {str(e)}") # AWS S3 Configuration API_KEY = os.getenv("API_KEY") AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY") AWS_SECRET_KEY = os.getenv("AWS_SECRET_KEY") S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME") # OpenAI Configuration openai.api_key = os.getenv("OPENAI_API_KEY") # S3 Client s3_client = boto3.client( 's3', aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY ) # Function to fetch file from S3 def fetch_file_from_s3(file_key): try: response = s3_client.get_object(Bucket=S3_BUCKET_NAME, Key=file_key) content_type = response['ContentType'] # Retrieve MIME type file_data = response['Body'].read() return file_data, content_type # Return file data as BytesIO except Exception as e: raise Exception(f"Failed to fetch file from S3: {str(e)}") # Updated extraction function that handles PDF and image files differently def extract_invoice_data(file_data, content_type, json_schema): """ For PDFs: Extract the embedded text using PyMuPDF (no OCR involved) For Images: Pass the Base64-encoded image to OpenAI (assuming a multimodal model) """ system_prompt = "You are an expert in document data extraction." base64_encoded_images = [] # To store Base64-encoded image data extracted_data = {} if content_type == "application/pdf": # Use PyMuPDF to extract text directly from the PDF try: doc = fitz.open(stream=file_data, filetype="pdf") num_pages = doc.page_count # Check if the number of pages exceeds 2 if num_pages > 2: raise ValueError("The PDF contains more than 2 pages, extraction not supported.") extracted_text = "" for page in doc: extracted_text += page.get_text() # Store the extracted text in the dictionary extracted_data["text"] = extracted_text except Exception as e: logger.error(f"Error extracting text from PDF: {e}") raise # Build a prompt containing the extracted text and the schema prompt = ( f"Extract the invoice data from the following PDF text. " f"Return only valid JSON that adheres to this schema:\n\n{json.dumps(json_schema, indent=2)}\n\n" f"PDF Text:\n{extracted_text}" ) elif content_type.startswith("image/"): # For images, determine if more than 2 images are provided try: img = Image.open(io.BytesIO(file_data)) # Open the image file num_images = img.n_frames # Get number of images (pages in the image file) if num_images > 2: raise ValueError("The image file contains more than 2 pages, extraction not supported.") # Process each image page if there are 1 or 2 pages for page_num in range(num_images): img.seek(page_num) # Move to the current page img_bytes = io.BytesIO() img.save(img_bytes, format="PNG") # Save each page as a PNG image in memory base64_encoded = base64.b64encode(img_bytes.getvalue()).decode('utf-8') base64_encoded_images.append(base64_encoded) # Add Base64 image data to the extracted data dictionary extracted_data["base64_images"] = base64_encoded_images # Build a prompt containing the image data for OpenAI prompt = f"Extract the invoice data from the following images (Base64 encoded). Return only valid JSON that adheres to this schema:\n\n{json.dumps(json_schema, indent=2)}\n\n" for base64_image in base64_encoded_images: prompt += f"Image Data URL: data:{content_type};base64,{base64_image}\n" except Exception as e: logger.error(f"Error handling images: {e}") raise else: raise ValueError(f"Unsupported content type: {content_type}") # Send request to OpenAI for data extraction try: response = openai.ChatCompletion.create( model="gpt-4o-mini", messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ], temperature=0.5, max_tokens=16384 ) content = response.choices[0].message.content.strip() cleaned_content = content.strip().strip('```json').strip('```') try: parsed_content = json.loads(cleaned_content) extracted_data["extracted_json"] = parsed_content # Store the parsed JSON data return extracted_data except json.JSONDecodeError as e: logger.error(f"JSON Parse Error: {e}") return {"error": f"JSON Parse Error: {str(e)}"} except Exception as e: logger.error(f"Error in data extraction: {e}") return {"error": str(e)} def get_content_type_from_s3(file_key): """Fetch the content type (MIME type) of a file stored in S3.""" try: response = s3_client.head_object(Bucket=S3_BUCKET_NAME, Key=file_key) return response.get('ContentType', 'application/octet-stream') # Default to binary if not found except Exception as e: raise Exception(f"Failed to get content type from S3: {str(e)}") # Dependency to check API Key def verify_api_key(api_key: str = Header(...)): if api_key != API_KEY: raise HTTPException(status_code=401, detail="Invalid API Key") @app.get("/") def read_root(): return {"message": "Welcome to the Invoice Summarization API!"} @app.get("/ocr/extraction") def extract_text_from_file( api_key: str = Depends(verify_api_key), file_key: str = Query(..., description="S3 file key for the file"), document_type: str = Query(..., description="Type of document"), entity_ref_key: str = Query(..., description="Entity Reference Key") ): """Extract text from a PDF or Image stored in S3 and process it based on document size.""" try: existing_document = invoice_collection.find_one({"entityrefkey": entity_ref_key}) if existing_document: existing_document["_id"] = str(existing_document["_id"]) return { "message": "Document Retrieved from MongoDB.", "document": existing_document } # Fetch dynamic schema based on document type schema_doc = schema_collection.find_one({"document_type": document_type}) if not schema_doc: raise ValueError("No schema found for the given document type") json_schema = schema_doc.get("json_schema") if not json_schema: raise ValueError("Schema is empty or not properly defined.") # Retrieve file from S3 and determine content type content_type = get_content_type_from_s3(file_key) file_data, _ = fetch_file_from_s3(file_key) extracted_data = extract_invoice_data(file_data, content_type, json_schema) # Build document for insertion document = { "file_key": file_key, "file_type": content_type, "document_type": document_type, "entityrefkey": entity_ref_key, "extracted_data": extracted_data } try: inserted_doc = invoice_collection.insert_one(document) document_id = str(inserted_doc.inserted_id) logger.info(f"Document inserted with ID: {document_id}") except Exception as e: logger.error(f"Error inserting document: {str(e)}") raise HTTPException(status_code=500, detail="Error inserting document into MongoDB") return { "message": "Document successfully stored in MongoDB", "document_id": document_id, "entityrefkey": entity_ref_key, "extracted_data": extracted_data } except Exception as e: error_details = { "error_type": type(e).__name__, "error_message": str(e), "traceback": traceback.format_exc() } return {"error": error_details} # Serve the output folder as static files app.mount("/output", StaticFiles(directory="output", follow_symlink=True, html=True), name="output") if __name__ == '__main__': uvicorn.run(app=app)