vkumartr commited on
Commit
2125a91
·
verified ·
1 Parent(s): c4f2ca9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -33
app.py CHANGED
@@ -46,6 +46,7 @@ app = FastAPI(docs_url='/')
46
  use_gpu = False
47
  output_dir = 'output'
48
 
 
49
  @app.on_event("startup")
50
  def startup_db():
51
  try:
@@ -54,6 +55,7 @@ def startup_db():
54
  except Exception as e:
55
  logger.error(f"MongoDB connection failed: {str(e)}")
56
 
 
57
  # AWS S3 Configuration
58
  API_KEY = os.getenv("API_KEY")
59
  AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY")
@@ -70,6 +72,7 @@ s3_client = boto3.client(
70
  aws_secret_access_key=AWS_SECRET_KEY
71
  )
72
 
 
73
  # Function to fetch file from S3
74
  def fetch_file_from_s3(file_key):
75
  try:
@@ -80,47 +83,84 @@ def fetch_file_from_s3(file_key):
80
  except Exception as e:
81
  raise Exception(f"Failed to fetch file from S3: {str(e)}")
82
 
83
- # Function to summarize text using OpenAI GPT
 
84
  def extract_invoice_data(file_data, content_type, json_schema):
 
 
 
 
85
  system_prompt = "You are an expert in document data extraction."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- # Convert file to Base64
88
- base64_encoded = base64.b64encode(file_data).decode('utf-8')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- # Determine the correct MIME type for OpenAI
91
- if content_type.startswith("image/"):
92
- mime_type = content_type # e.g., image/png, image/jpeg
93
- elif content_type == "application/pdf":
94
- mime_type = "application/pdf"
95
  else:
96
  raise ValueError(f"Unsupported content type: {content_type}")
97
 
 
98
  try:
99
  response = openai.ChatCompletion.create(
100
  model="gpt-4o-mini",
101
  messages=[
102
  {"role": "system", "content": system_prompt},
103
- {
104
- "role": "user",
105
- "content": [
106
- {
107
- "type": "image_url",
108
- "image_url": {
109
- "url": f"data:{mime_type};base64,{base64_encoded}"
110
- }
111
- }
112
- ]
113
- }
114
  ],
115
- response_format={
116
- "type": "json_schema",
117
- "json_schema": json_schema
118
- },
119
  temperature=0.5,
120
  max_tokens=16384
121
  )
122
 
123
- # Clean and parse JSON output
124
  content = response.choices[0].message.content.strip()
125
  cleaned_content = content.strip().strip('```json').strip('```')
126
 
@@ -129,12 +169,13 @@ def extract_invoice_data(file_data, content_type, json_schema):
129
  return parsed_content
130
  except json.JSONDecodeError as e:
131
  logger.error(f"JSON Parse Error: {e}")
132
- return None
133
 
134
  except Exception as e:
135
  logger.error(f"Error in data extraction: {e}")
136
  return {"error": str(e)}
137
 
 
138
  def get_content_type_from_s3(file_key):
139
  """Fetch the content type (MIME type) of a file stored in S3."""
140
  try:
@@ -143,21 +184,24 @@ def get_content_type_from_s3(file_key):
143
  except Exception as e:
144
  raise Exception(f"Failed to get content type from S3: {str(e)}")
145
 
 
146
  # Dependency to check API Key
147
  def verify_api_key(api_key: str = Header(...)):
148
  if api_key != API_KEY:
149
  raise HTTPException(status_code=401, detail="Invalid API Key")
150
 
 
151
  @app.get("/")
152
  def read_root():
153
  return {"message": "Welcome to the Invoice Summarization API!"}
154
 
 
155
  @app.get("/ocr/extraction")
156
  def extract_text_from_file(
157
- api_key: str = Depends(verify_api_key),
158
- file_key: str = Query(..., description="S3 file key for the file"),
159
- document_type: str = Query(..., description="Type of document"),
160
- entity_ref_key: str = Query(..., description="Entity Reference Key")
161
  ):
162
  """Extract text from a PDF or Image stored in S3 and process it based on document size."""
163
  try:
@@ -175,9 +219,9 @@ def extract_text_from_file(
175
 
176
  json_schema = schema_doc.get("json_schema")
177
  if not json_schema:
178
- raise ValueError("Schema is empty or not properly defined.")
179
-
180
- # Retrieve file from S3 and determine content type
181
  content_type = get_content_type_from_s3(file_key)
182
  file_data, _ = fetch_file_from_s3(file_key)
183
  extracted_data = extract_invoice_data(file_data, content_type, json_schema)
@@ -213,7 +257,8 @@ def extract_text_from_file(
213
  "traceback": traceback.format_exc()
214
  }
215
  return {"error": error_details}
216
-
 
217
  # Serve the output folder as static files
218
  app.mount("/output", StaticFiles(directory="output", follow_symlink=True, html=True), name="output")
219
 
 
46
  use_gpu = False
47
  output_dir = 'output'
48
 
49
+
50
  @app.on_event("startup")
51
  def startup_db():
52
  try:
 
55
  except Exception as e:
56
  logger.error(f"MongoDB connection failed: {str(e)}")
57
 
58
+
59
  # AWS S3 Configuration
60
  API_KEY = os.getenv("API_KEY")
61
  AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY")
 
72
  aws_secret_access_key=AWS_SECRET_KEY
73
  )
74
 
75
+
76
  # Function to fetch file from S3
77
  def fetch_file_from_s3(file_key):
78
  try:
 
83
  except Exception as e:
84
  raise Exception(f"Failed to fetch file from S3: {str(e)}")
85
 
