from typing import Dict, List, Any from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification import torch from PIL import Image import io import base64 import fitz # PyMuPDF import tempfile class EndpointHandler(): def __init__(self, path=""): # Load from Microsoft's repo self.processor = LayoutLMv3Processor.from_pretrained( "microsoft/layoutlmv3-base", apply_ocr=True ) self.model = LayoutLMv3ForTokenClassification.from_pretrained( "microsoft/layoutlmv3-base" ) self.model.eval() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) def process_image(self, image): """Process a single image and return extractions""" encoding = self.processor( image, truncation=True, padding="max_length", max_length=512, return_tensors="pt" ) encoding = {k: v.to(self.device) for k, v in encoding.items() if isinstance(v, torch.Tensor)} with torch.no_grad(): outputs = self.model(**encoding) tokens = self.processor.tokenizer.convert_ids_to_tokens(encoding["input_ids"][0].cpu()) boxes = encoding["bbox"][0].cpu().tolist() results = [] for token, box in zip(tokens, boxes): if token not in ['[CLS]', '[SEP]', '[PAD]']: results.append({ "text": token, "bbox": { "x": box[0], "y": box[1], "width": box[2] - box[0], "height": box[3] - box[1] } }) return results def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: inputs = data.pop("inputs", data) # Handle different input formats if isinstance(inputs, dict): # Check if it's a PDF if "pdf" in inputs: file_data = inputs["pdf"] else: file_data = inputs.get("image", inputs.get("inputs", "")) else: file_data = inputs # Remove base64 prefix if present if isinstance(file_data, str) and "base64," in file_data: file_data = file_data.split("base64,")[1] # Decode base64 file_bytes = base64.b64decode(file_data) # Check if it's a PDF or image if file_bytes.startswith(b'%PDF'): # Process PDF all_results = [] # Save PDF to temporary file with tempfile.NamedTemporaryFile(suffix='.pdf', delete=False) as tmp_file: tmp_file.write(file_bytes) tmp_file.flush() # Open with PyMuPDF pdf_document = fitz.open(tmp_file.name) # Process each page for page_num in range(len(pdf_document)): page = pdf_document[page_num] # Convert page to image (PIL format) mat = fitz.Matrix(2.0, 2.0) # 2x scaling for better quality pix = page.get_pixmap(matrix=mat) img_data = pix.tobytes("png") image = Image.open(io.BytesIO(img_data)).convert("RGB") # Process the page page_results = self.process_image(image) # Add page context to results all_results.append({ "page": page_num + 1, "page_width": page.rect.width, "page_height": page.rect.height, "extractions": page_results }) pdf_document.close() # Return all pages' results return { "document_type": "pdf", "total_pages": len(all_results), "pages": all_results } else: # Process as image image = Image.open(io.BytesIO(file_bytes)).convert("RGB") results = self.process_image(image) return { "document_type": "image", "extractions": results }