Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import Dict, Any, List, Tuple, Optional | |
| from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification | |
| import torch | |
| from PIL import Image | |
| import io | |
| import base64 | |
| import fitz # PyMuPDF | |
| import tempfile | |
| import os | |
| import math | |
| os.environ['OMP_NUM_THREADS'] = '1' | |
| os.environ['CUDA_LAUNCH_BLOCKING'] = '1' | |
| app = FastAPI() | |
| try: | |
| processor = LayoutLMv3Processor.from_pretrained( | |
| "microsoft/layoutlmv3-base", | |
| apply_ocr=True | |
| ) | |
| model = LayoutLMv3ForTokenClassification.from_pretrained( | |
| "microsoft/layoutlmv3-base" | |
| ) | |
| model.eval() | |
| device = torch.device("cpu") | |
| print(f"Using device: {device}") | |
| model.to(device) | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| processor = LayoutLMv3Processor.from_pretrained( | |
| "microsoft/layoutlmv3-base", | |
| apply_ocr=False | |
| ) | |
| model = LayoutLMv3ForTokenClassification.from_pretrained( | |
| "microsoft/layoutlmv3-base" | |
| ) | |
| model.eval() | |
| device = torch.device("cpu") | |
| model.to(device) | |
| class DocumentRequest(BaseModel): | |
| pdf: str = None | |
| image: str = None | |
| split_wide_pages: bool = True | |
| def home(): | |
| return {"message": "LayoutLMv3 PDF/Image Extraction API", "status": "ready"} | |
| async def extract_document(request: DocumentRequest): | |
| try: | |
| file_data = request.pdf or request.image | |
| if not file_data: | |
| raise HTTPException(status_code=400, detail="No PDF or image provided") | |
| file_bytes = base64.b64decode(file_data) | |
| if file_bytes.startswith(b'%PDF'): | |
| return process_pdf(pdf_bytes=file_bytes, split_wide=request.split_wide_pages) | |
| else: | |
| return process_image(file_bytes) | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| print(f"Error in extract_document: {error_details}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def process_image_chunk(image: Image.Image, max_tokens: int = 512) -> List[Dict]: | |
| """Process a single image chunk and return extractions.""" | |
| img_width, img_height = image.size | |
| if img_width < 1 or img_height < 1: | |
| print(f"Invalid image dimensions: {img_width}x{img_height}") | |
| return [] | |
| # Try multiple token limits if we hit errors | |
| token_limits = [max_tokens, 384, 256] if max_tokens > 256 else [max_tokens] | |
| for token_limit in token_limits: | |
| try: | |
| encoding = processor( | |
| image, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=token_limit, | |
| return_tensors="pt" | |
| ) | |
| except Exception as e: | |
| print(f"OCR failed with max_tokens={token_limit}: {e}") | |
| if token_limit == token_limits[-1]: | |
| # Last attempt, try fallback | |
| try: | |
| encoding = processor( | |
| image, | |
| text=[""] * token_limit, | |
| boxes=[[0, 0, 0, 0]] * token_limit, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=token_limit, | |
| return_tensors="pt" | |
| ) | |
| except Exception as e2: | |
| print(f"Fallback also failed: {e2}") | |
| return [] | |
| else: | |
| continue | |
| encoding_device = {} | |
| for k, v in encoding.items(): | |
| if isinstance(v, torch.Tensor): | |
| encoding_device[k] = v.to(device) | |
| if k == "bbox": | |
| encoding_device[k] = torch.clamp(encoding_device[k], 0, 1000) | |
| encoding = encoding_device | |
| try: | |
| with torch.no_grad(): | |
| outputs = model(**encoding) | |
| # Success! Break out of retry loop | |
| break | |
| except RuntimeError as e: | |
| error_str = str(e) | |
| if "CUDA" in error_str: | |
| print(f"CUDA error encountered: {e}") | |
| encoding = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in encoding.items()} | |
| model.cpu() | |
| with torch.no_grad(): | |
| outputs = model(**encoding) | |
| model.to(device) | |
| break | |
| elif "index out of range" in error_str: | |
| print(f"Index error with max_tokens={token_limit}: {e}") | |
| if token_limit == token_limits[-1]: | |
| print(f"All token limits exhausted, returning empty results") | |
| return [] | |
| else: | |
| print(f"Retrying with smaller token limit...") | |
| continue | |
| else: | |
| raise | |
| except Exception as e: | |
| print(f"Unexpected error in model processing: {e}") | |
| if token_limit == token_limits[-1]: | |
| return [] | |
| else: | |
| continue | |
| try: | |
| tokens = processor.tokenizer.convert_ids_to_tokens(encoding["input_ids"][0]) | |
| boxes = encoding["bbox"][0].tolist() | |
| except Exception as e: | |
| print(f"Error extracting tokens/boxes: {e}") | |
| return [] | |
| results = [] | |
| processed_boxes = set() | |
| for idx, (token, box) in enumerate(zip(tokens, boxes)): | |
| try: | |
| if token not in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>', '<pad>']: | |
| x_norm, y_norm, x2_norm, y2_norm = box | |
| if x_norm == 0 and y_norm == 0 and x2_norm == 0 and y2_norm == 0: | |
| continue | |
| # Convert normalized coordinates to pixel coordinates | |
| x = (x_norm / 1000.0) * img_width | |
| y = (y_norm / 1000.0) * img_height | |
| x2 = (x2_norm / 1000.0) * img_width | |
| y2 = (y2_norm / 1000.0) * img_height | |
| width = x2 - x | |
| height = y2 - y | |
| if width < 1 or height < 1: | |
| continue | |
| box_tuple = (round(x), round(y), round(width), round(height)) | |
| if box_tuple in processed_boxes: | |
| continue | |
| processed_boxes.add(box_tuple) | |
| clean_token = token.replace('##', '') | |
| results.append({ | |
| "text": clean_token, | |
| "bbox": { | |
| "x": x, | |
| "y": y, | |
| "width": width, | |
| "height": height | |
| } | |
| }) | |
| except Exception as e: | |
| print(f"Error processing token at index {idx}: {e}") | |
| continue | |
| return results | |
| def should_split_page(rendered_width: int, rendered_height: int, max_width: int) -> Tuple[bool, str]: | |
| """Determine if a page should be split based on rendered dimensions.""" | |
| if rendered_width > max_width: | |
| return (True, "horizontal") | |
| return (False, None) | |
| def split_image_intelligently(image: Image.Image, max_width: int, | |
| overlap_ratio: float = 0.1) -> List[Tuple[Image.Image, int]]: | |
| """Split image into overlapping chunks along the width.""" | |
| img_width, img_height = image.size | |
| if img_width <= max_width: | |
| return [(image, 0)] | |
| overlap_pixels = int(max_width * overlap_ratio) | |
| step_size = max_width - overlap_pixels | |
| chunks = [] | |
| x_position = 0 | |
| while x_position < img_width: | |
| right_edge = min(x_position + max_width, img_width) | |
| if right_edge < img_width and (img_width - right_edge) < (max_width * 0.3): | |
| right_edge = img_width | |
| chunk = image.crop((x_position, 0, right_edge, img_height)) | |
| chunks.append((chunk, x_position)) | |
| print(f" Created chunk at x={x_position}, width={right_edge - x_position}") | |
| if right_edge >= img_width: | |
| break | |
| x_position += step_size | |
| return chunks | |
| def process_pdf(pdf_bytes: bytes, split_wide: bool = True): | |
| """ | |
| Process PDF with proper handling of rotated pages. | |
| KEY FIX: We now work with ACTUAL rendered dimensions instead of assuming | |
| they match the effective dimensions. We map coordinates based on the | |
| actual render, then transform them to the effective coordinate space. | |
| """ | |
| RENDER_SCALE = 3.0 | |
| MAX_WIDTH = 1800 # Maximum width for a chunk in rendered pixels (reduced to ensure splitting) | |
| MAX_TOKENS = 512 # Reduced to prevent index out of range errors with large images | |
| all_results = [] | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: | |
| tmp_file.write(pdf_bytes) | |
| tmp_file.flush() | |
| pdf_document = fitz.open(tmp_file.name) | |
| for page_num in range(len(pdf_document)): | |
| try: | |
| page = pdf_document[page_num] | |
| # Get original page dimensions and rotation | |
| original_rect = page.rect | |
| original_width = original_rect.width | |
| original_height = original_rect.height | |
| original_rotation = page.rotation | |
| print(f"\nProcessing page {page_num + 1}:") | |
| print(f" Original dimensions: {original_width}x{original_height}") | |
| print(f" Rotation: {original_rotation}°") | |
| # Determine effective dimensions (what the page looks like when properly oriented) | |
| if original_rotation in [90, 270]: | |
| effective_pdf_width = original_height | |
| effective_pdf_height = original_width | |
| else: | |
| effective_pdf_width = original_width | |
| effective_pdf_height = original_height | |
| print(f" Effective PDF dimensions (after rotation): {effective_pdf_width}x{effective_pdf_height}") | |
| # Render the page - PyMuPDF may not rotate it as expected | |
| mat = fitz.Matrix(RENDER_SCALE, RENDER_SCALE) | |
| pix = page.get_pixmap(matrix=mat) | |
| img_data = pix.tobytes("png") | |
| full_image = Image.open(io.BytesIO(img_data)).convert("RGB") | |
| rendered_width, rendered_height = full_image.size | |
| print(f" Actual rendered dimensions: {rendered_width}x{rendered_height}") | |
| # Detect if dimensions don't match expectations | |
| expected_rendered_width = effective_pdf_width * RENDER_SCALE | |
| expected_rendered_height = effective_pdf_height * RENDER_SCALE | |
| dimensions_swapped = False | |
| if (abs(rendered_width - expected_rendered_height) < 10 and | |
| abs(rendered_height - expected_rendered_width) < 10): | |
| print(f" ⚠️ Dimensions are swapped! Rotating image 90° to match expected orientation.") | |
| # Rotate the image to match expected orientation | |
| full_image = full_image.rotate(-90, expand=True) | |
| rendered_width, rendered_height = full_image.size | |
| print(f" After rotation: {rendered_width}x{rendered_height}") | |
| dimensions_swapped = True | |
| # Calculate the scale factor from rendered pixels to effective PDF points | |
| # This handles any discrepancies between expected and actual rendering | |
| scale_x = rendered_width / (effective_pdf_width * RENDER_SCALE) | |
| scale_y = rendered_height / (effective_pdf_height * RENDER_SCALE) | |
| print(f" Scale factors: x={scale_x:.4f}, y={scale_y:.4f}") | |
| page_results = [] | |
| # Decide if we need to split | |
| should_split_decision, split_direction = should_split_page( | |
| rendered_width, rendered_height, MAX_WIDTH | |
| ) | |
| if split_wide and should_split_decision: | |
| print(f" Splitting page ({split_direction})...") | |
| chunks = split_image_intelligently(full_image, MAX_WIDTH, overlap_ratio=0.2) | |
| print(f" Created {len(chunks)} chunks") | |
| for chunk_idx, (chunk_image, x_offset) in enumerate(chunks): | |
| chunk_width, chunk_height = chunk_image.size | |
| print(f" Processing chunk {chunk_idx + 1}: offset={x_offset}px, size={chunk_width}x{chunk_height}px") | |
| chunk_results = process_image_chunk(chunk_image, max_tokens=MAX_TOKENS) | |
| print(f" Extracted {len(chunk_results)} items from chunk {chunk_idx + 1}") | |
| if chunk_results and chunk_idx < 2: | |
| print(f" Sample items from chunk {chunk_idx + 1}:") | |
| for i, item in enumerate(chunk_results[:3]): | |
| print(f" Item {i+1}: text='{item['text']}', chunk_x={item['bbox']['x']:.1f}px") | |
| # Transform coordinates from chunk space to PDF effective space | |
| for result in chunk_results: | |
| bbox = result['bbox'] | |
| # Step 1: Chunk coordinates -> Full rendered image coordinates | |
| rendered_x = bbox['x'] + x_offset | |
| rendered_y = bbox['y'] | |
| # Step 2: Rendered coordinates -> PDF points in effective space | |
| # Account for the actual render scale and any dimension swapping | |
| pdf_x = rendered_x / (RENDER_SCALE * scale_x) | |
| pdf_y = rendered_y / (RENDER_SCALE * scale_y) | |
| pdf_width = bbox['width'] / (RENDER_SCALE * scale_x) | |
| pdf_height = bbox['height'] / (RENDER_SCALE * scale_y) | |
| bbox['x'] = pdf_x | |
| bbox['y'] = pdf_y | |
| bbox['width'] = pdf_width | |
| bbox['height'] = pdf_height | |
| # Debug first item | |
| if result == chunk_results[0]: | |
| print(f" Transform: chunk_x={bbox['x'] - pdf_x + rendered_x - x_offset:.1f}px + offset={x_offset}px = rendered_x={rendered_x:.1f}px → pdf_x={pdf_x:.1f}pts") | |
| page_results.extend(chunk_results) | |
| print(f" Total items before deduplication: {len(page_results)}") | |
| else: | |
| # Process full page without splitting | |
| print(" Processing full page without splitting...") | |
| chunk_results = process_image_chunk(full_image, max_tokens=MAX_TOKENS) | |
| for result in chunk_results: | |
| bbox = result['bbox'] | |
| bbox['x'] = bbox['x'] / (RENDER_SCALE * scale_x) | |
| bbox['y'] = bbox['y'] / (RENDER_SCALE * scale_y) | |
| bbox['width'] = bbox['width'] / (RENDER_SCALE * scale_x) | |
| bbox['height'] = bbox['height'] / (RENDER_SCALE * scale_y) | |
| page_results = chunk_results | |
| print(f" Extracted {len(chunk_results)} items") | |
| # Deduplication | |
| unique_results = deduplicate_results(page_results) | |
| print(f" After deduplication: {len(unique_results)} unique items") | |
| # Verify coordinate ranges | |
| if unique_results: | |
| x_coords = [item['bbox']['x'] for item in unique_results] | |
| y_coords = [item['bbox']['y'] for item in unique_results] | |
| print(f" Final coordinate ranges:") | |
| print(f" X: {min(x_coords):.1f} to {max(x_coords):.1f} (effective width: {effective_pdf_width:.1f})") | |
| print(f" Y: {min(y_coords):.1f} to {max(y_coords):.1f} (effective height: {effective_pdf_height:.1f})") | |
| if max(x_coords) > effective_pdf_width + 10: | |
| print(f" ⚠️ WARNING: Some X coordinates still exceed effective page width!") | |
| elif max(x_coords) > effective_pdf_width: | |
| print(f" ℹ️ Note: Max X slightly exceeds width (likely edge items), but within tolerance") | |
| else: | |
| print(f" ✓ All coordinates within expected bounds") | |
| all_results.append({ | |
| "page": page_num + 1, | |
| "page_dimensions": { | |
| "width": original_width, | |
| "height": original_height | |
| }, | |
| "effective_dimensions": { | |
| "width": effective_pdf_width, | |
| "height": effective_pdf_height | |
| }, | |
| "rotation": original_rotation, | |
| "extractions": unique_results | |
| }) | |
| except Exception as e: | |
| print(f"Error processing page {page_num + 1}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| all_results.append({ | |
| "page": page_num + 1, | |
| "page_dimensions": {"width": 0, "height": 0}, | |
| "effective_dimensions": {"width": 0, "height": 0}, | |
| "rotation": 0, | |
| "extractions": [], | |
| "error": str(e) | |
| }) | |
| pdf_document.close() | |
| os.unlink(tmp_file.name) | |
| return { | |
| "document_type": "pdf", | |
| "total_pages": len(all_results), | |
| "pages": all_results | |
| } | |
| def deduplicate_results(results: List[Dict], tolerance: float = 10.0) -> List[Dict]: | |
| """Remove duplicate extractions using spatial clustering.""" | |
| if not results: | |
| return [] | |
| unique_results = [] | |
| processed_indices = set() | |
| for i, result in enumerate(results): | |
| if i in processed_indices: | |
| continue | |
| bbox = result['bbox'] | |
| center_x = bbox['x'] + bbox['width'] / 2 | |
| center_y = bbox['y'] + bbox['height'] / 2 | |
| cluster = [result] | |
| cluster_indices = {i} | |
| for j, other in enumerate(results): | |
| if j <= i or j in processed_indices: | |
| continue | |
| other_bbox = other['bbox'] | |
| other_center_x = other_bbox['x'] + other_bbox['width'] / 2 | |
| other_center_y = other_bbox['y'] + other_bbox['height'] / 2 | |
| dist = math.sqrt((center_x - other_center_x)**2 + (center_y - other_center_y)**2) | |
| if dist < tolerance: | |
| size_ratio_w = bbox['width'] / other_bbox['width'] if other_bbox['width'] > 0 else 1 | |
| size_ratio_h = bbox['height'] / other_bbox['height'] if other_bbox['height'] > 0 else 1 | |
| if 0.7 < size_ratio_w < 1.3 and 0.7 < size_ratio_h < 1.3: | |
| cluster.append(other) | |
| cluster_indices.add(j) | |
| best_result = max(cluster, key=lambda r: len(r.get('text', ''))) | |
| unique_results.append(best_result) | |
| processed_indices.update(cluster_indices) | |
| return unique_results | |
| def process_image(image_bytes): | |
| """Process single image""" | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| img_width, img_height = image.size | |
| print(f"Processing single image: {img_width}x{img_height}") | |
| should_split_decision, _ = should_split_page(img_width, img_height, 2000) | |
| if should_split_decision: | |
| print(" Image is wide, splitting into chunks...") | |
| chunks = split_image_intelligently(image, 2000, overlap_ratio=0.2) | |
| all_results = [] | |
| for chunk_idx, (chunk_image, x_offset) in enumerate(chunks): | |
| chunk_results = process_image_chunk(chunk_image, max_tokens=768) | |
| for result in chunk_results: | |
| result['bbox']['x'] += x_offset | |
| all_results.extend(chunk_results) | |
| results = deduplicate_results(all_results) | |
| else: | |
| results = process_image_chunk(image, max_tokens=768) | |
| print(f" Total extractions: {len(results)}") | |
| return { | |
| "document_type": "image", | |
| "image_dimensions": { | |
| "width": img_width, | |
| "height": img_height | |
| }, | |
| "extractions": results | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |