Alfonso Velasco commited on
Commit
1af4bc8
·
1 Parent(s): dd88d34
Files changed (1) hide show
  1. app.py +150 -185
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from typing import Dict, Any
4
  from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
5
  import torch
6
  from PIL import Image
@@ -29,7 +29,6 @@ try:
29
  model.to(device)
30
  except Exception as e:
31
  print(f"Error loading model: {e}")
32
- # Fallback to no OCR if there's an issue
33
  processor = LayoutLMv3Processor.from_pretrained(
34
  "microsoft/layoutlmv3-base",
35
  apply_ocr=False
@@ -44,6 +43,7 @@ except Exception as e:
44
  class DocumentRequest(BaseModel):
45
  pdf: str = None
46
  image: str = None
 
47
 
48
  @app.get("/")
49
  def home():
@@ -52,190 +52,25 @@ def home():
52
  @app.post("/extract")
53
  async def extract_document(request: DocumentRequest):
54
  try:
55
- # Determine input type
56
  file_data = request.pdf or request.image
57
  if not file_data:
58
  raise HTTPException(status_code=400, detail="No PDF or image provided")
59
 
60
- # Decode base64
61
  file_bytes = base64.b64decode(file_data)
62
 
63
- # Check if PDF or image
64
  if file_bytes.startswith(b'%PDF'):
65
- return process_pdf(file_bytes)
66
  else:
67
  return process_image(file_bytes)
68
 
69
  except Exception as e:
70
  raise HTTPException(status_code=500, detail=str(e))
71
 
72
- def process_pdf(pdf_bytes):
73
- """Process PDF document with proper coordinate scaling for any orientation"""
74
- all_results = []
75
-
76
- with tempfile.NamedTemporaryFile(suffix='.pdf', delete=False) as tmp_file:
77
- tmp_file.write(pdf_bytes)
78
- tmp_file.flush()
79
-
80
- pdf_document = fitz.open(tmp_file.name)
81
-
82
- # Define render scale
83
- RENDER_SCALE = 2.0
84
-
85
- for page_num in range(len(pdf_document)):
86
- page = pdf_document[page_num]
87
-
88
- # CRITICAL FIX: Get the actual page rectangle
89
- # This accounts for rotation and gives us the true page dimensions
90
- page_rect = page.rect
91
- page_width = page_rect.width
92
- page_height = page_rect.height
93
-
94
- print(f"Page {page_num + 1}: {page_width}x{page_height}, rotation={page.rotation}°")
95
-
96
- # Render page at consistent resolution
97
- # The matrix handles rotation automatically
98
- mat = fitz.Matrix(RENDER_SCALE, RENDER_SCALE)
99
- pix = page.get_pixmap(matrix=mat)
100
- img_data = pix.tobytes("png")
101
- image = Image.open(io.BytesIO(img_data)).convert("RGB")
102
-
103
- # Store rendered image dimensions
104
- img_width, img_height = image.size
105
-
106
- print(f"Rendered image: {img_width}x{img_height}")
107
-
108
- # CRITICAL: Verify the scaling is correct
109
- # The rendered image should be RENDER_SCALE times the page size
110
- expected_width = page_width * RENDER_SCALE
111
- expected_height = page_height * RENDER_SCALE
112
-
113
- if abs(img_width - expected_width) > 5 or abs(img_height - expected_height) > 5:
114
- print(f"WARNING: Image size mismatch! Expected {expected_width}x{expected_height}")
115
-
116
- try:
117
- # Try with OCR - increased max_length for wide documents
118
- encoding = processor(
119
- image,
120
- truncation=True,
121
- padding="max_length",
122
- max_length=1024, # Increased from 512 to handle wider documents
123
- return_tensors="pt"
124
- )
125
- except Exception as ocr_error:
126
- print(f"OCR failed: {ocr_error}, using fallback")
127
- # Fallback: process without OCR
128
- encoding = processor(
129
- image,
130
- text=[""] * 512, # Dummy text
131
- boxes=[[0, 0, 0, 0]] * 512, # Dummy boxes
132
- truncation=True,
133
- padding="max_length",
134
- max_length=512,
135
- return_tensors="pt"
136
- )
137
-
138
- encoding = {k: v.to(device) for k, v in encoding.items() if isinstance(v, torch.Tensor)}
139
-
140
- with torch.no_grad():
141
- outputs = model(**encoding)
142
-
143
- tokens = processor.tokenizer.convert_ids_to_tokens(encoding["input_ids"][0])
144
- boxes = encoding["bbox"][0].tolist()
145
-
146
- page_results = []
147
- processed_boxes = set() # Track processed boxes to avoid duplicates
148
-
149
- for token, box in zip(tokens, boxes):
150
- if token not in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>', '<pad>']:
151
- # LayoutLMv3 returns normalized coordinates (0-1000)
152
- # These are normalized relative to the INPUT IMAGE dimensions
153
- x_norm = box[0]
154
- y_norm = box[1]
155
- x2_norm = box[2]
156
- y2_norm = box[3]
157
-
158
- # Skip invalid boxes
159
- if x_norm == 0 and y_norm == 0 and x2_norm == 0 and y2_norm == 0:
160
- continue
161
-
162
- # STEP 1: Convert normalized (0-1000) to rendered image pixel coordinates
163
- # CRITICAL: Use the ACTUAL rendered image dimensions
164
- x_img = (x_norm / 1000.0) * img_width
165
- y_img = (y_norm / 1000.0) * img_height
166
- x2_img = (x2_norm / 1000.0) * img_width
167
- y2_img = (y2_norm / 1000.0) * img_height
168
-
169
- # STEP 2: Scale back to PDF page coordinates
170
- # The image was rendered at RENDER_SCALE times the PDF size
171
- x = x_img / RENDER_SCALE
172
- y = y_img / RENDER_SCALE
173
- x2 = x2_img / RENDER_SCALE
174
- y2 = y2_img / RENDER_SCALE
175
-
176
- width = x2 - x
177
- height = y2 - y
178
-
179
- # Skip boxes that are too small
180
- if width < 1 or height < 1:
181
- continue
182
-
183
- # Validate bounds
184
- if x < 0 or y < 0 or x2 > page_width or y2 > page_height:
185
- # Allow small tolerance for rounding errors
186
- if (x < -2 or y < -2 or
187
- x2 > page_width + 2 or y2 > page_height + 2):
188
- print(f"Skipping out of bounds box: ({x:.1f},{y:.1f}) to ({x2:.1f},{y2:.1f})")
189
- continue
190
- # Clamp to valid bounds
191
- x = max(0, x)
192
- y = max(0, y)
193
- x2 = min(page_width, x2)
194
- y2 = min(page_height, y2)
195
- width = x2 - x
196
- height = y2 - y
197
-
198
- # Create box tuple for duplicate checking
199
- box_tuple = (round(x), round(y), round(width), round(height))
200
- if box_tuple in processed_boxes:
201
- continue
202
- processed_boxes.add(box_tuple)
203
-
204
- # Clean up token text (remove ## prefix from subwords)
205
- clean_token = token.replace('##', '')
206
-
207
- page_results.append({
208
- "text": clean_token,
209
- "bbox": {
210
- "x": x,
211
- "y": y,
212
- "width": width,
213
- "height": height
214
- }
215
- })
216
-
217
- all_results.append({
218
- "page": page_num + 1,
219
- "page_dimensions": {
220
- "width": page_width,
221
- "height": page_height
222
- },
223
- "rotation": page.rotation,
224
- "extractions": page_results
225
- })
226
-
227
- pdf_document.close()
228
- os.unlink(tmp_file.name) # Clean up temp file
229
-
230
- return {
231
- "document_type": "pdf",
232
- "total_pages": len(all_results),
233
- "pages": all_results
234
- }
235
-
236
- def process_image(image_bytes):
237
- """Process single image with proper coordinate scaling"""
238
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
239
  img_width, img_height = image.size
240
 
241
  try:
@@ -243,19 +78,18 @@ def process_image(image_bytes):
243
  image,
244
  truncation=True,
245
  padding="max_length",
246
- max_length=512,
247
  return_tensors="pt"
248
  )
