| from abc import abstractmethod |
| from time import perf_counter |
| from typing import Any, List, Tuple, Union |
|
|
| import numpy as np |
|
|
| from inference.core.cache.model_artifacts import clear_cache, initialise_cache |
| from inference.core.entities.requests.inference import InferenceRequest |
| from inference.core.entities.responses.inference import InferenceResponse, StubResponse |
| from inference.core.models.base import Model |
| from inference.core.models.types import PreprocessReturnMetadata |
| from inference.core.utils.image_utils import np_image_to_base64 |
|
|
|
|
| class ModelStub(Model): |
| def __init__(self, model_id: str, api_key: str): |
| super().__init__() |
| self.model_id = model_id |
| self.api_key = api_key |
| self.dataset_id, self.version_id = model_id.split("/") |
| self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0} |
| initialise_cache(model_id=model_id) |
|
|
| def infer_from_request( |
| self, request: InferenceRequest |
| ) -> Union[InferenceResponse, List[InferenceResponse]]: |
| t1 = perf_counter() |
| stub_prediction = self.infer(**request.dict()) |
| response = self.make_response(request=request, prediction=stub_prediction) |
| response.time = perf_counter() - t1 |
| return response |
|
|
| def infer(self, *args, **kwargs) -> Any: |
| _ = self.preprocess() |
| dummy_prediction = self.predict() |
| return self.postprocess(dummy_prediction) |
|
|
| def preprocess( |
| self, *args, **kwargs |
| ) -> Tuple[np.ndarray, PreprocessReturnMetadata]: |
| return np.zeros((128, 128, 3), dtype=np.uint8), {} |
|
|
| def predict(self, *args, **kwargs) -> Tuple[np.ndarray, ...]: |
| return (np.zeros((1, 8)),) |
|
|
| def postprocess(self, predictions: Tuple[np.ndarray, ...], *args, **kwargs) -> Any: |
| return { |
| "is_stub": True, |
| "model_id": self.model_id, |
| } |
|
|
| def clear_cache(self) -> None: |
| clear_cache(model_id=self.model_id) |
|
|
| @abstractmethod |
| def make_response( |
| self, request: InferenceRequest, prediction: dict, **kwargs |
| ) -> Union[InferenceResponse, List[InferenceResponse]]: |
| pass |
|
|
|
|
| class ClassificationModelStub(ModelStub): |
| task_type = "classification" |
|
|
| def make_response( |
| self, request: InferenceRequest, prediction: dict, **kwargs |
| ) -> Union[InferenceResponse, List[InferenceResponse]]: |
| stub_visualisation = None |
| if getattr(request, "visualize_predictions", False): |
| stub_visualisation = np_image_to_base64( |
| np.zeros((128, 128, 3), dtype=np.uint8) |
| ) |
| return StubResponse( |
| is_stub=prediction["is_stub"], |
| model_id=prediction["model_id"], |
| task_type=self.task_type, |
| visualization=stub_visualisation, |
| ) |
|
|
|
|
| class ObjectDetectionModelStub(ModelStub): |
| task_type = "object-detection" |
|
|
| def make_response( |
| self, request: InferenceRequest, prediction: dict, **kwargs |
| ) -> Union[InferenceResponse, List[InferenceResponse]]: |
| stub_visualisation = None |
| if getattr(request, "visualize_predictions", False): |
| stub_visualisation = np_image_to_base64( |
| np.zeros((128, 128, 3), dtype=np.uint8) |
| ) |
| return StubResponse( |
| is_stub=prediction["is_stub"], |
| model_id=prediction["model_id"], |
| task_type=self.task_type, |
| visualization=stub_visualisation, |
| ) |
|
|
|
|
| class InstanceSegmentationModelStub(ModelStub): |
| task_type = "instance-segmentation" |
|
|
| def make_response( |
| self, request: InferenceRequest, prediction: dict, **kwargs |
| ) -> Union[InferenceResponse, List[InferenceResponse]]: |
| stub_visualisation = None |
| if getattr(request, "visualize_predictions", False): |
| stub_visualisation = np_image_to_base64( |
| np.zeros((128, 128, 3), dtype=np.uint8) |
| ) |
| return StubResponse( |
| is_stub=prediction["is_stub"], |
| model_id=prediction["model_id"], |
| task_type=self.task_type, |
| visualization=stub_visualisation, |
| ) |
|
|
|
|
| class KeypointsDetectionModelStub(ModelStub): |
| task_type = "keypoint-detection" |
|
|
| def make_response( |
| self, request: InferenceRequest, prediction: dict, **kwargs |
| ) -> Union[InferenceResponse, List[InferenceResponse]]: |
| stub_visualisation = None |
| if getattr(request, "visualize_predictions", False): |
| stub_visualisation = np_image_to_base64( |
| np.zeros((128, 128, 3), dtype=np.uint8) |
| ) |
| return StubResponse( |
| is_stub=prediction["is_stub"], |
| model_id=prediction["model_id"], |
| task_type=self.task_type, |
| visualization=stub_visualisation, |
| ) |
|
|