Spaces:
Sleeping
Sleeping
s-egg-mentation / deployments /deployment /Instance segmentation task /python /demo_package /model_wrapper.py
| # Copyright (C) 2024 Intel Corporation | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| """ModelContainer class used for loading the model in the model wrapper.""" | |
| from __future__ import annotations | |
| from enum import Enum | |
| from typing import TYPE_CHECKING, Any, NamedTuple | |
| from model_api.adapters import OpenvinoAdapter, create_core | |
| from model_api.models import Model | |
| from .utils import get_model_path, get_parameters | |
| if TYPE_CHECKING: | |
| from pathlib import Path | |
| import numpy as np | |
| from model_api.tilers import DetectionTiler, InstanceSegmentationTiler | |
| class TaskType(str, Enum): | |
| """OTX task type definition.""" | |
| CLASSIFICATION = "CLASSIFICATION" | |
| DETECTION = "DETECTION" | |
| INSTANCE_SEGMENTATION = "INSTANCE_SEGMENTATION" | |
| SEGMENTATION = "SEGMENTATION" | |
| class ModelWrapper: | |
| """Class for storing the model wrapper based on Model API and needed parameters of model. | |
| Args: | |
| model_dir (Path): path to model directory | |
| """ | |
| def __init__(self, model_dir: Path, device: str = "CPU") -> None: | |
| model_adapter = OpenvinoAdapter(create_core(), get_model_path(model_dir / "model.xml"), device=device) | |
| if not (model_dir / "config.json").exists(): | |
| msg = "config.json doesn't exist in the model directory." | |
| raise RuntimeError(msg) | |
| self.parameters = get_parameters(model_dir / "config.json") | |
| self._labels = self.parameters["model_parameters"]["labels"] | |
| self._task_type = TaskType[self.parameters["task_type"].upper()] | |
| # labels for modelAPI wrappers can be empty, because unused in pre- and postprocessing | |
| self.model_parameters = self.parameters["model_parameters"] | |
| # model already contains correct labels | |
| self.model_parameters.pop("labels") | |
| self.core_model = Model.create_model( | |
| model_adapter, | |
| self.parameters["model_type"], | |
| self.model_parameters, | |
| preload=True, | |
| ) | |
| self.tiler = self.setup_tiler(model_dir, device) | |
| def setup_tiler( | |
| self, | |
| model_dir: Path, | |
| device: str, | |
| ) -> DetectionTiler | InstanceSegmentationTiler | None: | |
| """Set up tiler for model. | |
| Args: | |
| model_dir (str): model directory | |
| device (str): device to run model on | |
| Returns: | |
| Optional: type of tiler or None | |
| """ | |
| if not self.parameters.get("tiling_parameters") or not self.parameters["tiling_parameters"]["enable_tiling"]: | |
| return None | |
| msg = "Tiling has not been implemented yet" | |
| raise NotImplementedError(msg) | |
| def task_type(self) -> TaskType: | |
| """Task type property.""" | |
| return self._task_type | |
| def labels(self) -> dict: | |
| """Labels property.""" | |
| return self._labels | |
| def infer(self, frame: np.ndarray) -> tuple[NamedTuple, dict]: | |
| """Infer with original image. | |
| Args: | |
| frame: np.ndarray, input image | |
| Returns: | |
| predictions: NamedTuple, prediction | |
| frame_meta: Dict, dict with original shape | |
| """ | |
| # getting result include preprocessing, infer, postprocessing for sync infer | |
| predictions = self.core_model(frame) | |
| frame_meta = {"original_shape": frame.shape} | |
| return predictions, frame_meta | |
| def infer_tile(self, frame: np.ndarray) -> tuple[NamedTuple, dict]: | |
| """Infer by patching full image to tiles. | |
| Args: | |
| frame: np.ndarray - input image | |
| Returns: | |
| Tuple[NamedTuple, Dict]: prediction and original shape | |
| """ | |
| if self.tiler is None: | |
| msg = "Tiler is not set" | |
| raise RuntimeError(msg) | |
| detections = self.tiler(frame) | |
| return detections, {"original_shape": frame.shape} | |
| def __call__(self, input_data: np.ndarray) -> tuple[Any, dict]: | |
| """Call the ModelWrapper class. | |
| Args: | |
| input_data (np.ndarray): The input image. | |
| Returns: | |
| Tuple[Any, dict]: A tuple containing predictions and the meta information. | |
| """ | |
| if self.tiler is not None: | |
| return self.infer_tile(input_data) | |
| return self.infer(input_data) | |