"""Simple DocLayout model for inference.""" import json from pathlib import Path from typing import Dict, List, Union import numpy as np from PIL import Image from ultralytics import YOLO class DocLayoutModel: """ Document layout detection model. Examples -------- >>> model = DocLayoutModel("model.pt") >>> results = model.predict("document.png") >>> for det in results: ... print(f"{det['class_name']}: {det['confidence']:.2f}") """ # Default class mappings DOCSTRUCTBENCH_CLASSES = { 0: "title", 1: "plain_text", 2: "abandon", 3: "figure", 4: "figure_caption", 5: "table", 6: "table_caption", 7: "table_footnote", 8: "isolate_formula", 9: "formula_caption", } DOCLAYNET_CLASSES = { 0: "Caption", 1: "Footnote", 2: "Formula", 3: "List-item", 4: "Page-footer", 5: "Page-header", 6: "Picture", 7: "Section-header", 8: "Table", 9: "Text", 10: "Title", } def __init__( self, weights_path: Union[str, Path], config_path: Union[str, Path, None] = None, model_type: str = "auto", ): """ Initialize model. Parameters ---------- weights_path : str or Path Path to model weights (.pt file) config_path : str or Path, optional Path to config.json with class names. If None, auto-detects from weights filename. model_type : str, default="auto" Model type: "docstructbench", "doclaynet", or "auto" (detect from filename) """ self.weights_path = Path(weights_path) self._model = None # Load class names from config or auto-detect if config_path: with open(config_path) as f: config = json.load(f) self.class_names = {i: name for i, name in enumerate(config["class_names"])} else: self.class_names = self._get_class_names(model_type) def _get_class_names(self, model_type: str) -> Dict[int, str]: """Get class names based on model type.""" if model_type == "auto": name = self.weights_path.stem.lower() if "doclaynet" in name: return self.DOCLAYNET_CLASSES return self.DOCSTRUCTBENCH_CLASSES elif model_type == "doclaynet": return self.DOCLAYNET_CLASSES elif model_type == "docstructbench": return self.DOCSTRUCTBENCH_CLASSES else: raise ValueError(f"Unknown model type: {model_type}") @property def model(self) -> YOLO: """Lazy-load the YOLO model.""" if self._model is None: self._model = YOLO(str(self.weights_path)) return self._model def predict( self, source: Union[str, Path, Image.Image, np.ndarray], confidence: float = 0.2, image_size: int = 1024, device: str = "cpu", ) -> List[Dict]: """ Run inference on an image. Parameters ---------- source : str, Path, PIL.Image, or np.ndarray Input image confidence : float, default=0.2 Confidence threshold image_size : int, default=1024 Input image size device : str, default="cpu" Device to run on ("cpu", "cuda", "mps") Returns ------- List[Dict] List of detections, each with keys: - class_id: int - class_name: str - confidence: float - bbox: [x1, y1, x2, y2] """ results = self.model.predict( source=str(source) if isinstance(source, Path) else source, imgsz=image_size, conf=confidence, device=device, save=False, verbose=False, ) detections = [] for result in results: for box in result.boxes: cls = int(box.cls[0]) detections.append( { "class_id": cls, "class_name": self.class_names.get(cls, f"class_{cls}"), "confidence": float(box.conf[0]), "bbox": box.xyxy[0].tolist(), } ) return detections