|
|
from contextlib import contextmanager |
|
|
from typing import Any, Generator, List, Optional, Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
import requests |
|
|
from requests import HTTPError |
|
|
|
|
|
from inference_sdk.http.entities import ( |
|
|
CLASSIFICATION_TASK, |
|
|
INSTANCE_SEGMENTATION_TASK, |
|
|
KEYPOINTS_DETECTION_TASK, |
|
|
OBJECT_DETECTION_TASK, |
|
|
HTTPClientMode, |
|
|
ImagesReference, |
|
|
InferenceConfiguration, |
|
|
ModelDescription, |
|
|
RegisteredModels, |
|
|
ServerInfo, |
|
|
) |
|
|
from inference_sdk.http.errors import ( |
|
|
HTTPCallErrorError, |
|
|
HTTPClientError, |
|
|
InvalidModelIdentifier, |
|
|
ModelNotInitializedError, |
|
|
ModelNotSelectedError, |
|
|
ModelTaskTypeNotSupportedError, |
|
|
WrongClientModeError, |
|
|
) |
|
|
from inference_sdk.http.utils.iterables import unwrap_single_element_list |
|
|
from inference_sdk.http.utils.loaders import ( |
|
|
load_static_inference_input, |
|
|
load_stream_inference_input, |
|
|
) |
|
|
from inference_sdk.http.utils.post_processing import ( |
|
|
adjust_prediction_to_client_scaling_factor, |
|
|
response_contains_jpeg_image, |
|
|
transform_base64_visualisation, |
|
|
transform_visualisation_bytes, |
|
|
) |
|
|
from inference_sdk.http.utils.requests import api_key_safe_raise_for_status |
|
|
|
|
|
SUCCESSFUL_STATUS_CODE = 200 |
|
|
DEFAULT_HEADERS = { |
|
|
"Content-Type": "application/json", |
|
|
} |
|
|
NEW_INFERENCE_ENDPOINTS = { |
|
|
INSTANCE_SEGMENTATION_TASK: "/infer/instance_segmentation", |
|
|
OBJECT_DETECTION_TASK: "/infer/object_detection", |
|
|
CLASSIFICATION_TASK: "/infer/classification", |
|
|
KEYPOINTS_DETECTION_TASK: "/infer/keypoints_detection", |
|
|
} |
|
|
|
|
|
|
|
|
def wrap_errors(function: callable) -> callable: |
|
|
def decorate(*args, **kwargs) -> Any: |
|
|
try: |
|
|
return function(*args, **kwargs) |
|
|
except HTTPError as error: |
|
|
if "application/json" in error.response.headers.get("Content-Type", ""): |
|
|
api_message = error.response.json().get("message") |
|
|
else: |
|
|
api_message = error.response.text |
|
|
raise HTTPCallErrorError( |
|
|
description=str(error), |
|
|
status_code=error.response.status_code, |
|
|
api_message=api_message, |
|
|
) from error |
|
|
except ConnectionError as error: |
|
|
raise HTTPClientError( |
|
|
f"Error with server connection: {str(error)}" |
|
|
) from error |
|
|
|
|
|
return decorate |
|
|
|
|
|
|
|
|
class InferenceHTTPClient: |
|
|
def __init__( |
|
|
self, |
|
|
api_url: str, |
|
|
api_key: str, |
|
|
): |
|
|
self.__api_url = api_url |
|
|
self.__api_key = api_key |
|
|
self.__inference_configuration = InferenceConfiguration.init_default() |
|
|
self.__client_mode = _determine_client_mode(api_url=api_url) |
|
|
self.__selected_model: Optional[str] = None |
|
|
|
|
|
@property |
|
|
def inference_configuration(self) -> InferenceConfiguration: |
|
|
return self.__inference_configuration |
|
|
|
|
|
@property |
|
|
def client_mode(self) -> HTTPClientMode: |
|
|
return self.__client_mode |
|
|
|
|
|
@property |
|
|
def selected_model(self) -> Optional[str]: |
|
|
return self.__selected_model |
|
|
|
|
|
@contextmanager |
|
|
def use_configuration( |
|
|
self, inference_configuration: InferenceConfiguration |
|
|
) -> Generator["InferenceHTTPClient", None, None]: |
|
|
previous_configuration = self.__inference_configuration |
|
|
self.__inference_configuration = inference_configuration |
|
|
try: |
|
|
yield self |
|
|
finally: |
|
|
self.__inference_configuration = previous_configuration |
|
|
|
|
|
def configure( |
|
|
self, inference_configuration: InferenceConfiguration |
|
|
) -> "InferenceHTTPClient": |
|
|
self.__inference_configuration = inference_configuration |
|
|
return self |
|
|
|
|
|
def select_api_v0(self) -> "InferenceHTTPClient": |
|
|
self.__client_mode = HTTPClientMode.V0 |
|
|
return self |
|
|
|
|
|
def select_api_v1(self) -> "InferenceHTTPClient": |
|
|
self.__client_mode = HTTPClientMode.V1 |
|
|
return self |
|
|
|
|
|
@contextmanager |
|
|
def use_api_v0(self) -> Generator["InferenceHTTPClient", None, None]: |
|
|
previous_client_mode = self.__client_mode |
|
|
self.__client_mode = HTTPClientMode.V0 |
|
|
try: |
|
|
yield self |
|
|
finally: |
|
|
self.__client_mode = previous_client_mode |
|
|
|
|
|
@contextmanager |
|
|
def use_api_v1(self) -> Generator["InferenceHTTPClient", None, None]: |
|
|
previous_client_mode = self.__client_mode |
|
|
self.__client_mode = HTTPClientMode.V1 |
|
|
try: |
|
|
yield self |
|
|
finally: |
|
|
self.__client_mode = previous_client_mode |
|
|
|
|
|
def select_model(self, model_id: str) -> "InferenceHTTPClient": |
|
|
self.__selected_model = model_id |
|
|
return self |
|
|
|
|
|
@contextmanager |
|
|
def use_model(self, model_id: str) -> Generator["InferenceHTTPClient", None, None]: |
|
|
previous_model = self.__selected_model |
|
|
self.__selected_model = model_id |
|
|
try: |
|
|
yield self |
|
|
finally: |
|
|
self.__selected_model = previous_model |
|
|
|
|
|
@wrap_errors |
|
|
def get_server_info(self) -> ServerInfo: |
|
|
response = requests.get(f"{self.__api_url}/info") |
|
|
response.raise_for_status() |
|
|
response_payload = response.json() |
|
|
return ServerInfo.from_dict(response_payload) |
|
|
|
|
|
def infer_on_stream( |
|
|
self, |
|
|
input_uri: str, |
|
|
model_id: Optional[str] = None, |
|
|
) -> Generator[Tuple[Union[str, int], np.ndarray, dict], None, None]: |
|
|
for reference, frame in load_stream_inference_input( |
|
|
input_uri=input_uri, |
|
|
image_extensions=self.__inference_configuration.image_extensions_for_directory_scan, |
|
|
): |
|
|
prediction = self.infer( |
|
|
inference_input=frame, |
|
|
model_id=model_id, |
|
|
) |
|
|
yield reference, frame, prediction |
|
|
|
|
|
@wrap_errors |
|
|
def infer( |
|
|
self, |
|
|
inference_input: Union[ImagesReference, List[ImagesReference]], |
|
|
model_id: Optional[str] = None, |
|
|
) -> Union[dict, List[dict]]: |
|
|
if self.__client_mode is HTTPClientMode.V0: |
|
|
return self.infer_from_api_v0( |
|
|
inference_input=inference_input, |
|
|
model_id=model_id, |
|
|
) |
|
|
return self.infer_from_api_v1( |
|
|
inference_input=inference_input, |
|
|
model_id=model_id, |
|
|
) |
|
|
|
|
|
def infer_from_api_v0( |
|
|
self, |
|
|
inference_input: Union[ImagesReference, List[ImagesReference]], |
|
|
model_id: Optional[str] = None, |
|
|
) -> Union[dict, List[dict]]: |
|
|
model_id_to_be_used = model_id or self.__selected_model |
|
|
_ensure_model_is_selected(model_id=model_id_to_be_used) |
|
|
model_id_chunks = model_id_to_be_used.split("/") |
|
|
if len(model_id_chunks) != 2: |
|
|
raise InvalidModelIdentifier( |
|
|
f"Invalid model identifier: {model_id} in use." |
|
|
) |
|
|
max_height, max_width = _determine_client_downsizing_parameters( |
|
|
client_downsizing_disabled=self.__inference_configuration.client_downsizing_disabled, |
|
|
model_description=None, |
|
|
default_max_input_size=self.__inference_configuration.default_max_input_size, |
|
|
) |
|
|
encoded_inference_inputs = load_static_inference_input( |
|
|
inference_input=inference_input, |
|
|
max_height=max_height, |
|
|
max_width=max_width, |
|
|
) |
|
|
params = { |
|
|
"api_key": self.__api_key, |
|
|
} |
|
|
params.update(self.__inference_configuration.to_legacy_call_parameters()) |
|
|
results = [] |
|
|
for element in encoded_inference_inputs: |
|
|
image, scaling_factor = element |
|
|
response = requests.post( |
|
|
f"{self.__api_url}/{model_id_chunks[0]}/{model_id_chunks[1]}", |
|
|
headers=DEFAULT_HEADERS, |
|
|
params=params, |
|
|
data=image, |
|
|
) |
|
|
api_key_safe_raise_for_status(response=response) |
|
|
if response_contains_jpeg_image(response=response): |
|
|
visualisation = transform_visualisation_bytes( |
|
|
visualisation=response.content, |
|
|
expected_format=self.__inference_configuration.output_visualisation_format, |
|
|
) |
|
|
parsed_response = {"visualization": visualisation} |
|
|
else: |
|
|
parsed_response = response.json() |
|
|
parsed_response = adjust_prediction_to_client_scaling_factor( |
|
|
prediction=parsed_response, |
|
|
scaling_factor=scaling_factor, |
|
|
) |
|
|
results.append(parsed_response) |
|
|
return unwrap_single_element_list(sequence=results) |
|
|
|
|
|
def infer_from_api_v1( |
|
|
self, |
|
|
inference_input: Union[ImagesReference, List[ImagesReference]], |
|
|
model_id: Optional[str] = None, |
|
|
) -> Union[dict, List[dict]]: |
|
|
self.__ensure_v1_client_mode() |
|
|
model_id_to_be_used = model_id or self.__selected_model |
|
|
_ensure_model_is_selected(model_id=model_id_to_be_used) |
|
|
model_description = self.get_model_description(model_id=model_id_to_be_used) |
|
|
max_height, max_width = _determine_client_downsizing_parameters( |
|
|
client_downsizing_disabled=self.__inference_configuration.client_downsizing_disabled, |
|
|
model_description=model_description, |
|
|
default_max_input_size=self.__inference_configuration.default_max_input_size, |
|
|
) |
|
|
if model_description.task_type not in NEW_INFERENCE_ENDPOINTS: |
|
|
raise ModelTaskTypeNotSupportedError( |
|
|
f"Model task {model_description.task_type} is not supported by API v1 client." |
|
|
) |
|
|
encoded_inference_inputs = load_static_inference_input( |
|
|
inference_input=inference_input, |
|
|
max_height=max_height, |
|
|
max_width=max_width, |
|
|
) |
|
|
payload = { |
|
|
"api_key": self.__api_key, |
|
|
"model_id": model_id_to_be_used, |
|
|
} |
|
|
endpoint = NEW_INFERENCE_ENDPOINTS[model_description.task_type] |
|
|
payload.update( |
|
|
self.__inference_configuration.to_api_call_parameters( |
|
|
client_mode=self.__client_mode, |
|
|
task_type=model_description.task_type, |
|
|
) |
|
|
) |
|
|
results = [] |
|
|
for element in encoded_inference_inputs: |
|
|
image, scaling_factor = element |
|
|
payload["image"] = {"type": "base64", "value": image} |
|
|
response = requests.post( |
|
|
f"{self.__api_url}{endpoint}", |
|
|
json=payload, |
|
|
headers=DEFAULT_HEADERS, |
|
|
) |
|
|
response.raise_for_status() |
|
|
parsed_response = response.json() |
|
|
if parsed_response.get("visualization") is not None: |
|
|
parsed_response["visualization"] = transform_base64_visualisation( |
|
|
visualisation=parsed_response["visualization"], |
|
|
expected_format=self.__inference_configuration.output_visualisation_format, |
|
|
) |
|
|
parsed_response = adjust_prediction_to_client_scaling_factor( |
|
|
prediction=parsed_response, |
|
|
scaling_factor=scaling_factor, |
|
|
) |
|
|
results.append(parsed_response) |
|
|
return unwrap_single_element_list(sequence=results) |
|
|
|
|
|
def get_model_description( |
|
|
self, model_id: str, allow_loading: bool = True |
|
|
) -> ModelDescription: |
|
|
self.__ensure_v1_client_mode() |
|
|
registered_models = self.list_loaded_models() |
|
|
matching_models = [ |
|
|
e for e in registered_models.models if e.model_id == model_id |
|
|
] |
|
|
if len(matching_models) > 0: |
|
|
return matching_models[0] |
|
|
if allow_loading is True: |
|
|
self.load_model(model_id=model_id) |
|
|
return self.get_model_description(model_id=model_id, allow_loading=False) |
|
|
raise ModelNotInitializedError( |
|
|
f"Model {model_id} is not initialised and cannot retrieve its description." |
|
|
) |
|
|
|
|
|
@wrap_errors |
|
|
def list_loaded_models(self) -> RegisteredModels: |
|
|
self.__ensure_v1_client_mode() |
|
|
response = requests.get(f"{self.__api_url}/model/registry") |
|
|
response.raise_for_status() |
|
|
response_payload = response.json() |
|
|
return RegisteredModels.from_dict(response_payload) |
|
|
|
|
|
@wrap_errors |
|
|
def load_model( |
|
|
self, model_id: str, set_as_default: bool = False |
|
|
) -> RegisteredModels: |
|
|
self.__ensure_v1_client_mode() |
|
|
response = requests.post( |
|
|
f"{self.__api_url}/model/add", |
|
|
json={ |
|
|
"model_id": model_id, |
|
|
"api_key": self.__api_key, |
|
|
}, |
|
|
headers=DEFAULT_HEADERS, |
|
|
) |
|
|
response.raise_for_status() |
|
|
response_payload = response.json() |
|
|
if set_as_default: |
|
|
self.__selected_model = model_id |
|
|
return RegisteredModels.from_dict(response_payload) |
|
|
|
|
|
@wrap_errors |
|
|
def unload_model(self, model_id: str) -> RegisteredModels: |
|
|
self.__ensure_v1_client_mode() |
|
|
response = requests.post( |
|
|
f"{self.__api_url}/model/remove", |
|
|
json={ |
|
|
"model_id": model_id, |
|
|
}, |
|
|
headers=DEFAULT_HEADERS, |
|
|
) |
|
|
response.raise_for_status() |
|
|
response_payload = response.json() |
|
|
if model_id == self.__selected_model: |
|
|
self.__selected_model = None |
|
|
return RegisteredModels.from_dict(response_payload) |
|
|
|
|
|
@wrap_errors |
|
|
def unload_all_models(self) -> RegisteredModels: |
|
|
self.__ensure_v1_client_mode() |
|
|
response = requests.post(f"{self.__api_url}/model/clear") |
|
|
response.raise_for_status() |
|
|
response_payload = response.json() |
|
|
self.__selected_model = None |
|
|
return RegisteredModels.from_dict(response_payload) |
|
|
|
|
|
def __ensure_v1_client_mode(self) -> None: |
|
|
if self.__client_mode is not HTTPClientMode.V1: |
|
|
raise WrongClientModeError("Use client mode `v1` to run this operation.") |
|
|
|
|
|
|
|
|
def _determine_client_downsizing_parameters( |
|
|
client_downsizing_disabled: bool, |
|
|
model_description: Optional[ModelDescription], |
|
|
default_max_input_size: int, |
|
|
) -> Tuple[Optional[int], Optional[int]]: |
|
|
if client_downsizing_disabled: |
|
|
return None, None |
|
|
if ( |
|
|
model_description is None |
|
|
or model_description.input_height is None |
|
|
or model_description.input_width is None |
|
|
): |
|
|
return default_max_input_size, default_max_input_size |
|
|
return model_description.input_height, model_description.input_width |
|
|
|
|
|
|
|
|
def _determine_client_mode(api_url: str) -> HTTPClientMode: |
|
|
if "roboflow.com" in api_url: |
|
|
return HTTPClientMode.V0 |
|
|
return HTTPClientMode.V1 |
|
|
|
|
|
|
|
|
def _ensure_model_is_selected(model_id: Optional[str]) -> None: |
|
|
if model_id is None: |
|
|
raise ModelNotSelectedError("No model was selected to be used.") |
|
|
|