Spaces:
Sleeping
Sleeping
| import uvicorn | |
| from fastapi.staticfiles import StaticFiles | |
| import hashlib | |
| from enum import Enum | |
| from fastapi import FastAPI, Header, Query, Depends, HTTPException | |
| from PIL import Image | |
| import io | |
| import fitz # PyMuPDF for PDF handling | |
| import logging | |
| from pymongo import MongoClient | |
| import boto3 | |
| import openai | |
| import os | |
| import traceback # For detailed traceback of errors | |
| import re | |
| import json | |
| from dotenv import load_dotenv | |
| import base64 | |
| from bson.objectid import ObjectId | |
| db_client = None | |
| load_dotenv() | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # MongoDB Configuration | |
| MONGODB_URI = os.getenv("MONGODB_URI") | |
| DATABASE_NAME = os.getenv("DATABASE_NAME") | |
| COLLECTION_NAME = os.getenv("COLLECTION_NAME") | |
| SCHEMA = os.getenv("SCHEMA") | |
| # Check if environment variables are set | |
| if not MONGODB_URI: | |
| raise ValueError("MONGODB_URI is not set. Please add it to your secrets.") | |
| # Initialize MongoDB Connection | |
| db_client = MongoClient(MONGODB_URI) | |
| db = db_client[DATABASE_NAME] | |
| invoice_collection = db[COLLECTION_NAME] | |
| schema_collection = db[SCHEMA] | |
| app = FastAPI(docs_url='/') | |
| use_gpu = False | |
| output_dir = 'output' | |
| def startup_db(): | |
| try: | |
| db_client.server_info() | |
| logger.info("MongoDB connection successful") | |
| except Exception as e: | |
| logger.error(f"MongoDB connection failed: {str(e)}") | |
| # AWS S3 Configuration | |
| API_KEY = os.getenv("API_KEY") | |
| AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY") | |
| AWS_SECRET_KEY = os.getenv("AWS_SECRET_KEY") | |
| S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME") | |
| # OpenAI Configuration | |
| openai.api_key = os.getenv("OPENAI_API_KEY") | |
| # S3 Client | |
| s3_client = boto3.client( | |
| 's3', | |
| aws_access_key_id=AWS_ACCESS_KEY, | |
| aws_secret_access_key=AWS_SECRET_KEY | |
| ) | |
| # Function to fetch file from S3 | |
| def fetch_file_from_s3(file_key): | |
| try: | |
| response = s3_client.get_object(Bucket=S3_BUCKET_NAME, Key=file_key) | |
| content_type = response['ContentType'] # Retrieve MIME type | |
| file_data = response['Body'].read() | |
| return file_data, content_type # Return file data as BytesIO | |
| except Exception as e: | |
| raise Exception(f"Failed to fetch file from S3: {str(e)}") | |
| # Updated extraction function that handles PDF and image files differently | |
| def extract_invoice_data(file_data, content_type, json_schema): | |
| """ | |
| For PDFs: Extract the embedded text using PyMuPDF (no OCR involved) | |
| For Images: Pass the Base64-encoded image to OpenAI (assuming a multimodal model) | |
| """ | |
| system_prompt = "You are an expert in document data extraction." | |
| base64_encoded_images = [] # To store Base64-encoded image data | |
| extracted_data = {} | |
| if content_type == "application/pdf": | |
| # Use PyMuPDF to extract text directly from the PDF | |
| try: | |
| doc = fitz.open(stream=file_data, filetype="pdf") | |
| num_pages = doc.page_count | |
| # Check if the number of pages exceeds 2 | |
| if num_pages > 2: | |
| raise ValueError("The PDF contains more than 2 pages, extraction not supported.") | |
| extracted_text = "" | |
| for page in doc: | |
| extracted_text += page.get_text() | |
| # Store the extracted text in the dictionary | |
| extracted_data["text"] = extracted_text | |
| except Exception as e: | |
| logger.error(f"Error extracting text from PDF: {e}") | |
| raise | |
| # Build a prompt containing the extracted text and the schema | |
| prompt = ( | |
| f"Extract the invoice data from the following PDF text. " | |
| f"Return only valid JSON that adheres to this schema:\n\n{json.dumps(json_schema, indent=2)}\n\n" | |
| f"PDF Text:\n{extracted_text}" | |
| ) | |
| elif content_type.startswith("image/"): | |
| # For images, determine if more than 2 images are provided | |
| try: | |
| img = Image.open(io.BytesIO(file_data)) # Open the image file | |
| num_images = img.n_frames # Get number of images (pages in the image file) | |
| if num_images > 2: | |
| raise ValueError("The image file contains more than 2 pages, extraction not supported.") | |
| # Process each image page if there are 1 or 2 pages | |
| for page_num in range(num_images): | |
| img.seek(page_num) # Move to the current page | |
| img_bytes = io.BytesIO() | |
| img.save(img_bytes, format="PNG") # Save each page as a PNG image in memory | |
| base64_encoded = base64.b64encode(img_bytes.getvalue()).decode('utf-8') | |
| base64_encoded_images.append(base64_encoded) | |
| # Add Base64 image data to the extracted data dictionary | |
| extracted_data["base64_images"] = base64_encoded_images | |
| # Build a prompt containing the image data for OpenAI | |
| prompt = f"Extract the invoice data from the following images (Base64 encoded). Return only valid JSON that adheres to this schema:\n\n{json.dumps(json_schema, indent=2)}\n\n" | |
| for base64_image in base64_encoded_images: | |
| prompt += f"Image Data URL: data:{content_type};base64,{base64_image}\n" | |
| except Exception as e: | |
| logger.error(f"Error handling images: {e}") | |
| raise | |
| else: | |
| raise ValueError(f"Unsupported content type: {content_type}") | |
| # Send request to OpenAI for data extraction | |
| try: | |
| response = openai.ChatCompletion.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=0.5, | |
| max_tokens=16384 | |
| ) | |
| content = response.choices[0].message.content.strip() | |
| cleaned_content = content.strip().strip('```json').strip('```') | |
| try: | |
| parsed_content = json.loads(cleaned_content) | |
| extracted_data["extracted_json"] = parsed_content # Store the parsed JSON data | |
| return extracted_data | |
| except json.JSONDecodeError as e: | |
| logger.error(f"JSON Parse Error: {e}") | |
| return {"error": f"JSON Parse Error: {str(e)}"} | |
| except Exception as e: | |
| logger.error(f"Error in data extraction: {e}") | |
| return {"error": str(e)} | |
| def get_content_type_from_s3(file_key): | |
| """Fetch the content type (MIME type) of a file stored in S3.""" | |
| try: | |
| response = s3_client.head_object(Bucket=S3_BUCKET_NAME, Key=file_key) | |
| return response.get('ContentType', 'application/octet-stream') # Default to binary if not found | |
| except Exception as e: | |
| raise Exception(f"Failed to get content type from S3: {str(e)}") | |
| # Dependency to check API Key | |
| def verify_api_key(api_key: str = Header(...)): | |
| if api_key != API_KEY: | |
| raise HTTPException(status_code=401, detail="Invalid API Key") | |
| def read_root(): | |
| return {"message": "Welcome to the Invoice Summarization API!"} | |
| def extract_text_from_file( | |
| api_key: str = Depends(verify_api_key), | |
| file_key: str = Query(..., description="S3 file key for the file"), | |
| document_type: str = Query(..., description="Type of document"), | |
| entity_ref_key: str = Query(..., description="Entity Reference Key") | |
| ): | |
| """Extract text from a PDF or Image stored in S3 and process it based on document size.""" | |
| try: | |
| existing_document = invoice_collection.find_one({"entityrefkey": entity_ref_key}) | |
| if existing_document: | |
| existing_document["_id"] = str(existing_document["_id"]) | |
| return { | |
| "message": "Document Retrieved from MongoDB.", | |
| "document": existing_document | |
| } | |
| # Fetch dynamic schema based on document type | |
| schema_doc = schema_collection.find_one({"document_type": document_type}) | |
| if not schema_doc: | |
| raise ValueError("No schema found for the given document type") | |
| json_schema = schema_doc.get("json_schema") | |
| if not json_schema: | |
| raise ValueError("Schema is empty or not properly defined.") | |
| # Retrieve file from S3 and determine content type | |
| content_type = get_content_type_from_s3(file_key) | |
| file_data, _ = fetch_file_from_s3(file_key) | |
| extracted_data = extract_invoice_data(file_data, content_type, json_schema) | |
| # Build document for insertion | |
| document = { | |
| "file_key": file_key, | |
| "file_type": content_type, | |
| "document_type": document_type, | |
| "entityrefkey": entity_ref_key, | |
| "extracted_data": extracted_data | |
| } | |
| try: | |
| inserted_doc = invoice_collection.insert_one(document) | |
| document_id = str(inserted_doc.inserted_id) | |
| logger.info(f"Document inserted with ID: {document_id}") | |
| except Exception as e: | |
| logger.error(f"Error inserting document: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Error inserting document into MongoDB") | |
| return { | |
| "message": "Document successfully stored in MongoDB", | |
| "document_id": document_id, | |
| "entityrefkey": entity_ref_key, | |
| "extracted_data": extracted_data | |
| } | |
| except Exception as e: | |
| error_details = { | |
| "error_type": type(e).__name__, | |
| "error_message": str(e), | |
| "traceback": traceback.format_exc() | |
| } | |
| return {"error": error_details} | |
| # Serve the output folder as static files | |
| app.mount("/output", StaticFiles(directory="output", follow_symlink=True, html=True), name="output") | |
| if __name__ == '__main__': | |
| uvicorn.run(app=app) |