layoutlmv3-custom / handler.py
Alfonso Velasco
pdf
f007bd6
from typing import Dict, List, Any
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
import torch
from PIL import Image
import io
import base64
import fitz # PyMuPDF
import tempfile
class EndpointHandler():
def __init__(self, path=""):
# Load from Microsoft's repo
self.processor = LayoutLMv3Processor.from_pretrained(
"microsoft/layoutlmv3-base",
apply_ocr=True
)
self.model = LayoutLMv3ForTokenClassification.from_pretrained(
"microsoft/layoutlmv3-base"
)
self.model.eval()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def process_image(self, image):
"""Process a single image and return extractions"""
encoding = self.processor(
image,
truncation=True,
padding="max_length",
max_length=512,
return_tensors="pt"
)
encoding = {k: v.to(self.device) for k, v in encoding.items() if isinstance(v, torch.Tensor)}
with torch.no_grad():
outputs = self.model(**encoding)
tokens = self.processor.tokenizer.convert_ids_to_tokens(encoding["input_ids"][0].cpu())
boxes = encoding["bbox"][0].cpu().tolist()
results = []
for token, box in zip(tokens, boxes):
if token not in ['[CLS]', '[SEP]', '[PAD]']:
results.append({
"text": token,
"bbox": {
"x": box[0],
"y": box[1],
"width": box[2] - box[0],
"height": box[3] - box[1]
}
})
return results
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
inputs = data.pop("inputs", data)
# Handle different input formats
if isinstance(inputs, dict):
# Check if it's a PDF
if "pdf" in inputs:
file_data = inputs["pdf"]
else:
file_data = inputs.get("image", inputs.get("inputs", ""))
else:
file_data = inputs
# Remove base64 prefix if present
if isinstance(file_data, str) and "base64," in file_data:
file_data = file_data.split("base64,")[1]
# Decode base64
file_bytes = base64.b64decode(file_data)
# Check if it's a PDF or image
if file_bytes.startswith(b'%PDF'):
# Process PDF
all_results = []
# Save PDF to temporary file
with tempfile.NamedTemporaryFile(suffix='.pdf', delete=False) as tmp_file:
tmp_file.write(file_bytes)
tmp_file.flush()
# Open with PyMuPDF
pdf_document = fitz.open(tmp_file.name)
# Process each page
for page_num in range(len(pdf_document)):
page = pdf_document[page_num]
# Convert page to image (PIL format)
mat = fitz.Matrix(2.0, 2.0) # 2x scaling for better quality
pix = page.get_pixmap(matrix=mat)
img_data = pix.tobytes("png")
image = Image.open(io.BytesIO(img_data)).convert("RGB")
# Process the page
page_results = self.process_image(image)
# Add page context to results
all_results.append({
"page": page_num + 1,
"page_width": page.rect.width,
"page_height": page.rect.height,
"extractions": page_results
})
pdf_document.close()
# Return all pages' results
return {
"document_type": "pdf",
"total_pages": len(all_results),
"pages": all_results
}
else:
# Process as image
image = Image.open(io.BytesIO(file_bytes)).convert("RGB")
results = self.process_image(image)
return {
"document_type": "image",
"extractions": results
}