from typing import Dict, List, Any from PIL import Image import base64 import io import os import torch class EndpointHandler: def __init__(self, path=""): from doclayout_yolo import YOLOv10 # Load model from repo path model_path = os.path.join(path, "doclayout_yolo_docstructbench_imgsz1024.pt") self.model = YOLOv10(model_path) # Label mapping self.id_to_names = { 0: 'title', 1: 'plain_text', 2: 'abandon', 3: 'figure', 4: 'figure_caption', 5: 'table', 6: 'table_caption', 7: 'table_footnote', 8: 'isolate_formula', 9: 'formula_caption' } # Set device self.device = 'cuda' if torch.cuda.is_available() else 'cpu' def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Process image and return layout detections. Args: data: Dictionary with: - "inputs": base64 encoded image string or PIL Image - "parameters" (optional): { "confidence": float (default 0.2), "iou_threshold": float (default 0.45) } Returns: List of detections with label, score, and bounding box """ # Get image from request image = data.get("inputs") # Get optional parameters params = data.get("parameters", {}) conf_threshold = params.get("confidence", 0.2) iou_threshold = params.get("iou_threshold", 0.45) # Handle base64 encoded image if isinstance(image, str): # Remove data URL prefix if present if "base64," in image: image = image.split("base64,")[1] image = Image.open(io.BytesIO(base64.b64decode(image))) # Run inference results = self.model.predict( image, imgsz=1024, conf=conf_threshold, iou=iou_threshold, device=self.device )[0] # Format output detections = [] boxes = results.boxes for i in range(len(boxes)): box = boxes[i] cls_id = int(box.cls.item()) detections.append({ "label": self.id_to_names.get(cls_id, f"class_{cls_id}"), "score": round(float(box.conf.item()), 4), "box": { "x1": round(float(box.xyxy[0][0].item()), 2), "y1": round(float(box.xyxy[0][1].item()), 2), "x2": round(float(box.xyxy[0][2].item()), 2), "y2": round(float(box.xyxy[0][3].item()), 2) } }) # Sort by confidence score detections.sort(key=lambda x: x["score"], reverse=True) return detections