File size: 3,003 Bytes
7f44b94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from typing import Dict, List, Any
from PIL import Image
import base64
import io
import os
import torch

class EndpointHandler:
    def __init__(self, path=""):
        from doclayout_yolo import YOLOv10
        
        # Load model from repo path
        model_path = os.path.join(path, "doclayout_yolo_docstructbench_imgsz1024.pt")
        self.model = YOLOv10(model_path)
        
        # Label mapping
        self.id_to_names = {
            0: 'title',
            1: 'plain_text',
            2: 'abandon',
            3: 'figure',
            4: 'figure_caption',
            5: 'table',
            6: 'table_caption',
            7: 'table_footnote',
            8: 'isolate_formula',
            9: 'formula_caption'
        }
        
        # Set device
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Process image and return layout detections.
        
        Args:
            data: Dictionary with:
                - "inputs": base64 encoded image string or PIL Image
                - "parameters" (optional): {
                    "confidence": float (default 0.2),
                    "iou_threshold": float (default 0.45)
                  }
        
        Returns:
            List of detections with label, score, and bounding box
        """
        # Get image from request
        image = data.get("inputs")
        
        # Get optional parameters
        params = data.get("parameters", {})
        conf_threshold = params.get("confidence", 0.2)
        iou_threshold = params.get("iou_threshold", 0.45)
        
        # Handle base64 encoded image
        if isinstance(image, str):
            # Remove data URL prefix if present
            if "base64," in image:
                image = image.split("base64,")[1]
            image = Image.open(io.BytesIO(base64.b64decode(image)))
        
        # Run inference
        results = self.model.predict(
            image,
            imgsz=1024,
            conf=conf_threshold,
            iou=iou_threshold,
            device=self.device
        )[0]
        
        # Format output
        detections = []
        boxes = results.boxes
        
        for i in range(len(boxes)):
            box = boxes[i]
            cls_id = int(box.cls.item())
            
            detections.append({
                "label": self.id_to_names.get(cls_id, f"class_{cls_id}"),
                "score": round(float(box.conf.item()), 4),
                "box": {
                    "x1": round(float(box.xyxy[0][0].item()), 2),
                    "y1": round(float(box.xyxy[0][1].item()), 2),
                    "x2": round(float(box.xyxy[0][2].item()), 2),
                    "y2": round(float(box.xyxy[0][3].item()), 2)
                }
            })
        
        # Sort by confidence score
        detections.sort(key=lambda x: x["score"], reverse=True)
        
        return detections