from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import Dict, Any, List, Tuple, Optional from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification import torch from PIL import Image import io import base64 import fitz # PyMuPDF import tempfile import os import math os.environ['OMP_NUM_THREADS'] = '1' os.environ['CUDA_LAUNCH_BLOCKING'] = '1' app = FastAPI() try: processor = LayoutLMv3Processor.from_pretrained( "microsoft/layoutlmv3-base", apply_ocr=True ) model = LayoutLMv3ForTokenClassification.from_pretrained( "microsoft/layoutlmv3-base" ) model.eval() device = torch.device("cpu") print(f"Using device: {device}") model.to(device) except Exception as e: print(f"Error loading model: {e}") processor = LayoutLMv3Processor.from_pretrained( "microsoft/layoutlmv3-base", apply_ocr=False ) model = LayoutLMv3ForTokenClassification.from_pretrained( "microsoft/layoutlmv3-base" ) model.eval() device = torch.device("cpu") model.to(device) class DocumentRequest(BaseModel): pdf: str = None image: str = None split_wide_pages: bool = True @app.get("/") def home(): return {"message": "LayoutLMv3 PDF/Image Extraction API", "status": "ready"} @app.post("/extract") async def extract_document(request: DocumentRequest): try: file_data = request.pdf or request.image if not file_data: raise HTTPException(status_code=400, detail="No PDF or image provided") file_bytes = base64.b64decode(file_data) if file_bytes.startswith(b'%PDF'): return process_pdf(pdf_bytes=file_bytes, split_wide=request.split_wide_pages) else: return process_image(file_bytes) except Exception as e: import traceback error_details = traceback.format_exc() print(f"Error in extract_document: {error_details}") raise HTTPException(status_code=500, detail=str(e)) def process_image_chunk(image: Image.Image, max_tokens: int = 512) -> List[Dict]: """Process a single image chunk and return extractions.""" img_width, img_height = image.size if img_width < 1 or img_height < 1: print(f"Invalid image dimensions: {img_width}x{img_height}") return [] # Try multiple token limits if we hit errors token_limits = [max_tokens, 384, 256] if max_tokens > 256 else [max_tokens] for token_limit in token_limits: try: encoding = processor( image, truncation=True, padding="max_length", max_length=token_limit, return_tensors="pt" ) except Exception as e: print(f"OCR failed with max_tokens={token_limit}: {e}") if token_limit == token_limits[-1]: # Last attempt, try fallback try: encoding = processor( image, text=[""] * token_limit, boxes=[[0, 0, 0, 0]] * token_limit, truncation=True, padding="max_length", max_length=token_limit, return_tensors="pt" ) except Exception as e2: print(f"Fallback also failed: {e2}") return [] else: continue encoding_device = {} for k, v in encoding.items(): if isinstance(v, torch.Tensor): encoding_device[k] = v.to(device) if k == "bbox": encoding_device[k] = torch.clamp(encoding_device[k], 0, 1000) encoding = encoding_device try: with torch.no_grad(): outputs = model(**encoding) # Success! Break out of retry loop break except RuntimeError as e: error_str = str(e) if "CUDA" in error_str: print(f"CUDA error encountered: {e}") encoding = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in encoding.items()} model.cpu() with torch.no_grad(): outputs = model(**encoding) model.to(device) break elif "index out of range" in error_str: print(f"Index error with max_tokens={token_limit}: {e}") if token_limit == token_limits[-1]: print(f"All token limits exhausted, returning empty results") return [] else: print(f"Retrying with smaller token limit...") continue else: raise except Exception as e: print(f"Unexpected error in model processing: {e}") if token_limit == token_limits[-1]: return [] else: continue try: tokens = processor.tokenizer.convert_ids_to_tokens(encoding["input_ids"][0]) boxes = encoding["bbox"][0].tolist() except Exception as e: print(f"Error extracting tokens/boxes: {e}") return [] results = [] processed_boxes = set() for idx, (token, box) in enumerate(zip(tokens, boxes)): try: if token not in ['[CLS]', '[SEP]', '[PAD]', '', '', '']: x_norm, y_norm, x2_norm, y2_norm = box if x_norm == 0 and y_norm == 0 and x2_norm == 0 and y2_norm == 0: continue # Convert normalized coordinates to pixel coordinates x = (x_norm / 1000.0) * img_width y = (y_norm / 1000.0) * img_height x2 = (x2_norm / 1000.0) * img_width y2 = (y2_norm / 1000.0) * img_height width = x2 - x height = y2 - y if width < 1 or height < 1: continue box_tuple = (round(x), round(y), round(width), round(height)) if box_tuple in processed_boxes: continue processed_boxes.add(box_tuple) clean_token = token.replace('##', '') results.append({ "text": clean_token, "bbox": { "x": x, "y": y, "width": width, "height": height } }) except Exception as e: print(f"Error processing token at index {idx}: {e}") continue return results def should_split_page(rendered_width: int, rendered_height: int, max_width: int) -> Tuple[bool, str]: """Determine if a page should be split based on rendered dimensions.""" if rendered_width > max_width: return (True, "horizontal") return (False, None) def split_image_intelligently(image: Image.Image, max_width: int, overlap_ratio: float = 0.1) -> List[Tuple[Image.Image, int]]: """Split image into overlapping chunks along the width.""" img_width, img_height = image.size if img_width <= max_width: return [(image, 0)] overlap_pixels = int(max_width * overlap_ratio) step_size = max_width - overlap_pixels chunks = [] x_position = 0 while x_position < img_width: right_edge = min(x_position + max_width, img_width) if right_edge < img_width and (img_width - right_edge) < (max_width * 0.3): right_edge = img_width chunk = image.crop((x_position, 0, right_edge, img_height)) chunks.append((chunk, x_position)) print(f" Created chunk at x={x_position}, width={right_edge - x_position}") if right_edge >= img_width: break x_position += step_size return chunks def process_pdf(pdf_bytes: bytes, split_wide: bool = True): """ Process PDF with proper handling of rotated pages. KEY FIX: We now work with ACTUAL rendered dimensions instead of assuming they match the effective dimensions. We map coordinates based on the actual render, then transform them to the effective coordinate space. """ RENDER_SCALE = 3.0 MAX_WIDTH = 1800 # Maximum width for a chunk in rendered pixels (reduced to ensure splitting) MAX_TOKENS = 512 # Reduced to prevent index out of range errors with large images all_results = [] with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: tmp_file.write(pdf_bytes) tmp_file.flush() pdf_document = fitz.open(tmp_file.name) for page_num in range(len(pdf_document)): try: page = pdf_document[page_num] # Get original page dimensions and rotation original_rect = page.rect original_width = original_rect.width original_height = original_rect.height original_rotation = page.rotation print(f"\nProcessing page {page_num + 1}:") print(f" Original dimensions: {original_width}x{original_height}") print(f" Rotation: {original_rotation}°") # Determine effective dimensions (what the page looks like when properly oriented) if original_rotation in [90, 270]: effective_pdf_width = original_height effective_pdf_height = original_width else: effective_pdf_width = original_width effective_pdf_height = original_height print(f" Effective PDF dimensions (after rotation): {effective_pdf_width}x{effective_pdf_height}") # Render the page - PyMuPDF may not rotate it as expected mat = fitz.Matrix(RENDER_SCALE, RENDER_SCALE) pix = page.get_pixmap(matrix=mat) img_data = pix.tobytes("png") full_image = Image.open(io.BytesIO(img_data)).convert("RGB") rendered_width, rendered_height = full_image.size print(f" Actual rendered dimensions: {rendered_width}x{rendered_height}") # Detect if dimensions don't match expectations expected_rendered_width = effective_pdf_width * RENDER_SCALE expected_rendered_height = effective_pdf_height * RENDER_SCALE dimensions_swapped = False if (abs(rendered_width - expected_rendered_height) < 10 and abs(rendered_height - expected_rendered_width) < 10): print(f" ⚠️ Dimensions are swapped! Rotating image 90° to match expected orientation.") # Rotate the image to match expected orientation full_image = full_image.rotate(-90, expand=True) rendered_width, rendered_height = full_image.size print(f" After rotation: {rendered_width}x{rendered_height}") dimensions_swapped = True # Calculate the scale factor from rendered pixels to effective PDF points # This handles any discrepancies between expected and actual rendering scale_x = rendered_width / (effective_pdf_width * RENDER_SCALE) scale_y = rendered_height / (effective_pdf_height * RENDER_SCALE) print(f" Scale factors: x={scale_x:.4f}, y={scale_y:.4f}") page_results = [] # Decide if we need to split should_split_decision, split_direction = should_split_page( rendered_width, rendered_height, MAX_WIDTH ) if split_wide and should_split_decision: print(f" Splitting page ({split_direction})...") chunks = split_image_intelligently(full_image, MAX_WIDTH, overlap_ratio=0.2) print(f" Created {len(chunks)} chunks") for chunk_idx, (chunk_image, x_offset) in enumerate(chunks): chunk_width, chunk_height = chunk_image.size print(f" Processing chunk {chunk_idx + 1}: offset={x_offset}px, size={chunk_width}x{chunk_height}px") chunk_results = process_image_chunk(chunk_image, max_tokens=MAX_TOKENS) print(f" Extracted {len(chunk_results)} items from chunk {chunk_idx + 1}") if chunk_results and chunk_idx < 2: print(f" Sample items from chunk {chunk_idx + 1}:") for i, item in enumerate(chunk_results[:3]): print(f" Item {i+1}: text='{item['text']}', chunk_x={item['bbox']['x']:.1f}px") # Transform coordinates from chunk space to PDF effective space for result in chunk_results: bbox = result['bbox'] # Step 1: Chunk coordinates -> Full rendered image coordinates rendered_x = bbox['x'] + x_offset rendered_y = bbox['y'] # Step 2: Rendered coordinates -> PDF points in effective space # Account for the actual render scale and any dimension swapping pdf_x = rendered_x / (RENDER_SCALE * scale_x) pdf_y = rendered_y / (RENDER_SCALE * scale_y) pdf_width = bbox['width'] / (RENDER_SCALE * scale_x) pdf_height = bbox['height'] / (RENDER_SCALE * scale_y) bbox['x'] = pdf_x bbox['y'] = pdf_y bbox['width'] = pdf_width bbox['height'] = pdf_height # Debug first item if result == chunk_results[0]: print(f" Transform: chunk_x={bbox['x'] - pdf_x + rendered_x - x_offset:.1f}px + offset={x_offset}px = rendered_x={rendered_x:.1f}px → pdf_x={pdf_x:.1f}pts") page_results.extend(chunk_results) print(f" Total items before deduplication: {len(page_results)}") else: # Process full page without splitting print(" Processing full page without splitting...") chunk_results = process_image_chunk(full_image, max_tokens=MAX_TOKENS) for result in chunk_results: bbox = result['bbox'] bbox['x'] = bbox['x'] / (RENDER_SCALE * scale_x) bbox['y'] = bbox['y'] / (RENDER_SCALE * scale_y) bbox['width'] = bbox['width'] / (RENDER_SCALE * scale_x) bbox['height'] = bbox['height'] / (RENDER_SCALE * scale_y) page_results = chunk_results print(f" Extracted {len(chunk_results)} items") # Deduplication unique_results = deduplicate_results(page_results) print(f" After deduplication: {len(unique_results)} unique items") # Verify coordinate ranges if unique_results: x_coords = [item['bbox']['x'] for item in unique_results] y_coords = [item['bbox']['y'] for item in unique_results] print(f" Final coordinate ranges:") print(f" X: {min(x_coords):.1f} to {max(x_coords):.1f} (effective width: {effective_pdf_width:.1f})") print(f" Y: {min(y_coords):.1f} to {max(y_coords):.1f} (effective height: {effective_pdf_height:.1f})") if max(x_coords) > effective_pdf_width + 10: print(f" ⚠️ WARNING: Some X coordinates still exceed effective page width!") elif max(x_coords) > effective_pdf_width: print(f" ℹ️ Note: Max X slightly exceeds width (likely edge items), but within tolerance") else: print(f" ✓ All coordinates within expected bounds") all_results.append({ "page": page_num + 1, "page_dimensions": { "width": original_width, "height": original_height }, "effective_dimensions": { "width": effective_pdf_width, "height": effective_pdf_height }, "rotation": original_rotation, "extractions": unique_results }) except Exception as e: print(f"Error processing page {page_num + 1}: {e}") import traceback traceback.print_exc() all_results.append({ "page": page_num + 1, "page_dimensions": {"width": 0, "height": 0}, "effective_dimensions": {"width": 0, "height": 0}, "rotation": 0, "extractions": [], "error": str(e) }) pdf_document.close() os.unlink(tmp_file.name) return { "document_type": "pdf", "total_pages": len(all_results), "pages": all_results } def deduplicate_results(results: List[Dict], tolerance: float = 10.0) -> List[Dict]: """Remove duplicate extractions using spatial clustering.""" if not results: return [] unique_results = [] processed_indices = set() for i, result in enumerate(results): if i in processed_indices: continue bbox = result['bbox'] center_x = bbox['x'] + bbox['width'] / 2 center_y = bbox['y'] + bbox['height'] / 2 cluster = [result] cluster_indices = {i} for j, other in enumerate(results): if j <= i or j in processed_indices: continue other_bbox = other['bbox'] other_center_x = other_bbox['x'] + other_bbox['width'] / 2 other_center_y = other_bbox['y'] + other_bbox['height'] / 2 dist = math.sqrt((center_x - other_center_x)**2 + (center_y - other_center_y)**2) if dist < tolerance: size_ratio_w = bbox['width'] / other_bbox['width'] if other_bbox['width'] > 0 else 1 size_ratio_h = bbox['height'] / other_bbox['height'] if other_bbox['height'] > 0 else 1 if 0.7 < size_ratio_w < 1.3 and 0.7 < size_ratio_h < 1.3: cluster.append(other) cluster_indices.add(j) best_result = max(cluster, key=lambda r: len(r.get('text', ''))) unique_results.append(best_result) processed_indices.update(cluster_indices) return unique_results def process_image(image_bytes): """Process single image""" image = Image.open(io.BytesIO(image_bytes)).convert("RGB") img_width, img_height = image.size print(f"Processing single image: {img_width}x{img_height}") should_split_decision, _ = should_split_page(img_width, img_height, 2000) if should_split_decision: print(" Image is wide, splitting into chunks...") chunks = split_image_intelligently(image, 2000, overlap_ratio=0.2) all_results = [] for chunk_idx, (chunk_image, x_offset) in enumerate(chunks): chunk_results = process_image_chunk(chunk_image, max_tokens=768) for result in chunk_results: result['bbox']['x'] += x_offset all_results.extend(chunk_results) results = deduplicate_results(all_results) else: results = process_image_chunk(image, max_tokens=768) print(f" Total extractions: {len(results)}") return { "document_type": "image", "image_dimensions": { "width": img_width, "height": img_height }, "extractions": results } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)