86
+
87
+ # Updated extraction function that handles PDF and image files differently
88
  def extract_invoice_data(file_data, content_type, json_schema):
89
+ """
90
+ For PDFs: Extract the embedded text using PyMuPDF (no OCR involved)
91
+ For Images: Pass the Base64-encoded image to OpenAI (assuming a multimodal model)
92
+ """
93
  system_prompt = "You are an expert in document data extraction."
94
+ base64_encoded_images = [] # To store Base64-encoded image data
95
+
96
+ extracted_data = {}
97
+
98
+ if content_type == "application/pdf":
99
+ # Use PyMuPDF to extract text directly from the PDF
100
+ try:
101
+ doc = fitz.open(stream=file_data, filetype="pdf")
102
+ num_pages = doc.page_count
103
+
104
+ # Check if the number of pages exceeds 2
105
+ if num_pages > 2:
106
+ raise ValueError("The PDF contains more than 2 pages, extraction not supported.")
107
+
108
+ extracted_text = ""
109
+ for page in doc:
110
+ extracted_text += page.get_text()
111
+
112
+ except Exception as e:
113
+ logger.error(f"Error extracting text from PDF: {e}")
114
+ raise
115
+
116
+ # Build a prompt containing the extracted text and the schema
117
+ prompt = (
118
+ f"Extract the invoice data from the following PDF text. "
119
+ f"Return only valid JSON that adheres to this schema:\n\n{json.dumps(json_schema, indent=2)}\n\n"
120
+ f"PDF Text:\n{extracted_text}"
121
+ )
122
 
123
+ elif content_type.startswith("image/"):
124
+ # For images, determine if more than 2 images are provided
125
+ try:
126
+ img = Image.open(io.BytesIO(file_data)) # Open the image file
127
+ num_images = img.n_frames # Get number of images (pages in the image file)
128
+
129
+ if num_images > 2:
130
+ raise ValueError("The image file contains more than 2 pages, extraction not supported.")
131
+
132
+ # Process each image page if there are 1 or 2 pages
133
+ for page_num in range(num_images):
134
+ img.seek(page_num) # Move to the current page
135
+ img_bytes = io.BytesIO()
136
+ img.save(img_bytes, format="PNG") # Save each page as a PNG image in memory
137
+ base64_encoded = base64.b64encode(img_bytes.getvalue()).decode('utf-8')
138
+ base64_encoded_images.append(base64_encoded)
139
+
140
+ # Build a prompt containing the image data for OpenAI
141
+ prompt = f"Extract the invoice data from the following images (Base64 encoded). Return only valid JSON that adheres to this schema:\n\n{json.dumps(json_schema, indent=2)}\n\n"
142
+ for base64_image in base64_encoded_images:
143
+ prompt += f"Image Data URL: data:{content_type};base64,{base64_image}\n"
144
+
145
+ except Exception as e:
146
+ logger.error(f"Error handling images: {e}")
147
+ raise
148
 
 
 
 
 
 
149
  else:
150
  raise ValueError(f"Unsupported content type: {content_type}")
151
 
152
+ # Send request to OpenAI for data extraction
153
  try:
154
  response = openai.ChatCompletion.create(
155
  model="gpt-4o-mini",
156
  messages=[
157
  {"role": "system", "content": system_prompt},
158
+ {"role": "user", "content": prompt},
 
 
 
 
 
 
 
 
 
 
159
  ],
 
 
 
 
160
  temperature=0.5,
161
  max_tokens=16384
162
  )
163
 
 
164
  content = response.choices[0].message.content.strip()
165
  cleaned_content = content.strip().strip('```json').strip('```')
166
 
 
169
  return parsed_content
170
  except json.JSONDecodeError as e:
171
  logger.error(f"JSON Parse Error: {e}")
172
+ return {"error": f"JSON Parse Error: {str(e)}"}
173
 
174
  except Exception as e:
175
  logger.error(f"Error in data extraction: {e}")
176
  return {"error": str(e)}
177
 
178
+
179
  def get_content_type_from_s3(file_key):
180
  """Fetch the content type (MIME type) of a file stored in S3."""
181
  try:
 
184
  except Exception as e:
185
  raise Exception(f"Failed to get content type from S3: {str(e)}")
186
 
187
+
188
  # Dependency to check API Key
189
  def verify_api_key(api_key: str = Header(...)):
190
  if api_key != API_KEY:
191
  raise HTTPException(status_code=401, detail="Invalid API Key")
192
 
193
+
194
  @app.get("/")
195
  def read_root():
196
  return {"message": "Welcome to the Invoice Summarization API!"}
197
 
198
+
199
  @app.get("/ocr/extraction")
200
  def extract_text_from_file(
201
+ api_key: str = Depends(verify_api_key),
202
+ file_key: str = Query(..., description="S3 file key for the file"),
203
+ document_type: str = Query(..., description="Type of document"),
204
+ entity_ref_key: str = Query(..., description="Entity Reference Key")
205
  ):
206
  """Extract text from a PDF or Image stored in S3 and process it based on document size."""
207
  try:
 
219
 
220
  json_schema = schema_doc.get("json_schema")
221
  if not json_schema:
222
+ raise ValueError("Schema is empty or not properly defined.")
223
+
224
+ # Retrieve file from S3 and determine content type
225
  content_type = get_content_type_from_s3(file_key)
226
  file_data, _ = fetch_file_from_s3(file_key)
227
  extracted_data = extract_invoice_data(file_data, content_type, json_schema)
 
257
  "traceback": traceback.format_exc()
258
  }
259
  return {"error": error_details}
260
+
261
+
262
  # Serve the output folder as static files
263
  app.mount("/output", StaticFiles(directory="output", follow_symlink=True, html=True), name="output")
264