Sankie005's picture
Upload 434 files
c446951
from contextlib import contextmanager
from typing import Any, Generator, List, Optional, Tuple, Union
import numpy as np
import requests
from requests import HTTPError
from inference_sdk.http.entities import (
CLASSIFICATION_TASK,
INSTANCE_SEGMENTATION_TASK,
KEYPOINTS_DETECTION_TASK,
OBJECT_DETECTION_TASK,
HTTPClientMode,
ImagesReference,
InferenceConfiguration,
ModelDescription,
RegisteredModels,
ServerInfo,
)
from inference_sdk.http.errors import (
HTTPCallErrorError,
HTTPClientError,
InvalidModelIdentifier,
ModelNotInitializedError,
ModelNotSelectedError,
ModelTaskTypeNotSupportedError,
WrongClientModeError,
)
from inference_sdk.http.utils.iterables import unwrap_single_element_list
from inference_sdk.http.utils.loaders import (
load_static_inference_input,
load_stream_inference_input,
)
from inference_sdk.http.utils.post_processing import (
adjust_prediction_to_client_scaling_factor,
response_contains_jpeg_image,
transform_base64_visualisation,
transform_visualisation_bytes,
)
from inference_sdk.http.utils.requests import api_key_safe_raise_for_status
SUCCESSFUL_STATUS_CODE = 200
DEFAULT_HEADERS = {
"Content-Type": "application/json",
}
NEW_INFERENCE_ENDPOINTS = {
INSTANCE_SEGMENTATION_TASK: "/infer/instance_segmentation",
OBJECT_DETECTION_TASK: "/infer/object_detection",
CLASSIFICATION_TASK: "/infer/classification",
KEYPOINTS_DETECTION_TASK: "/infer/keypoints_detection",
}
def wrap_errors(function: callable) -> callable:
def decorate(*args, **kwargs) -> Any:
try:
return function(*args, **kwargs)
except HTTPError as error:
if "application/json" in error.response.headers.get("Content-Type", ""):
api_message = error.response.json().get("message")
else:
api_message = error.response.text
raise HTTPCallErrorError(
description=str(error),
status_code=error.response.status_code,
api_message=api_message,
) from error
except ConnectionError as error:
raise HTTPClientError(
f"Error with server connection: {str(error)}"
) from error
return decorate
class InferenceHTTPClient:
def __init__(
self,
api_url: str,
api_key: str,
):
self.__api_url = api_url
self.__api_key = api_key
self.__inference_configuration = InferenceConfiguration.init_default()
self.__client_mode = _determine_client_mode(api_url=api_url)
self.__selected_model: Optional[str] = None
@property
def inference_configuration(self) -> InferenceConfiguration:
return self.__inference_configuration
@property
def client_mode(self) -> HTTPClientMode:
return self.__client_mode
@property
def selected_model(self) -> Optional[str]:
return self.__selected_model
@contextmanager
def use_configuration(
self, inference_configuration: InferenceConfiguration
) -> Generator["InferenceHTTPClient", None, None]:
previous_configuration = self.__inference_configuration
self.__inference_configuration = inference_configuration
try:
yield self
finally:
self.__inference_configuration = previous_configuration
def configure(
self, inference_configuration: InferenceConfiguration
) -> "InferenceHTTPClient":
self.__inference_configuration = inference_configuration
return self
def select_api_v0(self) -> "InferenceHTTPClient":
self.__client_mode = HTTPClientMode.V0
return self
def select_api_v1(self) -> "InferenceHTTPClient":
self.__client_mode = HTTPClientMode.V1
return self
@contextmanager
def use_api_v0(self) -> Generator["InferenceHTTPClient", None, None]:
previous_client_mode = self.__client_mode
self.__client_mode = HTTPClientMode.V0
try:
yield self
finally:
self.__client_mode = previous_client_mode
@contextmanager
def use_api_v1(self) -> Generator["InferenceHTTPClient", None, None]:
previous_client_mode = self.__client_mode
self.__client_mode = HTTPClientMode.V1
try:
yield self
finally:
self.__client_mode = previous_client_mode
def select_model(self, model_id: str) -> "InferenceHTTPClient":
self.__selected_model = model_id
return self
@contextmanager
def use_model(self, model_id: str) -> Generator["InferenceHTTPClient", None, None]:
previous_model = self.__selected_model
self.__selected_model = model_id
try:
yield self
finally:
self.__selected_model = previous_model
@wrap_errors
def get_server_info(self) -> ServerInfo:
response = requests.get(f"{self.__api_url}/info")
response.raise_for_status()
response_payload = response.json()
return ServerInfo.from_dict(response_payload)
def infer_on_stream(
self,
input_uri: str,
model_id: Optional[str] = None,
) -> Generator[Tuple[Union[str, int], np.ndarray, dict], None, None]:
for reference, frame in load_stream_inference_input(
input_uri=input_uri,
image_extensions=self.__inference_configuration.image_extensions_for_directory_scan,
):
prediction = self.infer(
inference_input=frame,
model_id=model_id,
)
yield reference, frame, prediction
@wrap_errors
def infer(
self,
inference_input: Union[ImagesReference, List[ImagesReference]],
model_id: Optional[str] = None,
) -> Union[dict, List[dict]]:
if self.__client_mode is HTTPClientMode.V0:
return self.infer_from_api_v0(
inference_input=inference_input,
model_id=model_id,
)
return self.infer_from_api_v1(
inference_input=inference_input,
model_id=model_id,
)
def infer_from_api_v0(
self,
inference_input: Union[ImagesReference, List[ImagesReference]],
model_id: Optional[str] = None,
) -> Union[dict, List[dict]]:
model_id_to_be_used = model_id or self.__selected_model
_ensure_model_is_selected(model_id=model_id_to_be_used)
model_id_chunks = model_id_to_be_used.split("/")
if len(model_id_chunks) != 2:
raise InvalidModelIdentifier(
f"Invalid model identifier: {model_id} in use."
)
max_height, max_width = _determine_client_downsizing_parameters(
client_downsizing_disabled=self.__inference_configuration.client_downsizing_disabled,
model_description=None,
default_max_input_size=self.__inference_configuration.default_max_input_size,
)
encoded_inference_inputs = load_static_inference_input(
inference_input=inference_input,
max_height=max_height,
max_width=max_width,
)
params = {
"api_key": self.__api_key,
}
params.update(self.__inference_configuration.to_legacy_call_parameters())
results = []
for element in encoded_inference_inputs:
image, scaling_factor = element
response = requests.post(
f"{self.__api_url}/{model_id_chunks[0]}/{model_id_chunks[1]}",
headers=DEFAULT_HEADERS,
params=params,
data=image,
)
api_key_safe_raise_for_status(response=response)
if response_contains_jpeg_image(response=response):
visualisation = transform_visualisation_bytes(
visualisation=response.content,
expected_format=self.__inference_configuration.output_visualisation_format,
)
parsed_response = {"visualization": visualisation}
else:
parsed_response = response.json()
parsed_response = adjust_prediction_to_client_scaling_factor(
prediction=parsed_response,
scaling_factor=scaling_factor,
)
results.append(parsed_response)
return unwrap_single_element_list(sequence=results)
def infer_from_api_v1(
self,
inference_input: Union[ImagesReference, List[ImagesReference]],
model_id: Optional[str] = None,
) -> Union[dict, List[dict]]:
self.__ensure_v1_client_mode()
model_id_to_be_used = model_id or self.__selected_model
_ensure_model_is_selected(model_id=model_id_to_be_used)
model_description = self.get_model_description(model_id=model_id_to_be_used)
max_height, max_width = _determine_client_downsizing_parameters(
client_downsizing_disabled=self.__inference_configuration.client_downsizing_disabled,
model_description=model_description,
default_max_input_size=self.__inference_configuration.default_max_input_size,
)
if model_description.task_type not in NEW_INFERENCE_ENDPOINTS:
raise ModelTaskTypeNotSupportedError(
f"Model task {model_description.task_type} is not supported by API v1 client."
)
encoded_inference_inputs = load_static_inference_input(
inference_input=inference_input,
max_height=max_height,
max_width=max_width,
)
payload = {
"api_key": self.__api_key,
"model_id": model_id_to_be_used,
}
endpoint = NEW_INFERENCE_ENDPOINTS[model_description.task_type]
payload.update(
self.__inference_configuration.to_api_call_parameters(
client_mode=self.__client_mode,
task_type=model_description.task_type,
)
)
results = []
for element in encoded_inference_inputs:
image, scaling_factor = element
payload["image"] = {"type": "base64", "value": image}
response = requests.post(
f"{self.__api_url}{endpoint}",
json=payload,
headers=DEFAULT_HEADERS,
)
response.raise_for_status()
parsed_response = response.json()
if parsed_response.get("visualization") is not None:
parsed_response["visualization"] = transform_base64_visualisation(
visualisation=parsed_response["visualization"],
expected_format=self.__inference_configuration.output_visualisation_format,
)
parsed_response = adjust_prediction_to_client_scaling_factor(
prediction=parsed_response,
scaling_factor=scaling_factor,
)
results.append(parsed_response)
return unwrap_single_element_list(sequence=results)
def get_model_description(
self, model_id: str, allow_loading: bool = True
) -> ModelDescription:
self.__ensure_v1_client_mode()
registered_models = self.list_loaded_models()
matching_models = [
e for e in registered_models.models if e.model_id == model_id
]
if len(matching_models) > 0:
return matching_models[0]
if allow_loading is True:
self.load_model(model_id=model_id)
return self.get_model_description(model_id=model_id, allow_loading=False)
raise ModelNotInitializedError(
f"Model {model_id} is not initialised and cannot retrieve its description."
)
@wrap_errors
def list_loaded_models(self) -> RegisteredModels:
self.__ensure_v1_client_mode()
response = requests.get(f"{self.__api_url}/model/registry")
response.raise_for_status()
response_payload = response.json()
return RegisteredModels.from_dict(response_payload)
@wrap_errors
def load_model(
self, model_id: str, set_as_default: bool = False
) -> RegisteredModels:
self.__ensure_v1_client_mode()
response = requests.post(
f"{self.__api_url}/model/add",
json={
"model_id": model_id,
"api_key": self.__api_key,
},
headers=DEFAULT_HEADERS,
)
response.raise_for_status()
response_payload = response.json()
if set_as_default:
self.__selected_model = model_id
return RegisteredModels.from_dict(response_payload)
@wrap_errors
def unload_model(self, model_id: str) -> RegisteredModels:
self.__ensure_v1_client_mode()
response = requests.post(
f"{self.__api_url}/model/remove",
json={
"model_id": model_id,
},
headers=DEFAULT_HEADERS,
)
response.raise_for_status()
response_payload = response.json()
if model_id == self.__selected_model:
self.__selected_model = None
return RegisteredModels.from_dict(response_payload)
@wrap_errors
def unload_all_models(self) -> RegisteredModels:
self.__ensure_v1_client_mode()
response = requests.post(f"{self.__api_url}/model/clear")
response.raise_for_status()
response_payload = response.json()
self.__selected_model = None
return RegisteredModels.from_dict(response_payload)
def __ensure_v1_client_mode(self) -> None:
if self.__client_mode is not HTTPClientMode.V1:
raise WrongClientModeError("Use client mode `v1` to run this operation.")
def _determine_client_downsizing_parameters(
client_downsizing_disabled: bool,
model_description: Optional[ModelDescription],
default_max_input_size: int,
) -> Tuple[Optional[int], Optional[int]]:
if client_downsizing_disabled:
return None, None
if (
model_description is None
or model_description.input_height is None
or model_description.input_width is None
):
return default_max_input_size, default_max_input_size
return model_description.input_height, model_description.input_width
def _determine_client_mode(api_url: str) -> HTTPClientMode:
if "roboflow.com" in api_url:
return HTTPClientMode.V0
return HTTPClientMode.V1
def _ensure_model_is_selected(model_id: Optional[str]) -> None:
if model_id is None:
raise ModelNotSelectedError("No model was selected to be used.")