vkumartr commited on
Commit
5c635bd
·
verified ·
1 Parent(s): 63eb4a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +239 -27
app.py CHANGED
@@ -1,17 +1,23 @@
1
  import uvicorn
2
  from fastapi.staticfiles import StaticFiles
3
  import hashlib
 
4
  from fastapi import FastAPI, Header, Query, Depends, HTTPException
 
 
 
 
5
  from pymongo import MongoClient
 
6
  import boto3
7
  import openai
8
  import os
9
- import traceback
 
10
  import json
11
  from dotenv import load_dotenv
12
  import base64
13
  from bson.objectid import ObjectId
14
- import logging
15
 
16
  db_client = None
17
  load_dotenv()
@@ -20,26 +26,26 @@ load_dotenv()
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger(__name__)
22
 
23
- # Validate and Load Environment Variables
24
  MONGODB_URI = os.getenv("MONGODB_URI")
25
  DATABASE_NAME = os.getenv("DATABASE_NAME")
26
  COLLECTION_NAME = os.getenv("COLLECTION_NAME", "invoice_collection")
27
- AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY")
28
- AWS_SECRET_KEY = os.getenv("AWS_SECRET_KEY")
29
- S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
30
- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
31
- API_KEY = os.getenv("API_KEY")
32
 
33
- if not all([MONGODB_URI, DATABASE_NAME, AWS_ACCESS_KEY, AWS_SECRET_KEY, S3_BUCKET_NAME, OPENAI_API_KEY, API_KEY]):
34
- raise ValueError("One or more required environment variables are missing.")
 
 
 
 
35
 
36
  # Initialize MongoDB Connection
37
  db_client = MongoClient(MONGODB_URI)
38
  db = db_client[DATABASE_NAME]
39
  invoice_collection = db[COLLECTION_NAME]
40
- openai.api_key = OPENAI_API_KEY
41
 
42
  app = FastAPI(docs_url='/')
 
 
43
 
44
  @app.on_event("startup")
45
  def startup_db():
@@ -49,6 +55,15 @@ def startup_db():
49
  except Exception as e:
50
  logger.error(f"MongoDB connection failed: {str(e)}")
51
 
 
 
 
 
 
 
 
 
 
52
  # S3 Client
