Spaces:
Running
on
Zero
Running
on
Zero
| # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import os | |
| import sys | |
| import torch | |
| import importlib | |
| import numpy as np | |
| import numpy.typing as npt | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Dict, List, Tuple, Union | |
| from yolox.boxes import postprocess | |
| def define_model(config_name: str = "page_element_v3", verbose: bool = True) -> nn.Module: | |
| """ | |
| Defines and initializes the model based on the configuration. | |
| Args: | |
| config_name (str): Configuration name. Defaults to "page_element_v3". | |
| verbose (bool): Whether to print verbose output. Defaults to True. | |
| Returns: | |
| torch.nn.Module: The initialized YOLOX model. | |
| """ | |
| # Load model from exp_file | |
| sys.path.append(os.path.dirname(config_name)) | |
| exp_module = importlib.import_module(os.path.basename(config_name).split(".")[0]) | |
| config = exp_module.Exp() | |
| model = config.get_model() | |
| # Load weights | |
| if verbose: | |
| print(" -> Loading weights from", config.ckpt) | |
| ckpt = torch.load(config.ckpt, map_location="cpu", weights_only=False) | |
| model.load_state_dict(ckpt["model"], strict=True) | |
| model = YoloXWrapper(model, config) | |
| return model.eval().to(config.device) | |
| def resize_pad(img: torch.Tensor, size: tuple) -> torch.Tensor: | |
| """ | |
| Resizes and pads an image to a given size. | |
| The goal is to preserve the aspect ratio of the image. | |
| Args: | |
| img (torch.Tensor[C x H x W]): The image to resize and pad. | |
| size (tuple[2]): The size to resize and pad the image to. | |
| Returns: | |
| torch.Tensor: The resized and padded image. | |
| """ | |
| img = img.float() | |
| _, h, w = img.shape | |
| scale = min(size[0] / h, size[1] / w) | |
| nh = int(h * scale) | |
| nw = int(w * scale) | |
| img = F.interpolate( | |
| img.unsqueeze(0), size=(nh, nw), mode="bilinear", align_corners=False | |
| ).squeeze(0) | |
| img = torch.clamp(img, 0, 255) | |
| pad_b = size[0] - nh | |
| pad_r = size[1] - nw | |
| img = F.pad(img, (0, pad_r, 0, pad_b), value=114.0) | |
| return img | |
| class YoloXWrapper(nn.Module): | |
| """ | |
| Wrapper for YoloX models. | |
| """ | |
| def __init__(self, model: nn.Module, config) -> None: | |
| """ | |
| Constructor | |
| Args: | |
| model (torch model): Yolo model. | |
| config (Config): Config object containing model parameters. | |
| """ | |
| super().__init__() | |
| self.model = model | |
| self.config = config | |
| # Copy config parameters | |
| self.device = config.device | |
| self.img_size = config.size | |
| self.min_bbox_size = config.min_bbox_size | |
| self.normalize_boxes = config.normalize_boxes | |
| self.conf_thresh = config.conf_thresh | |
| self.iou_thresh = config.iou_thresh | |
| self.class_agnostic = config.class_agnostic | |
| self.thresholds_per_class = config.thresholds_per_class | |
| self.labels = config.labels | |
| self.num_classes = config.num_classes | |
| def reformat_input( | |
| self, | |
| x: torch.Tensor, | |
| orig_sizes: Union[torch.Tensor, List, Tuple, npt.NDArray] | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Reformats the input data and original sizes to the correct format. | |
| Args: | |
| x (torch.Tensor[BS x C x H x W]): Input image batch. | |
| orig_sizes (torch.Tensor or list or np.ndarray): Original image sizes. | |
| Returns: | |
| torch tensor [BS x C x H x W]: Input image batch. | |
| torch tensor [BS x 2]: Original image sizes (before resizing and padding). | |
| """ | |
| # Convert image size to tensor | |
| if isinstance(orig_sizes, (list, tuple)): | |
| orig_sizes = np.array(orig_sizes) | |
| if orig_sizes.shape[-1] == 3: # remove channel | |
| orig_sizes = orig_sizes[..., :2] | |
| if isinstance(orig_sizes, np.ndarray): | |
| orig_sizes = torch.from_numpy(orig_sizes).to(self.device) | |
| # Add batch dimension if not present | |
| if len(x.size()) == 3: | |
| x = x.unsqueeze(0) | |
| if len(orig_sizes.size()) == 1: | |
| orig_sizes = orig_sizes.unsqueeze(0) | |
| return x, orig_sizes | |
| def preprocess(self, image: Union[torch.Tensor, npt.NDArray]) -> torch.Tensor: | |
| """ | |
| YoloX preprocessing function: | |
| - Resizes to the longest edge to img_size while preserving the aspect ratio | |
| - Pads the shortest edge to img_size | |
| Args: | |
| image (torch tensor or np array [H x W x 3]): Input images in uint8 format. | |
| Returns: | |
| torch tensor [3 x H x W]: Processed image. | |
| """ | |
| if not isinstance(image, torch.Tensor): | |
| image = torch.from_numpy(image) | |
| image = image.permute(2, 0, 1) # [H, W, 3] -> [3, H, W] | |
| image = resize_pad(image, self.img_size) | |
| return image.float() | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| orig_sizes: Union[torch.Tensor, List, Tuple, npt.NDArray] | |
| ) -> List[Dict[str, torch.Tensor]]: | |
| """ | |
| Forward pass of the model. | |
| Applies NMS and reformats the predictions. | |
| Args: | |
| x (torch.Tensor[BS x C x H x W]): Input image batch. | |
| orig_sizes (torch.Tensor or list or np.ndarray): Original image sizes. | |
| Returns: | |
| list[dict]: List of prediction dictionaries. Each dictionary contains: | |
| - labels (torch.Tensor[N]): Class labels | |
| - boxes (torch.Tensor[N x 4]): Bounding boxes | |
| - scores (torch.Tensor[N]): Confidence scores. | |
| """ | |
| x, orig_sizes = self.reformat_input(x, orig_sizes) | |
| # Scale to 0-255 if in range 0-1 | |
| if x.max() <= 1: | |
| x *= 255 | |
| pred_boxes = self.model(x.to(self.device)) | |
| # NMS | |
| pred_boxes = postprocess( | |
| pred_boxes, | |
| self.config.num_classes, | |
| self.conf_thresh, | |
| self.iou_thresh, | |
| class_agnostic=self.class_agnostic, | |
| ) | |
| # Reformat output | |
| preds = [] | |
| for i, (p, size) in enumerate(zip(pred_boxes, orig_sizes)): | |
| if p is None: # No detections | |
| preds.append({ | |
| "labels": torch.empty(0), | |
| "boxes": torch.empty((0, 4)), | |
| "scores": torch.empty(0), | |
| }) | |
| continue | |
| p = p.view(-1, p.size(-1)) | |
| ratio = min(self.img_size[0] / size[0], self.img_size[1] / size[1]) | |
| boxes = p[:, :4] / ratio | |
| # Clip | |
| boxes[:, [0, 2]] = torch.clamp(boxes[:, [0, 2]], 0, size[1]) | |
| boxes[:, [1, 3]] = torch.clamp(boxes[:, [1, 3]], 0, size[0]) | |
| # Remove too small | |
| kept = ( | |
| (boxes[:, 2] - boxes[:, 0] > self.min_bbox_size) & | |
| (boxes[:, 3] - boxes[:, 1] > self.min_bbox_size) | |
| ) | |
| boxes = boxes[kept] | |
| p = p[kept] | |
| # Normalize to 0-1 | |
| if self.normalize_boxes: | |
| boxes[:, [0, 2]] /= size[1] | |
| boxes[:, [1, 3]] /= size[0] | |
| scores = p[:, 4] * p[:, 5] | |
| labels = p[:, 6] | |
| preds.append({"labels": labels, "boxes": boxes, "scores": scores}) | |
| return preds | |