Alfonso Velasco commited on
Commit
c8bd4a2
·
0 Parent(s):

Add custom handler for LayoutLMv3 inference

Browse files
Files changed (2) hide show
  1. handler.py +60 -0
  2. requirements.txt +4 -0
handler.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
3
+ import torch
4
+ from PIL import Image
5
+ import io
6
+ import base64
7
+
8
+ class EndpointHandler():
9
+ def __init__(self, path=""):
10
+ # Load from Microsoft's repo
11
+ self.processor = LayoutLMv3Processor.from_pretrained(
12
+ "microsoft/layoutlmv3-base",
13
+ apply_ocr=True
14
+ )
15
+ self.model = LayoutLMv3ForTokenClassification.from_pretrained(
16
+ "microsoft/layoutlmv3-base"
17
+ )
18
+ self.model.eval()
19
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ self.model.to(self.device)
21
+
22
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
23
+ inputs = data.pop("inputs", data)
24
+
25
+ if isinstance(inputs, dict):
26
+ image_data = inputs.get("image", inputs.get("inputs", ""))
27
+ else:
28
+ image_data = inputs
29
+
30
+ if "base64," in image_data:
31
+ image_data = image_data.split("base64,")[1]
32
+
33
+ image_bytes = base64.b64decode(image_data)
34
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
35
+
36
+ encoding = self.processor(
37
+ image,
38
+ truncation=True,
39
+ padding="max_length",
40
+ max_length=512,
41
+ return_tensors="pt"
42
+ )
43
+
44
+ encoding = {k: v.to(self.device) for k, v in encoding.items() if isinstance(v, torch.Tensor)}
45
+
46
+ with torch.no_grad():
47
+ outputs = self.model(**encoding)
48
+
49
+ tokens = self.processor.tokenizer.convert_ids_to_tokens(encoding["input_ids"][0].cpu())
50
+ boxes = encoding["bbox"][0].cpu().tolist()
51
+
52
+ results = []
53
+ for token, box in zip(tokens, boxes):
54
+ if token not in ['[CLS]', '[SEP]', '[PAD]']:
55
+ results.append({
56
+ "text": token,
57
+ "bbox": {"x": box[0], "y": box[1], "width": box[2] - box[0], "height": box[3] - box[1]}
58
+ })
59
+
60
+ return {"extractions": results}
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers>=4.35.0
2
+ torch>=2.0.0
3
+ pillow>=9.0.0
4
+ pytesseract>=0.3.10