from contextlib import contextmanager from typing import Any, Generator, List, Optional, Tuple, Union import numpy as np import requests from requests import HTTPError from inference_sdk.http.entities import ( CLASSIFICATION_TASK, INSTANCE_SEGMENTATION_TASK, KEYPOINTS_DETECTION_TASK, OBJECT_DETECTION_TASK, HTTPClientMode, ImagesReference, InferenceConfiguration, ModelDescription, RegisteredModels, ServerInfo, ) from inference_sdk.http.errors import ( HTTPCallErrorError, HTTPClientError, InvalidModelIdentifier, ModelNotInitializedError, ModelNotSelectedError, ModelTaskTypeNotSupportedError, WrongClientModeError, ) from inference_sdk.http.utils.iterables import unwrap_single_element_list from inference_sdk.http.utils.loaders import ( load_static_inference_input, load_stream_inference_input, ) from inference_sdk.http.utils.post_processing import ( adjust_prediction_to_client_scaling_factor, response_contains_jpeg_image, transform_base64_visualisation, transform_visualisation_bytes, ) from inference_sdk.http.utils.requests import api_key_safe_raise_for_status SUCCESSFUL_STATUS_CODE = 200 DEFAULT_HEADERS = { "Content-Type": "application/json", } NEW_INFERENCE_ENDPOINTS = { INSTANCE_SEGMENTATION_TASK: "/infer/instance_segmentation", OBJECT_DETECTION_TASK: "/infer/object_detection", CLASSIFICATION_TASK: "/infer/classification", KEYPOINTS_DETECTION_TASK: "/infer/keypoints_detection", } def wrap_errors(function: callable) -> callable: def decorate(*args, **kwargs) -> Any: try: return function(*args, **kwargs) except HTTPError as error: if "application/json" in error.response.headers.get("Content-Type", ""): api_message = error.response.json().get("message") else: api_message = error.response.text raise HTTPCallErrorError( description=str(error), status_code=error.response.status_code, api_message=api_message, ) from error except ConnectionError as error: raise HTTPClientError( f"Error with server connection: {str(error)}" ) from error return decorate class InferenceHTTPClient: def __init__( self, api_url: str, api_key: str, ): self.__api_url = api_url self.__api_key = api_key self.__inference_configuration = InferenceConfiguration.init_default() self.__client_mode = _determine_client_mode(api_url=api_url) self.__selected_model: Optional[str] = None @property def inference_configuration(self) -> InferenceConfiguration: return self.__inference_configuration @property def client_mode(self) -> HTTPClientMode: return self.__client_mode @property def selected_model(self) -> Optional[str]: return self.__selected_model @contextmanager def use_configuration( self, inference_configuration: InferenceConfiguration ) -> Generator["InferenceHTTPClient", None, None]: previous_configuration = self.__inference_configuration self.__inference_configuration = inference_configuration try: yield self finally: self.__inference_configuration = previous_configuration def configure( self, inference_configuration: InferenceConfiguration ) -> "InferenceHTTPClient": self.__inference_configuration = inference_configuration return self def select_api_v0(self) -> "InferenceHTTPClient": self.__client_mode = HTTPClientMode.V0 return self def select_api_v1(self) -> "InferenceHTTPClient": self.__client_mode = HTTPClientMode.V1 return self @contextmanager def use_api_v0(self) -> Generator["InferenceHTTPClient", None, None]: previous_client_mode = self.__client_mode self.__client_mode = HTTPClientMode.V0 try: yield self finally: self.__client_mode = previous_client_mode @contextmanager def use_api_v1(self) -> Generator["InferenceHTTPClient", None, None]: previous_client_mode = self.__client_mode self.__client_mode = HTTPClientMode.V1 try: yield self finally: self.__client_mode = previous_client_mode def select_model(self, model_id: str) -> "InferenceHTTPClient": self.__selected_model = model_id return self @contextmanager def use_model(self, model_id: str) -> Generator["InferenceHTTPClient", None, None]: previous_model = self.__selected_model self.__selected_model = model_id try: yield self finally: self.__selected_model = previous_model @wrap_errors def get_server_info(self) -> ServerInfo: response = requests.get(f"{self.__api_url}/info") response.raise_for_status() response_payload = response.json() return ServerInfo.from_dict(response_payload) def infer_on_stream( self, input_uri: str, model_id: Optional[str] = None, ) -> Generator[Tuple[Union[str, int], np.ndarray, dict], None, None]: for reference, frame in load_stream_inference_input( input_uri=input_uri, image_extensions=self.__inference_configuration.image_extensions_for_directory_scan, ): prediction = self.infer( inference_input=frame, model_id=model_id, ) yield reference, frame, prediction @wrap_errors def infer( self, inference_input: Union[ImagesReference, List[ImagesReference]], model_id: Optional[str] = None, ) -> Union[dict, List[dict]]: if self.__client_mode is HTTPClientMode.V0: return self.infer_from_api_v0( inference_input=inference_input, model_id=model_id, ) return self.infer_from_api_v1( inference_input=inference_input, model_id=model_id, ) def infer_from_api_v0( self, inference_input: Union[ImagesReference, List[ImagesReference]], model_id: Optional[str] = None, ) -> Union[dict, List[dict]]: model_id_to_be_used = model_id or self.__selected_model _ensure_model_is_selected(model_id=model_id_to_be_used) model_id_chunks = model_id_to_be_used.split("/") if len(model_id_chunks) != 2: raise InvalidModelIdentifier( f"Invalid model identifier: {model_id} in use." ) max_height, max_width = _determine_client_downsizing_parameters( client_downsizing_disabled=self.__inference_configuration.client_downsizing_disabled, model_description=None, default_max_input_size=self.__inference_configuration.default_max_input_size, ) encoded_inference_inputs = load_static_inference_input( inference_input=inference_input, max_height=max_height, max_width=max_width, ) params = { "api_key": self.__api_key, } params.update(self.__inference_configuration.to_legacy_call_parameters()) results = [] for element in encoded_inference_inputs: image, scaling_factor = element response = requests.post( f"{self.__api_url}/{model_id_chunks[0]}/{model_id_chunks[1]}", headers=DEFAULT_HEADERS, params=params, data=image, ) api_key_safe_raise_for_status(response=response) if response_contains_jpeg_image(response=response): visualisation = transform_visualisation_bytes( visualisation=response.content, expected_format=self.__inference_configuration.output_visualisation_format, ) parsed_response = {"visualization": visualisation} else: parsed_response = response.json() parsed_response = adjust_prediction_to_client_scaling_factor( prediction=parsed_response, scaling_factor=scaling_factor, ) results.append(parsed_response) return unwrap_single_element_list(sequence=results) def infer_from_api_v1( self, inference_input: Union[ImagesReference, List[ImagesReference]], model_id: Optional[str] = None, ) -> Union[dict, List[dict]]: self.__ensure_v1_client_mode() model_id_to_be_used = model_id or self.__selected_model _ensure_model_is_selected(model_id=model_id_to_be_used) model_description = self.get_model_description(model_id=model_id_to_be_used) max_height, max_width = _determine_client_downsizing_parameters( client_downsizing_disabled=self.__inference_configuration.client_downsizing_disabled, model_description=model_description, default_max_input_size=self.__inference_configuration.default_max_input_size, ) if model_description.task_type not in NEW_INFERENCE_ENDPOINTS: raise ModelTaskTypeNotSupportedError( f"Model task {model_description.task_type} is not supported by API v1 client." ) encoded_inference_inputs = load_static_inference_input( inference_input=inference_input, max_height=max_height, max_width=max_width, ) payload = { "api_key": self.__api_key, "model_id": model_id_to_be_used, } endpoint = NEW_INFERENCE_ENDPOINTS[model_description.task_type] payload.update( self.__inference_configuration.to_api_call_parameters( client_mode=self.__client_mode, task_type=model_description.task_type, ) ) results = [] for element in encoded_inference_inputs: image, scaling_factor = element payload["image"] = {"type": "base64", "value": image} response = requests.post( f"{self.__api_url}{endpoint}", json=payload, headers=DEFAULT_HEADERS, ) response.raise_for_status() parsed_response = response.json() if parsed_response.get("visualization") is not None: parsed_response["visualization"] = transform_base64_visualisation( visualisation=parsed_response["visualization"], expected_format=self.__inference_configuration.output_visualisation_format, ) parsed_response = adjust_prediction_to_client_scaling_factor( prediction=parsed_response, scaling_factor=scaling_factor, ) results.append(parsed_response) return unwrap_single_element_list(sequence=results) def get_model_description( self, model_id: str, allow_loading: bool = True ) -> ModelDescription: self.__ensure_v1_client_mode() registered_models = self.list_loaded_models() matching_models = [ e for e in registered_models.models if e.model_id == model_id ] if len(matching_models) > 0: return matching_models[0] if allow_loading is True: self.load_model(model_id=model_id) return self.get_model_description(model_id=model_id, allow_loading=False) raise ModelNotInitializedError( f"Model {model_id} is not initialised and cannot retrieve its description." ) @wrap_errors def list_loaded_models(self) -> RegisteredModels: self.__ensure_v1_client_mode() response = requests.get(f"{self.__api_url}/model/registry") response.raise_for_status() response_payload = response.json() return RegisteredModels.from_dict(response_payload) @wrap_errors def load_model( self, model_id: str, set_as_default: bool = False ) -> RegisteredModels: self.__ensure_v1_client_mode() response = requests.post( f"{self.__api_url}/model/add", json={ "model_id": model_id, "api_key": self.__api_key, }, headers=DEFAULT_HEADERS, ) response.raise_for_status() response_payload = response.json() if set_as_default: self.__selected_model = model_id return RegisteredModels.from_dict(response_payload) @wrap_errors def unload_model(self, model_id: str) -> RegisteredModels: self.__ensure_v1_client_mode() response = requests.post( f"{self.__api_url}/model/remove", json={ "model_id": model_id, }, headers=DEFAULT_HEADERS, ) response.raise_for_status() response_payload = response.json() if model_id == self.__selected_model: self.__selected_model = None return RegisteredModels.from_dict(response_payload) @wrap_errors def unload_all_models(self) -> RegisteredModels: self.__ensure_v1_client_mode() response = requests.post(f"{self.__api_url}/model/clear") response.raise_for_status() response_payload = response.json() self.__selected_model = None return RegisteredModels.from_dict(response_payload) def __ensure_v1_client_mode(self) -> None: if self.__client_mode is not HTTPClientMode.V1: raise WrongClientModeError("Use client mode `v1` to run this operation.") def _determine_client_downsizing_parameters( client_downsizing_disabled: bool, model_description: Optional[ModelDescription], default_max_input_size: int, ) -> Tuple[Optional[int], Optional[int]]: if client_downsizing_disabled: return None, None if ( model_description is None or model_description.input_height is None or model_description.input_width is None ): return default_max_input_size, default_max_input_size return model_description.input_height, model_description.input_width def _determine_client_mode(api_url: str) -> HTTPClientMode: if "roboflow.com" in api_url: return HTTPClientMode.V0 return HTTPClientMode.V1 def _ensure_model_is_selected(model_id: Optional[str]) -> None: if model_id is None: raise ModelNotSelectedError("No model was selected to be used.")