Sankie005's picture
Upload 435 files
143eb4d
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from dataclasses_json import DataClassJsonMixin
from PIL import Image
from inference_sdk.http.errors import ModelTaskTypeNotSupportedError
from inference_sdk.http.utils.iterables import remove_empty_values
ImagesReference = Union[np.ndarray, Image.Image, str]
DEFAULT_IMAGE_EXTENSIONS = ["jpg", "jpeg", "JPG", "JPEG", "png", "PNG"]
TaskType = str
CLASSIFICATION_TASK = "classification"
OBJECT_DETECTION_TASK = "object-detection"
INSTANCE_SEGMENTATION_TASK = "instance-segmentation"
KEYPOINTS_DETECTION_TASK = "keypoints-detection"
DEFAULT_MAX_INPUT_SIZE = 1024
@dataclass(frozen=True)
class ServerInfo(DataClassJsonMixin):
name: str
version: str
uuid: str
@dataclass(frozen=True)
class ModelDescription(DataClassJsonMixin):
model_id: str
task_type: TaskType
batch_size: Optional[Union[int, str]] = None
input_height: Optional[int] = None
input_width: Optional[int] = None
@dataclass(frozen=True)
class RegisteredModels(DataClassJsonMixin):
models: List[ModelDescription]
class HTTPClientMode(Enum):
V0 = "v0"
V1 = "v1"
class VisualisationResponseFormat:
BASE64 = "base64"
NUMPY = "numpy"
PILLOW = "pillow"
@dataclass(frozen=True)
class InferenceConfiguration:
confidence_threshold: Optional[float] = None
keypoint_confidence_threshold: Optional[float] = None
format: Optional[str] = None
mask_decode_mode: Optional[str] = None
tradeoff_factor: Optional[float] = None
max_candidates: Optional[int] = None
max_detections: Optional[int] = None
iou_threshold: Optional[float] = None
stroke_width: Optional[int] = None
count_inference: Optional[bool] = None
service_secret: Optional[str] = None
disable_preproc_auto_orientation: Optional[bool] = None
disable_preproc_contrast: Optional[bool] = None
disable_preproc_grayscale: Optional[bool] = None
disable_preproc_static_crop: Optional[bool] = None
class_agnostic_nms: Optional[bool] = None
class_filter: Optional[List[str]] = None
fix_batch_size: Optional[bool] = None
visualize_predictions: bool = False
visualize_labels: Optional[bool] = None
output_visualisation_format: VisualisationResponseFormat = (
VisualisationResponseFormat.BASE64
)
image_extensions_for_directory_scan: Optional[List[str]] = field(
default_factory=lambda: DEFAULT_IMAGE_EXTENSIONS,
)
client_downsizing_disabled: bool = False
default_max_input_size: int = DEFAULT_MAX_INPUT_SIZE
@classmethod
def init_default(cls) -> "InferenceConfiguration":
return cls()
def to_api_call_parameters(
self, client_mode: HTTPClientMode, task_type: TaskType
) -> Dict[str, Any]:
if client_mode is HTTPClientMode.V0:
return self.to_legacy_call_parameters()
if task_type == OBJECT_DETECTION_TASK:
return self.to_object_detection_parameters()
if task_type == INSTANCE_SEGMENTATION_TASK:
return self.to_instance_segmentation_parameters()
if task_type == CLASSIFICATION_TASK:
return self.to_classification_parameters()
if task_type == KEYPOINTS_DETECTION_TASK:
return self.to_keypoints_detection_parameters()
raise ModelTaskTypeNotSupportedError(
f"Model task {task_type} is not supported by API v1 client."
)
def to_keypoints_detection_parameters(self) -> Dict[str, Any]:
parameters = self.to_object_detection_parameters()
parameters["keypoint_confidence"] = self.keypoint_confidence_threshold
return remove_empty_values(dictionary=parameters)
def to_object_detection_parameters(self) -> Dict[str, Any]:
parameters_specs = [
("disable_preproc_auto_orientation", "disable_preproc_auto_orient"),
("disable_preproc_contrast", "disable_preproc_contrast"),
("disable_preproc_grayscale", "disable_preproc_grayscale"),
("disable_preproc_static_crop", "disable_preproc_static_crop"),
("class_agnostic_nms", "class_agnostic_nms"),
("class_filter", "class_filter"),
("confidence_threshold", "confidence"),
("fix_batch_size", "fix_batch_size"),
("iou_threshold", "iou_threshold"),
("max_detections", "max_detections"),
("max_candidates", "max_candidates"),
("visualize_labels", "visualization_labels"),
("stroke_width", "visualization_stroke_width"),
("visualize_predictions", "visualize_predictions"),
]
return get_non_empty_attributes(
source_object=self,
specification=parameters_specs,
)
def to_instance_segmentation_parameters(self) -> Dict[str, Any]:
parameters = self.to_object_detection_parameters()
parameters_specs = [
("mask_decode_mode", "mask_decode_mode"),
("tradeoff_factor", "tradeoff_factor"),
]
for internal_name, external_name in parameters_specs:
parameters[external_name] = getattr(self, internal_name)
return remove_empty_values(dictionary=parameters)
def to_classification_parameters(self) -> Dict[str, Any]:
parameters_specs = [
("disable_preproc_auto_orientation", "disable_preproc_auto_orient"),
("disable_preproc_contrast", "disable_preproc_contrast"),
("disable_preproc_grayscale", "disable_preproc_grayscale"),
("disable_preproc_static_crop", "disable_preproc_static_crop"),
("confidence_threshold", "confidence"),
("visualize_predictions", "visualize_predictions"),
("stroke_width", "visualization_stroke_width"),
]
return get_non_empty_attributes(
source_object=self,
specification=parameters_specs,
)
def to_legacy_call_parameters(self) -> Dict[str, Any]:
parameters_specs = [
("confidence_threshold", "confidence"),
("keypoint_confidence_threshold", "keypoint_confidence"),
("format", "format"),
("visualize_labels", "labels"),
("mask_decode_mode", "mask_decode_mode"),
("tradeoff_factor", "tradeoff_factor"),
("max_detections", "max_detections"),
("iou_threshold", "overlap"),
("stroke_width", "stroke"),
("count_inference", "countinference"),
("service_secret", "service_secret"),
("disable_preproc_auto_orientation", "disable_preproc_auto_orient"),
("disable_preproc_contrast", "disable_preproc_contrast"),
("disable_preproc_grayscale", "disable_preproc_grayscale"),
("disable_preproc_static_crop", "disable_preproc_static_crop"),
]
return get_non_empty_attributes(
source_object=self,
specification=parameters_specs,
)
def get_non_empty_attributes(
source_object: object, specification: List[Tuple[str, str]]
) -> Dict[str, Any]:
attributes = {
external_name: getattr(source_object, internal_name)
for internal_name, external_name in specification
}
return remove_empty_values(dictionary=attributes)