249
  except Exception as e:
250
  print(f"OCR failed: {e}, using fallback")
251
- # Fallback without OCR
252
  encoding = processor(
253
  image,
254
- text=[""] * 512,
255
- boxes=[[0, 0, 0, 0]] * 512,
256
  truncation=True,
257
  padding="max_length",
258
- max_length=512,
259
  return_tensors="pt"
260
  )
261
 
@@ -272,37 +106,37 @@ def process_image(image_bytes):
272
 
273
  for token, box in zip(tokens, boxes):
274
  if token not in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>', '<pad>']:
275
- # LayoutLMv3 returns normalized coordinates (0-1000)
276
- # For images, we directly have the correct dimensions
277
  x_norm = box[0]
278
  y_norm = box[1]
279
  x2_norm = box[2]
280
  y2_norm = box[3]
281
 
282
- # Skip invalid boxes
283
  if x_norm == 0 and y_norm == 0 and x2_norm == 0 and y2_norm == 0:
284
  continue
285
 
286
- # Convert to actual image coordinates
287
  x = (x_norm / 1000.0) * img_width
288
  y = (y_norm / 1000.0) * img_height
289
  x2 = (x2_norm / 1000.0) * img_width
290
  y2 = (y2_norm / 1000.0) * img_height
291
 
 
 
 
 
 
 