53
  s3_client = boto3.client(
54
  's3',
@@ -56,13 +71,164 @@ s3_client = boto3.client(
56
  aws_secret_access_key=AWS_SECRET_KEY
57
  )
58
 
 
59
  def fetch_file_from_s3(file_key):
60
  try:
61
  response = s3_client.get_object(Bucket=S3_BUCKET_NAME, Key=file_key)
62
- return response['Body'].read(), response['ContentType']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  except Exception as e:
64
- logger.error(f"Failed to fetch file from S3: {e}")
65
- raise HTTPException(status_code=500, detail=f"S3 fetch error: {str(e)}")
66
 
67
  def extract_text_from_s3(file_key, content_type):
68
  return "Extracted text from file", 1 # Placeholder for real extraction logic
@@ -73,42 +239,88 @@ def convert_to_base64(file_key):
73
  def generate_summary(extracted_text):
74
  return "Summarized text" # Placeholder
75
 
 
 
 
 
 
 
 
 
 
 
76
  def verify_api_key(api_key: str = Header(...)):
77
  if api_key != API_KEY:
78
  raise HTTPException(status_code=401, detail="Invalid API Key")
79
 
 
 
 
 
80
  @app.get("/ocr/extraction")
81
  def extract_text_from_file(
82
  api_key: str = Depends(verify_api_key),
83
- file_key: str = Query(..., description="S3 file key"),
84
- document_type: str = Query(..., description="Document type"),
85
  entity_ref_key: str = Query(..., description="Entity Reference Key")
86
  ):
 
87
  try:
88
  existing_document = invoice_collection.find_one({"entityrefkey": entity_ref_key})
 
89
  if existing_document:
90
  existing_document["_id"] = str(existing_document["_id"])
91
- return {"message": "Document Retrieved from MongoDB.", "document": existing_document}
92
-
93
- file_data, content_type = fetch_file_from_s3(file_key)
 
 
 
 
 
 
94
  extracted_text, num_pages = extract_text_from_s3(file_key, content_type)
 
 
95
  base64DataResp = None
96
  summary = None
97
  if num_pages <= 2:
98
- base64DataResp = convert_to_base64(file_key)
99
- summary = generate_summary(extracted_text)
100
-
 
101
  document = {
102
  "file_key": file_key,
103
  "file_type": content_type,
104
  "document_type": document_type,
105
  "entityrefkey": entity_ref_key,
106
  "num_pages": num_pages,
107
- "base64DataResp": base64DataResp,
108
- "extracted_text": extracted_text if num_pages <= 2 else "Skipped",
109
- "summary": summary,
110
  }
 
111
  inserted_doc = invoice_collection.insert_one(document)
112
- return {"message": "Document stored in MongoDB", "document_id": str(inserted_doc.inserted_id)}
 
 
 
 
 
 
 
 
 
113
  except Exception as e:
114
- return {"error": {"type": type(e).__name__, "message": str(e), "traceback": traceback.format_exc()}}
 
 
 
 
 
 
 
 
 
 
 
 
1
  import uvicorn
2
  from fastapi.staticfiles import StaticFiles
3
  import hashlib
4
+ from enum import Enum
5
  from fastapi import FastAPI, Header, Query, Depends, HTTPException
6
+ from PIL import Image
7
+ import io
8
+ import fitz # PyMuPDF for PDF handling
9
+ import logging
10
  from pymongo import MongoClient
11
+
12
  import boto3
13
  import openai
14
  import os
15
+ import traceback # For detailed traceback of errors
16
+ import re
17
  import json
18
  from dotenv import load_dotenv
19
  import base64
20
  from bson.objectid import ObjectId
 
21
 
22
  db_client = None
23
  load_dotenv()
 
26
  logging.basicConfig(level=logging.INFO)
27
  logger = logging.getLogger(__name__)
28
 
29
+ # MongoDB Configuration
30
  MONGODB_URI = os.getenv("MONGODB_URI")
31
  DATABASE_NAME = os.getenv("DATABASE_NAME")
32
  COLLECTION_NAME = os.getenv("COLLECTION_NAME", "invoice_collection")
 
 
 
 
 
33
 
34
+ # use_gpu = False
35
+ # output_dir = 'output'
36
+
37
+ # Check if environment variables are set
38
+ if not MONGODB_URI:
39
+ raise ValueError("MONGODB_URL is not set. Please add it to Hugging Face secrets.")
40
 
41
  # Initialize MongoDB Connection
42
  db_client = MongoClient(MONGODB_URI)
43
  db = db_client[DATABASE_NAME]
44
  invoice_collection = db[COLLECTION_NAME]
 
45
 
46
  app = FastAPI(docs_url='/')
47
+ use_gpu = False
48
+ output_dir = 'output'
49
 
50
  @app.on_event("startup")
51
  def startup_db():
 
55
  except Exception as e:
56
  logger.error(f"MongoDB connection failed: {str(e)}")
57
 
58
+ # AWS S3 Configuration
59
+ API_KEY = os.getenv("API_KEY")
60
+ AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY")
61
+ AWS_SECRET_KEY = os.getenv("AWS_SECRET_KEY")
62
+ S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
63
+
64
+ # OpenAI Configuration
65
+ openai.api_key = os.getenv("OPENAI_API_KEY")
66
+
67
  # S3 Client
68
  s3_client = boto3.client(
69
  's3',
 
71
  aws_secret_access_key=AWS_SECRET_KEY
72
  )
73
 
74
+ # Function to fetch file from S3
75
  def fetch_file_from_s3(file_key):
76
  try:
77
  response = s3_client.get_object(Bucket=S3_BUCKET_NAME, Key=file_key)
78
+ content_type = response['ContentType'] # Retrieve MIME type
79
+ file_data = response['Body'].read()
80
+ return file_data, content_type # Return file data as BytesIO
81
+ except Exception as e:
82
+ raise Exception(f"Failed to fetch file from S3: {str(e)}")
83
+
84
+ # Function to summarize text using OpenAI GPT
85
+ def extract_invoice_data(file_data, content_type):
86
+ system_prompt = "You are an expert in document data extraction."
87
+
88
+ # Convert file to Base64
89
+ base64_encoded = base64.b64encode(file_data).decode('utf-8')
90
+
91
+ # Determine the correct MIME type for OpenAI
92
+ if content_type.startswith("image/"):
93
+ mime_type = content_type # e.g., image/png, image/jpeg
94
+ elif content_type == "application/pdf":
95
+ mime_type = "application/pdf"
96
+ else:
97
+ raise ValueError(f"Unsupported content type: {content_type}")
98
+
99
+ try:
100
+ response = openai.ChatCompletion.create(
101
+ model="gpt-4o-mini",
102
+ messages=[
103
+ {"role": "system", "content": system_prompt},
104
+ {
105
+ "role": "user",
106
+ "content": [
107
+ {
108
+ "type": "image_url",
109
+ "image_url": {
110
+ "url": f"data:{mime_type};base64,{base64_encoded}"
111
+ }
112
+ }
113
+ ]
114
+ }
115
+ ],
116
+ response_format={
117
+ "type": "json_schema",
118
+ "json_schema": {
119
+ "name": "invoice",
120
+ "strict": True,
121
+ "schema": {
122
+ "type": "object",
123
+ "title": "Invoice Information Extractor",
124
+ "$schema": "http://json-schema.org/draft-07/schema#",
125
+ "properties": {
126
+ "LineItems": {
127
+ "type": "array",
128
+ "items": {
129
+ "type": "object",
130
+ "required": [
131
+ "ProductCode",
132
+ "Description",
133
+ "Amount"
134
+ ],
135
+ "properties": {
136
+ "ProductCode": {
137
+ "type": "string",
138
+ "title": "Product Code",
139
+ "description": "The code of the product"
140
+ },
141
+ "Description": {
142
+ "type": "string",
143
+ "title": "Description",
144
+ "description": "Description of the product"
145
+ },
146
+ "Amount": {
147
+ "type": "number",
148
+ "title": "Amount",
149
+ "description": "The amount of the product"
150
+ }
151
+ },
152
+ "additionalProperties": False
153
+ },
154
+ "title": "Line Items",
155
+ "description": "List of line items on the invoice"
156
+ },
157
+ "TaxAmount": {
158
+ "type": "number",
159
+ "title": "Tax Amount",
160
+ "description": "The tax amount on the invoice"
161
+ },
162
+ "VendorGST": {
163
+ "type": "string",
164
+ "title": "Vendor GST",
165
+ "description": "The GST number of the vendor"
166
+ },
167
+ "VendorName": {
168
+ "type": "string",
169
+ "title": "Vendor Name",
170
+ "description": "The name of the vendor"
171
+ },
172
+ "InvoiceDate": {
173
+ "type": "string",
174
+ "title": "Invoice Date",
175
+ "description": "The date of the invoice (format: dd-MMM-yyyy)"
176
+ },
177
+ "TotalAmount": {
178
+ "type": "number",
179
+ "title": "Total Amount",
180
+ "description": "The total amount on the invoice"
181
+ },
182
+ "InvoiceNumber": {
183
+ "type": "string",
184
+ "title": "Invoice Number",
185
+ "description": "The number of the invoice"
186
+ },
187
+ "VendorAddress": {
188
+ "type": "string",
189
+ "title": "Vendor Address",
190
+ "description": "The address of the vendor"
191
+ },
192
+ "InvoiceCurrency": {
193
+ "type": "string",
194
+ "title": "Invoice Currency",
195
+ "description": "The currency used in the invoice (e.g., USD, INR, AUD)"
196
+ }
197
+ },
198
+ "required": [
199
+ "LineItems",
200
+ "TaxAmount",
201
+ "VendorGST",
202
+ "VendorName",
203
+ "InvoiceDate",
204
+ "TotalAmount",
205
+ "InvoiceNumber",
206
+ "VendorAddress",
207
+ "InvoiceCurrency"
208
+ ],
209
+ "additionalProperties": False,
210
+ "description": "Schema for extracting structured invoice data"
211
+ }
212
+ }
213
+ },
214
+ temperature=0.5,
215
+ max_tokens=16384
216
+ )
217
+
218
+ # Clean and parse JSON output
219
+ content = response.choices[0].message.content.strip()
220
+ cleaned_content = content.strip().strip('```json').strip('```')
221
+
222
+ try:
223
+ parsed_content = json.loads(cleaned_content)
224
+ return parsed_content
225
+ except json.JSONDecodeError as e:
226
+ logger.error(f"JSON Parse Error: {e}")
227
+ return None
228
+
229
  except Exception as e:
230
+ logger.error(f"Error in data extraction: {e}")
231
+ return {"error": str(e)}
232
 
233
  def extract_text_from_s3(file_key, content_type):
234
  return "Extracted text from file", 1 # Placeholder for real extraction logic
 
239
  def generate_summary(extracted_text):
240
  return "Summarized text" # Placeholder
241
 
242
+ def get_content_type_from_s3(file_key):
243
+ """Fetch the content type (MIME type) of a file stored in S3."""
244
+ try:
245
+ response = s3_client.head_object(Bucket=S3_BUCKET_NAME, Key=file_key)
246
+ return response.get('ContentType', 'application/octet-stream') # Default to binary if not found
247
+ except Exception as e:
248
+ raise Exception(f"Failed to get content type from S3: {str(e)}")
249
+
250
+
251
+ # Dependency to check API Key
252
  def verify_api_key(api_key: str = Header(...)):
253
  if api_key != API_KEY:
254
  raise HTTPException(status_code=401, detail="Invalid API Key")
255
 
256
+ @app.get("/")
257
+ def read_root():
258
+ return {"message": "Welcome to the Invoice Summarization API!"}
259
+
260
  @app.get("/ocr/extraction")
261
  def extract_text_from_file(
262
  api_key: str = Depends(verify_api_key),
263
+ file_key: str = Query(..., description="S3 file key for the file"),
264
+ document_type: str = Query(..., description="Type of document"),
265
  entity_ref_key: str = Query(..., description="Entity Reference Key")
266
  ):
267
+ """Extract text from a PDF or Image stored in S3 and process it based on document size."""
268
  try:
269
  existing_document = invoice_collection.find_one({"entityrefkey": entity_ref_key})
270
+
271
  if existing_document:
272
  existing_document["_id"] = str(existing_document["_id"])
273
+ return {
274
+ "message": "Document Retrieved from MongoDB.",
275
+ "document": existing_document
276
+ }
277
+
278
+ # Retrieve file from S3 and determine content type (Ensure this step is implemented)
279
+ content_type = get_content_type_from_s3(file_key) # Implement this function
280
+
281
+ # Extract text (Ensure Extraction function is implemented)
282
  extracted_text, num_pages = extract_text_from_s3(file_key, content_type)
283
+
284
+ # Define values for small/large files
285
  base64DataResp = None
286
  summary = None
287
  if num_pages <= 2:
288
+ base64DataResp = convert_to_base64(file_key) # Implement this function
289
+ summary = generate_summary(extracted_text) # Implement this function
290
+
291
+ # Store extracted data in MongoDB
292
  document = {
293
  "file_key": file_key,
294
  "file_type": content_type,
295
  "document_type": document_type,
296
  "entityrefkey": entity_ref_key,
297
  "num_pages": num_pages,
298
+ "base64DataResp": base64DataResp, # Only for small files
299
+ "extracted_text": extracted_text,
300
+ "summary": summary, # Only for small files
301
  }
302
+
303
  inserted_doc = invoice_collection.insert_one(document)
304
+ document_id = str(inserted_doc.inserted_id)
305
+
306
+ return {
307
+ "message": "Document successfully stored in MongoDB",
308
+ "document_id": document_id,
309
+ "file_key": file_key,
310
+ "num_pages": num_pages,
311
+ "summary": summary if summary else "Skipped for large documents"
312
+ }
313
+
314
  except Exception as e:
315
+ error_details = {
316
+ "error_type": type(e).__name__,
317
+ "error_message": str(e),
318
+ "traceback": traceback.format_exc()
319
+ }
320
+ return {"error": error_details}
321
+
322
+ # Serve the output folder as static files
323
+ app.mount("/output", StaticFiles(directory="output", follow_symlink=True, html=True), name="output")
324
+
325
+ if __name__ == '__main__':
326
+ uvicorn.run(app=app)