File size: 4,535 Bytes
c8bd4a2
 
 
 
 
 
f007bd6
 
c8bd4a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f007bd6
 
c8bd4a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f007bd6
 
 
 
 
 
c8bd4a2
 
f007bd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from typing import Dict, List, Any
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
import torch
from PIL import Image
import io
import base64
import fitz  # PyMuPDF
import tempfile

class EndpointHandler():
    def __init__(self, path=""):
        # Load from Microsoft's repo
        self.processor = LayoutLMv3Processor.from_pretrained(
            "microsoft/layoutlmv3-base",
            apply_ocr=True
        )
        self.model = LayoutLMv3ForTokenClassification.from_pretrained(
            "microsoft/layoutlmv3-base"
        )
        self.model.eval()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
    
    def process_image(self, image):
        """Process a single image and return extractions"""
        encoding = self.processor(
            image,
            truncation=True,
            padding="max_length",
            max_length=512,
            return_tensors="pt"
        )
        
        encoding = {k: v.to(self.device) for k, v in encoding.items() if isinstance(v, torch.Tensor)}
        
        with torch.no_grad():
            outputs = self.model(**encoding)
        
        tokens = self.processor.tokenizer.convert_ids_to_tokens(encoding["input_ids"][0].cpu())
        boxes = encoding["bbox"][0].cpu().tolist()
        
        results = []
        for token, box in zip(tokens, boxes):
            if token not in ['[CLS]', '[SEP]', '[PAD]']:
                results.append({
                    "text": token,
                    "bbox": {
                        "x": box[0], 
                        "y": box[1], 
                        "width": box[2] - box[0], 
                        "height": box[3] - box[1]
                    }
                })
        
        return results
    
    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        inputs = data.pop("inputs", data)
        
        # Handle different input formats
        if isinstance(inputs, dict):
            # Check if it's a PDF
            if "pdf" in inputs:
                file_data = inputs["pdf"]
            else:
                file_data = inputs.get("image", inputs.get("inputs", ""))
        else:
            file_data = inputs
        
        # Remove base64 prefix if present
        if isinstance(file_data, str) and "base64," in file_data:
            file_data = file_data.split("base64,")[1]
        
        # Decode base64
        file_bytes = base64.b64decode(file_data)
        
        # Check if it's a PDF or image
        if file_bytes.startswith(b'%PDF'):
            # Process PDF
            all_results = []
            
            # Save PDF to temporary file
            with tempfile.NamedTemporaryFile(suffix='.pdf', delete=False) as tmp_file:
                tmp_file.write(file_bytes)
                tmp_file.flush()
                
                # Open with PyMuPDF
                pdf_document = fitz.open(tmp_file.name)
                
                # Process each page
                for page_num in range(len(pdf_document)):
                    page = pdf_document[page_num]
                    
                    # Convert page to image (PIL format)
                    mat = fitz.Matrix(2.0, 2.0)  # 2x scaling for better quality
                    pix = page.get_pixmap(matrix=mat)
                    img_data = pix.tobytes("png")
                    image = Image.open(io.BytesIO(img_data)).convert("RGB")
                    
                    # Process the page
                    page_results = self.process_image(image)
                    
                    # Add page context to results
                    all_results.append({
                        "page": page_num + 1,
                        "page_width": page.rect.width,
                        "page_height": page.rect.height,
                        "extractions": page_results
                    })
                
                pdf_document.close()
            
            # Return all pages' results
            return {
                "document_type": "pdf",
                "total_pages": len(all_results),
                "pages": all_results
            }
        
        else:
            # Process as image
            image = Image.open(io.BytesIO(file_bytes)).convert("RGB")
            results = self.process_image(image)
            
            return {
                "document_type": "image",
                "extractions": results
            }