292
  width = x2 - x
293
  height = y2 - y
294
 
295
- # Skip boxes that are too small
296
  if width < 1 or height < 1:
297
  continue
298
 
299
- # Check for duplicates
300
  box_tuple = (round(x), round(y), round(width), round(height))
301
  if box_tuple in processed_boxes:
302
  continue
303
  processed_boxes.add(box_tuple)
304
 
305
- # Clean up token text
306
  clean_token = token.replace('##', '')
307
 
308
  results.append({
@@ -315,6 +149,137 @@ def process_image(image_bytes):
315
  }
316
  })
317
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  return {
319
  "document_type": "image",
320
  "image_dimensions": {
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from typing import Dict, Any, List
4
  from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
5
  import torch
6
  from PIL import Image
 
29
  model.to(device)
30
  except Exception as e:
31
  print(f"Error loading model: {e}")
 
32
  processor = LayoutLMv3Processor.from_pretrained(
33
  "microsoft/layoutlmv3-base",
34
  apply_ocr=False
 
43
  class DocumentRequest(BaseModel):
44
  pdf: str = None
45
  image: str = None
46
+ split_wide_pages: bool = True # New option to split wide pages
47
 
48
  @app.get("/")
49
  def home():
 
52
  @app.post("/extract")
53
  async def extract_document(request: DocumentRequest):
54
  try:
 
55
  file_data = request.pdf or request.image
56
  if not file_data:
57
  raise HTTPException(status_code=400, detail="No PDF or image provided")
58
 
 
59
  file_bytes = base64.b64decode(file_data)
60
 
 
61
  if file_bytes.startswith(b'%PDF'):
62
+ return process_pdf(pdf_bytes=file_bytes, split_wide=request.split_wide_pages)
63
  else:
64
  return process_image(file_bytes)
65
 
66
  except Exception as e:
67
  raise HTTPException(status_code=500, detail=str(e))
68
 
69
+ def process_image_chunk(image: Image.Image, offset_x: float = 0, offset_y: float = 0) -> List[Dict]:
70
+ """
71
+ Process a single image or image chunk and return extractions.
72
+ offset_x and offset_y are used when processing chunks of a larger image.
73
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  img_width, img_height = image.size
75
 
76
  try:
 
78
  image,
79
  truncation=True,
80
  padding="max_length",
81
+ max_length=1024, # Increased limit
82
  return_tensors="pt"
83
  )
84
  except Exception as e:
85
  print(f"OCR failed: {e}, using fallback")
 
86
  encoding = processor(
87
  image,
88
+ text=[""] * 1024,
89
+ boxes=[[0, 0, 0, 0]] * 1024,
90
  truncation=True,
91
  padding="max_length",
92
+ max_length=1024,
93
  return_tensors="pt"
94
  )
95
 
 
106
 
107
  for token, box in zip(tokens, boxes):
108
  if token not in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>', '<pad>']:
 
 
109
  x_norm = box[0]
110
  y_norm = box[1]
111
  x2_norm = box[2]
112
  y2_norm = box[3]
113
 
 
114
  if x_norm == 0 and y_norm == 0 and x2_norm == 0 and y2_norm == 0:
115
  continue
116
 
117
+ # Convert to chunk coordinates
118
  x = (x_norm / 1000.0) * img_width
119
  y = (y_norm / 1000.0) * img_height
120
  x2 = (x2_norm / 1000.0) * img_width
121
  y2 = (y2_norm / 1000.0) * img_height
122
 
123
+ # Add offset to get coordinates in full page space
124
+ x += offset_x
125
+ y += offset_y
126
+ x2 += offset_x
127
+ y2 += offset_y
128
+
129
  width = x2 - x
130
  height = y2 - y
131
 
 
132
  if width < 1 or height < 1:
133
  continue
134
 
 
135
  box_tuple = (round(x), round(y), round(width), round(height))
136
  if box_tuple in processed_boxes:
137
  continue
138
  processed_boxes.add(box_tuple)
139
 
 
140
  clean_token = token.replace('##', '')
141
 
142
  results.append({
 
149
  }
150
  })
151
 
152
+ return results
153
+
154
+ def process_pdf(pdf_bytes, split_wide: bool = True):
155
+ """Process PDF document, optionally splitting wide pages into chunks"""
156
+ all_results = []
157
+
158
+ with tempfile.NamedTemporaryFile(suffix='.pdf', delete=False) as tmp_file:
159
+ tmp_file.write(pdf_bytes)
160
+ tmp_file.flush()
161
+
162
+ pdf_document = fitz.open(tmp_file.name)
163
+
164
+ RENDER_SCALE = 2.0
165
+ MAX_WIDTH = 2000 # Maximum width before splitting (in pixels after rendering)
166
+ OVERLAP = 200 # Overlap between chunks to avoid missing text at boundaries
167
+
168
+ for page_num in range(len(pdf_document)):
169
+ page = pdf_document[page_num]
170
+ page_rect = page.rect
171
+ page_width = page_rect.width
172
+ page_height = page_rect.height
173
+
174
+ print(f"Page {page_num + 1}: {page_width}x{page_height}, rotation={page.rotation}°")
175
+
176
+ # Render page
177
+ mat = fitz.Matrix(RENDER_SCALE, RENDER_SCALE)
178
+ pix = page.get_pixmap(matrix=mat)
179
+ img_data = pix.tobytes("png")
180
+ full_image = Image.open(io.BytesIO(img_data)).convert("RGB")
181
+ img_width, img_height = full_image.size
182
+
183
+ print(f"Rendered image: {img_width}x{img_height}")
184
+
185
+ page_results = []
186
+
187
+ # Check if page is too wide and should be split
188
+ if split_wide and img_width > MAX_WIDTH:
189
+ print(f"Page is wide ({img_width}px), splitting into chunks...")
190
+
191
+ num_chunks = (img_width + MAX_WIDTH - OVERLAP - 1) // (MAX_WIDTH - OVERLAP)
192
+ chunk_width = MAX_WIDTH
193
+
194
+ for chunk_idx in range(num_chunks):
195
+ # Calculate chunk boundaries
196
+ start_x = chunk_idx * (chunk_width - OVERLAP)
197
+ end_x = min(start_x + chunk_width, img_width)
198
+
199
+ # Crop chunk
200
+ chunk = full_image.crop((start_x, 0, end_x, img_height))
201
+
202
+ print(f" Processing chunk {chunk_idx + 1}/{num_chunks}: x={start_x}-{end_x}")
203
+
204
+ # Process chunk and adjust coordinates
205
+ chunk_offset_pdf = start_x / RENDER_SCALE
206
+ chunk_results = process_image_chunk(
207
+ chunk,
208
+ offset_x=chunk_offset_pdf,
209
+ offset_y=0
210
+ )
211
+
212
+ # Scale coordinates back to PDF space
213
+ for result in chunk_results:
214
+ bbox = result['bbox']
215
+ bbox['x'] /= RENDER_SCALE
216
+ bbox['y'] /= RENDER_SCALE
217
+ bbox['width'] /= RENDER_SCALE
218
+ bbox['height'] /= RENDER_SCALE
219
+
220
+ page_results.extend(chunk_results)
221
+
222
+ print(f" Total extractions from all chunks: {len(page_results)}")
223
+
224
+ else:
225
+ # Process full page
226
+ chunk_results = process_image_chunk(full_image, 0, 0)
227
+
228
+ # Scale coordinates back to PDF space
229
+ for result in chunk_results:
230
+ bbox = result['bbox']
231
+ bbox['x'] = (bbox['x'] / img_width) * page_width
232
+ bbox['y'] = (bbox['y'] / img_height) * page_height
233
+ bbox['width'] = (bbox['width'] / img_width) * page_width
234
+ bbox['height'] = (bbox['height'] / img_height) * page_height
235
+
236
+ page_results = chunk_results
237
+
238
+ # Remove duplicates from overlapping chunks
239
+ unique_results = []
240
+ seen_boxes = set()
241
+
242
+ for result in page_results:
243
+ bbox = result['bbox']
244
+ box_tuple = (
245
+ round(bbox['x']),
246
+ round(bbox['y']),
247
+ round(bbox['width']),
248
+ round(bbox['height'])
249
+ )
250
+
251
+ if box_tuple not in seen_boxes:
252
+ seen_boxes.add(box_tuple)
253
+ unique_results.append(result)
254
+
255
+ print(f" After deduplication: {len(unique_results)} unique extractions")
256
+
257
+ all_results.append({
258
+ "page": page_num + 1,
259
+ "page_dimensions": {
260
+ "width": page_width,
261
+ "height": page_height
262
+ },
263
+ "rotation": page.rotation,
264
+ "extractions": unique_results
265
+ })
266
+
267
+ pdf_document.close()
268
+ os.unlink(tmp_file.name)
269
+
270
+ return {
271
+ "document_type": "pdf",
272
+ "total_pages": len(all_results),
273
+ "pages": all_results
274
+ }
275
+
276
+ def process_image(image_bytes):
277
+ """Process single image"""
278
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
279
+ img_width, img_height = image.size
280
+
281
+ results = process_image_chunk(image, 0, 0)
282
+
283
  return {
284
  "document_type": "image",
285
  "image_dimensions": {