|
|
from typing import Dict, List, Any |
|
|
from PIL import Image |
|
|
import base64 |
|
|
import io |
|
|
import os |
|
|
import torch |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
from doclayout_yolo import YOLOv10 |
|
|
|
|
|
|
|
|
model_path = os.path.join(path, "doclayout_yolo_docstructbench_imgsz1024.pt") |
|
|
self.model = YOLOv10(model_path) |
|
|
|
|
|
|
|
|
self.id_to_names = { |
|
|
0: 'title', |
|
|
1: 'plain_text', |
|
|
2: 'abandon', |
|
|
3: 'figure', |
|
|
4: 'figure_caption', |
|
|
5: 'table', |
|
|
6: 'table_caption', |
|
|
7: 'table_footnote', |
|
|
8: 'isolate_formula', |
|
|
9: 'formula_caption' |
|
|
} |
|
|
|
|
|
|
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Process image and return layout detections. |
|
|
|
|
|
Args: |
|
|
data: Dictionary with: |
|
|
- "inputs": base64 encoded image string or PIL Image |
|
|
- "parameters" (optional): { |
|
|
"confidence": float (default 0.2), |
|
|
"iou_threshold": float (default 0.45) |
|
|
} |
|
|
|
|
|
Returns: |
|
|
List of detections with label, score, and bounding box |
|
|
""" |
|
|
|
|
|
image = data.get("inputs") |
|
|
|
|
|
|
|
|
params = data.get("parameters", {}) |
|
|
conf_threshold = params.get("confidence", 0.2) |
|
|
iou_threshold = params.get("iou_threshold", 0.45) |
|
|
|
|
|
|
|
|
if isinstance(image, str): |
|
|
|
|
|
if "base64," in image: |
|
|
image = image.split("base64,")[1] |
|
|
image = Image.open(io.BytesIO(base64.b64decode(image))) |
|
|
|
|
|
|
|
|
results = self.model.predict( |
|
|
image, |
|
|
imgsz=1024, |
|
|
conf=conf_threshold, |
|
|
iou=iou_threshold, |
|
|
device=self.device |
|
|
)[0] |
|
|
|
|
|
|
|
|
detections = [] |
|
|
boxes = results.boxes |
|
|
|
|
|
for i in range(len(boxes)): |
|
|
box = boxes[i] |
|
|
cls_id = int(box.cls.item()) |
|
|
|
|
|
detections.append({ |
|
|
"label": self.id_to_names.get(cls_id, f"class_{cls_id}"), |
|
|
"score": round(float(box.conf.item()), 4), |
|
|
"box": { |
|
|
"x1": round(float(box.xyxy[0][0].item()), 2), |
|
|
"y1": round(float(box.xyxy[0][1].item()), 2), |
|
|
"x2": round(float(box.xyxy[0][2].item()), 2), |
|
|
"y2": round(float(box.xyxy[0][3].item()), 2) |
|
|
} |
|
|
}) |
|
|
|
|
|
|
|
|
detections.sort(key=lambda x: x["score"], reverse=True) |
|
|
|
|
|
return detections |