Docker_ml / inference_sdk /http /utils /post_processing.py
Sankie005's picture
Upload 434 files
c446951
import base64
from typing import List, Optional, Union
import numpy as np
from PIL import Image
from requests import Response
from inference_sdk.http.entities import VisualisationResponseFormat
from inference_sdk.http.utils.encoding import (
bytes_to_opencv_image,
bytes_to_pillow_image,
encode_base_64,
)
CONTENT_TYPE_HEADERS = ["content-type", "Content-Type"]
IMAGES_TRANSCODING_METHODS = {
VisualisationResponseFormat.BASE64: encode_base_64,
VisualisationResponseFormat.NUMPY: bytes_to_opencv_image,
VisualisationResponseFormat.PILLOW: bytes_to_pillow_image,
}
def response_contains_jpeg_image(response: Response) -> bool:
content_type = None
for header_name in CONTENT_TYPE_HEADERS:
if header_name in response.headers:
content_type = response.headers[header_name]
break
if content_type is None:
return False
return "image/jpeg" in content_type
def transform_base64_visualisation(
visualisation: str,
expected_format: VisualisationResponseFormat,
) -> Union[str, np.ndarray, Image.Image]:
visualisation_bytes = base64.b64decode(visualisation)
return transform_visualisation_bytes(
visualisation=visualisation_bytes, expected_format=expected_format
)
def transform_visualisation_bytes(
visualisation: bytes,
expected_format: VisualisationResponseFormat,
) -> Union[str, np.ndarray, Image.Image]:
if expected_format not in IMAGES_TRANSCODING_METHODS:
raise NotImplementedError(
f"Expected format: {expected_format} is not supported in terms of visualisations transcoding."
)
transcoding_method = IMAGES_TRANSCODING_METHODS[expected_format]
return transcoding_method(visualisation)
def adjust_prediction_to_client_scaling_factor(
prediction: dict,
scaling_factor: Optional[float],
) -> dict:
if scaling_factor is None or prediction.get("is_stub", False):
return prediction
if "image" in prediction:
prediction["image"] = {
"width": round(prediction["image"]["width"] / scaling_factor),
"height": round(prediction["image"]["height"] / scaling_factor),
}
if predictions_should_not_be_post_processed(prediction=prediction):
return prediction
if "points" in prediction["predictions"][0]:
prediction[
"predictions"
] = adjust_prediction_with_bbox_and_points_to_client_scaling_factor(
predictions=prediction["predictions"],
scaling_factor=scaling_factor,
points_key="points",
)
elif "keypoints" in prediction["predictions"][0]:
prediction[
"predictions"
] = adjust_prediction_with_bbox_and_points_to_client_scaling_factor(
predictions=prediction["predictions"],
scaling_factor=scaling_factor,
points_key="keypoints",
)
elif "x" in prediction["predictions"][0] and "y" in prediction["predictions"][0]:
prediction[
"predictions"
] = adjust_object_detection_predictions_to_client_scaling_factor(
predictions=prediction["predictions"],
scaling_factor=scaling_factor,
)
return prediction
def predictions_should_not_be_post_processed(prediction: dict) -> bool:
# excluding from post-processing classification output and empty predictions
return (
"predictions" not in prediction
or not issubclass(type(prediction["predictions"]), list)
or len(prediction["predictions"]) == 0
)
def adjust_object_detection_predictions_to_client_scaling_factor(
predictions: List[dict],
scaling_factor: float,
) -> List[dict]:
result = []
for prediction in predictions:
prediction = adjust_bbox_coordinates_to_client_scaling_factor(
bbox=prediction,
scaling_factor=scaling_factor,
)
result.append(prediction)
return result
def adjust_prediction_with_bbox_and_points_to_client_scaling_factor(
predictions: List[dict],
scaling_factor: float,
points_key: str,
) -> List[dict]:
result = []
for prediction in predictions:
prediction = adjust_bbox_coordinates_to_client_scaling_factor(
bbox=prediction,
scaling_factor=scaling_factor,
)
prediction[points_key] = adjust_points_coordinates_to_client_scaling_factor(
points=prediction[points_key],
scaling_factor=scaling_factor,
)
result.append(prediction)
return result
def adjust_bbox_coordinates_to_client_scaling_factor(
bbox: dict,
scaling_factor: float,
) -> dict:
bbox["x"] = bbox["x"] / scaling_factor
bbox["y"] = bbox["y"] / scaling_factor
bbox["width"] = bbox["width"] / scaling_factor
bbox["height"] = bbox["height"] / scaling_factor
return bbox
def adjust_points_coordinates_to_client_scaling_factor(
points: List[dict],
scaling_factor: float,
) -> List[dict]:
result = []
for point in points:
point["x"] = point["x"] / scaling_factor
point["y"] = point["y"] / scaling_factor
result.append(point)
return result