| | import io |
| | from typing import Dict, List, Any |
| | from transformers import LayoutLMv3ForSequenceClassification, LayoutLMv3FeatureExtractor, LayoutLMv3Tokenizer, LayoutLMv3Processor |
| | import torch |
| | from subprocess import run |
| | from PIL import Image |
| |
|
| | |
| | run("apt install -y tesseract-ocr", shell=True, check=True) |
| | run("python -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.10/index.html", shell=True, check=True) |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | |
| | self.FEATURE_EXTRACTOR = LayoutLMv3FeatureExtractor() |
| | self.TOKENIZER = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base") |
| | self.PROCESSOR = LayoutLMv3Processor(self.FEATURE_EXTRACTOR, self.TOKENIZER) |
| | self.MODEL = LayoutLMv3ForSequenceClassification.from_pretrained("OtraBoi/document_classifier_testing").to(device) |
| |
|
| | def __call__(self, data: Dict): |
| | image = Image.open(io.BytesIO(data["inputs"])).convert("RGB") |
| | encoding = self.PROCESSOR(image, return_tensors="pt", padding="max_length", truncation=True) |
| | |
| | for k,v in encoding.items(): |
| | encoding[k] = v.to(self.MODEL.device) |
| |
|
| | |
| | with torch.inference_mode(): |
| | outputs = self.MODEL(**encoding) |
| | logits = outputs.logits |
| | predicted_class_idx = logits.argmax(-1).item() |
| | |
| | return self.MODEL.config.id2label[predicted_class_idx] |
| |
|