File size: 3,003 Bytes
7f44b94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
from typing import Dict, List, Any
from PIL import Image
import base64
import io
import os
import torch
class EndpointHandler:
def __init__(self, path=""):
from doclayout_yolo import YOLOv10
# Load model from repo path
model_path = os.path.join(path, "doclayout_yolo_docstructbench_imgsz1024.pt")
self.model = YOLOv10(model_path)
# Label mapping
self.id_to_names = {
0: 'title',
1: 'plain_text',
2: 'abandon',
3: 'figure',
4: 'figure_caption',
5: 'table',
6: 'table_caption',
7: 'table_footnote',
8: 'isolate_formula',
9: 'formula_caption'
}
# Set device
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process image and return layout detections.
Args:
data: Dictionary with:
- "inputs": base64 encoded image string or PIL Image
- "parameters" (optional): {
"confidence": float (default 0.2),
"iou_threshold": float (default 0.45)
}
Returns:
List of detections with label, score, and bounding box
"""
# Get image from request
image = data.get("inputs")
# Get optional parameters
params = data.get("parameters", {})
conf_threshold = params.get("confidence", 0.2)
iou_threshold = params.get("iou_threshold", 0.45)
# Handle base64 encoded image
if isinstance(image, str):
# Remove data URL prefix if present
if "base64," in image:
image = image.split("base64,")[1]
image = Image.open(io.BytesIO(base64.b64decode(image)))
# Run inference
results = self.model.predict(
image,
imgsz=1024,
conf=conf_threshold,
iou=iou_threshold,
device=self.device
)[0]
# Format output
detections = []
boxes = results.boxes
for i in range(len(boxes)):
box = boxes[i]
cls_id = int(box.cls.item())
detections.append({
"label": self.id_to_names.get(cls_id, f"class_{cls_id}"),
"score": round(float(box.conf.item()), 4),
"box": {
"x1": round(float(box.xyxy[0][0].item()), 2),
"y1": round(float(box.xyxy[0][1].item()), 2),
"x2": round(float(box.xyxy[0][2].item()), 2),
"y2": round(float(box.xyxy[0][3].item()), 2)
}
})
# Sort by confidence score
detections.sort(key=lambda x: x["score"], reverse=True)
return detections |