Spaces:
Running
Running
| import abc | |
| import cv2 | |
| import numpy as np | |
| import contextlib | |
| from huggingface_hub import hf_hub_download | |
| class DocLayoutModel(abc.ABC): | |
| def load_torch(): | |
| model = TorchModel.from_pretrained( | |
| repo_id="juliozhao/DocLayout-YOLO-DocStructBench", | |
| filename="doclayout_yolo_docstructbench_imgsz1024.pt", | |
| ) | |
| return model | |
| def load_onnx(): | |
| model = OnnxModel.from_pretrained( | |
| repo_id="wybxc/DocLayout-YOLO-DocStructBench-onnx", | |
| filename="doclayout_yolo_docstructbench_imgsz1024.onnx", | |
| ) | |
| return model | |
| def load_available(): | |
| with contextlib.suppress(ImportError): | |
| return DocLayoutModel.load_torch() | |
| with contextlib.suppress(ImportError): | |
| return DocLayoutModel.load_onnx() | |
| raise ImportError( | |
| "Please install the `torch` or `onnx` feature to use the DocLayout model." | |
| ) | |
| def stride(self) -> int: | |
| """Stride of the model input.""" | |
| pass | |
| def predict(self, image, imgsz=1024, **kwargs) -> list: | |
| """ | |
| Predict the layout of a document page. | |
| Args: | |
| image: The image of the document page. | |
| imgsz: Resize the image to this size. Must be a multiple of the stride. | |
| **kwargs: Additional arguments. | |
| """ | |
| pass | |
| class TorchModel(DocLayoutModel): | |
| def __init__(self, model_path: str): | |
| try: | |
| import doclayout_yolo | |
| except ImportError: | |
| raise ImportError( | |
| "Please install the `torch` feature to use the Torch model." | |
| ) | |
| self.model_path = model_path | |
| self.model = doclayout_yolo.YOLOv10(model_path) | |
| def from_pretrained(repo_id: str, filename: str): | |
| pth = hf_hub_download(repo_id=repo_id, filename=filename) | |
| return TorchModel(pth) | |
| def stride(self): | |
| return 32 | |
| def predict(self, *args, **kwargs): | |
| return self.model.predict(*args, **kwargs) | |
| class YoloResult: | |
| """Helper class to store detection results from ONNX model.""" | |
| def __init__(self, boxes, names): | |
| self.boxes = [YoloBox(data=d) for d in boxes] | |
| self.boxes.sort(key=lambda x: x.conf, reverse=True) | |
| self.names = names | |
| class YoloBox: | |
| """Helper class to store detection results from ONNX model.""" | |
| def __init__(self, data): | |
| self.xyxy = data[:4] | |
| self.conf = data[-2] | |
| self.cls = data[-1] | |
| class OnnxModel(DocLayoutModel): | |
| def __init__(self, model_path: str): | |
| import ast | |
| try: | |
| import onnx | |
| import onnxruntime | |
| except ImportError: | |
| raise ImportError( | |
| "Please install the `onnx` feature to use the ONNX model." | |
| ) | |
| self.model_path = model_path | |
| model = onnx.load(model_path) | |
| metadata = {d.key: d.value for d in model.metadata_props} | |
| self._stride = ast.literal_eval(metadata["stride"]) | |
| self._names = ast.literal_eval(metadata["names"]) | |
| self.model = onnxruntime.InferenceSession(model.SerializeToString()) | |
| def from_pretrained(repo_id: str, filename: str): | |
| pth = hf_hub_download(repo_id=repo_id, filename=filename) | |
| return OnnxModel(pth) | |
| def stride(self): | |
| return self._stride | |
| def resize_and_pad_image(self, image, new_shape): | |
| """ | |
| Resize and pad the image to the specified size, ensuring dimensions are multiples of stride. | |
| Parameters: | |
| - image: Input image | |
| - new_shape: Target size (integer or (height, width) tuple) | |
| - stride: Padding alignment stride, default 32 | |
| Returns: | |
| - Processed image | |
| """ | |
| if isinstance(new_shape, int): | |
| new_shape = (new_shape, new_shape) | |
| h, w = image.shape[:2] | |
| new_h, new_w = new_shape | |
| # Calculate scaling ratio | |
| r = min(new_h / h, new_w / w) | |
| resized_h, resized_w = int(round(h * r)), int(round(w * r)) | |
| # Resize image | |
| image = cv2.resize( | |
| image, (resized_w, resized_h), interpolation=cv2.INTER_LINEAR | |
| ) | |
| # Calculate padding size and align to stride multiple | |
| pad_w = (new_w - resized_w) % self.stride | |
| pad_h = (new_h - resized_h) % self.stride | |
| top, bottom = pad_h // 2, pad_h - pad_h // 2 | |
| left, right = pad_w // 2, pad_w - pad_w // 2 | |
| # Add padding | |
| image = cv2.copyMakeBorder( | |
| image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114) | |
| ) | |
| return image | |
| def scale_boxes(self, img1_shape, boxes, img0_shape): | |
| """ | |
| Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally | |
| specified in (img1_shape) to the shape of a different image (img0_shape). | |
| Args: | |
| img1_shape (tuple): The shape of the image that the bounding boxes are for, | |
| in the format of (height, width). | |
| boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2) | |
| img0_shape (tuple): the shape of the target image, in the format of (height, width). | |
| Returns: | |
| boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2) | |
| """ | |
| # Calculate scaling ratio | |
| gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) | |
| # Calculate padding size | |
| pad_x = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1) | |
| pad_y = round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1) | |
| # Remove padding and scale boxes | |
| boxes[..., :4] = (boxes[..., :4] - [pad_x, pad_y, pad_x, pad_y]) / gain | |
| return boxes | |
| def predict(self, image, imgsz=1024, **kwargs): | |
| # Preprocess input image | |
| orig_h, orig_w = image.shape[:2] | |
| pix = self.resize_and_pad_image(image, new_shape=imgsz) | |
| pix = np.transpose(pix, (2, 0, 1)) # CHW | |
| pix = np.expand_dims(pix, axis=0) # BCHW | |
| pix = pix.astype(np.float32) / 255.0 # Normalize to [0, 1] | |
| new_h, new_w = pix.shape[2:] | |
| # Run inference | |
| preds = self.model.run(None, {"images": pix})[0] | |
| # Postprocess predictions | |
| preds = preds[preds[..., 4] > 0.25] | |
| preds[..., :4] = self.scale_boxes( | |
| (new_h, new_w), preds[..., :4], (orig_h, orig_w) | |
| ) | |
| return [YoloResult(boxes=preds, names=self._names)] | |