|
|
from typing import Dict, List, Any |
|
|
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification |
|
|
import torch |
|
|
from PIL import Image |
|
|
import io |
|
|
import base64 |
|
|
import fitz |
|
|
import tempfile |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path=""): |
|
|
|
|
|
self.processor = LayoutLMv3Processor.from_pretrained( |
|
|
"microsoft/layoutlmv3-base", |
|
|
apply_ocr=True |
|
|
) |
|
|
self.model = LayoutLMv3ForTokenClassification.from_pretrained( |
|
|
"microsoft/layoutlmv3-base" |
|
|
) |
|
|
self.model.eval() |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.model.to(self.device) |
|
|
|
|
|
def process_image(self, image): |
|
|
"""Process a single image and return extractions""" |
|
|
encoding = self.processor( |
|
|
image, |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
max_length=512, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
encoding = {k: v.to(self.device) for k, v in encoding.items() if isinstance(v, torch.Tensor)} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**encoding) |
|
|
|
|
|
tokens = self.processor.tokenizer.convert_ids_to_tokens(encoding["input_ids"][0].cpu()) |
|
|
boxes = encoding["bbox"][0].cpu().tolist() |
|
|
|
|
|
results = [] |
|
|
for token, box in zip(tokens, boxes): |
|
|
if token not in ['[CLS]', '[SEP]', '[PAD]']: |
|
|
results.append({ |
|
|
"text": token, |
|
|
"bbox": { |
|
|
"x": box[0], |
|
|
"y": box[1], |
|
|
"width": box[2] - box[0], |
|
|
"height": box[3] - box[1] |
|
|
} |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
inputs = data.pop("inputs", data) |
|
|
|
|
|
|
|
|
if isinstance(inputs, dict): |
|
|
|
|
|
if "pdf" in inputs: |
|
|
file_data = inputs["pdf"] |
|
|
else: |
|
|
file_data = inputs.get("image", inputs.get("inputs", "")) |
|
|
else: |
|
|
file_data = inputs |
|
|
|
|
|
|
|
|
if isinstance(file_data, str) and "base64," in file_data: |
|
|
file_data = file_data.split("base64,")[1] |
|
|
|
|
|
|
|
|
file_bytes = base64.b64decode(file_data) |
|
|
|
|
|
|
|
|
if file_bytes.startswith(b'%PDF'): |
|
|
|
|
|
all_results = [] |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.pdf', delete=False) as tmp_file: |
|
|
tmp_file.write(file_bytes) |
|
|
tmp_file.flush() |
|
|
|
|
|
|
|
|
pdf_document = fitz.open(tmp_file.name) |
|
|
|
|
|
|
|
|
for page_num in range(len(pdf_document)): |
|
|
page = pdf_document[page_num] |
|
|
|
|
|
|
|
|
mat = fitz.Matrix(2.0, 2.0) |
|
|
pix = page.get_pixmap(matrix=mat) |
|
|
img_data = pix.tobytes("png") |
|
|
image = Image.open(io.BytesIO(img_data)).convert("RGB") |
|
|
|
|
|
|
|
|
page_results = self.process_image(image) |
|
|
|
|
|
|
|
|
all_results.append({ |
|
|
"page": page_num + 1, |
|
|
"page_width": page.rect.width, |
|
|
"page_height": page.rect.height, |
|
|
"extractions": page_results |
|
|
}) |
|
|
|
|
|
pdf_document.close() |
|
|
|
|
|
|
|
|
return { |
|
|
"document_type": "pdf", |
|
|
"total_pages": len(all_results), |
|
|
"pages": all_results |
|
|
} |
|
|
|
|
|
else: |
|
|
|
|
|
image = Image.open(io.BytesIO(file_bytes)).convert("RGB") |
|
|
results = self.process_image(image) |
|
|
|
|
|
return { |
|
|
"document_type": "image", |
|
|
"extractions": results |
|
|
} |