File size: 5,193 Bytes
c446951 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import base64
from typing import List, Optional, Union
import numpy as np
from PIL import Image
from requests import Response
from inference_sdk.http.entities import VisualisationResponseFormat
from inference_sdk.http.utils.encoding import (
bytes_to_opencv_image,
bytes_to_pillow_image,
encode_base_64,
)
CONTENT_TYPE_HEADERS = ["content-type", "Content-Type"]
IMAGES_TRANSCODING_METHODS = {
VisualisationResponseFormat.BASE64: encode_base_64,
VisualisationResponseFormat.NUMPY: bytes_to_opencv_image,
VisualisationResponseFormat.PILLOW: bytes_to_pillow_image,
}
def response_contains_jpeg_image(response: Response) -> bool:
content_type = None
for header_name in CONTENT_TYPE_HEADERS:
if header_name in response.headers:
content_type = response.headers[header_name]
break
if content_type is None:
return False
return "image/jpeg" in content_type
def transform_base64_visualisation(
visualisation: str,
expected_format: VisualisationResponseFormat,
) -> Union[str, np.ndarray, Image.Image]:
visualisation_bytes = base64.b64decode(visualisation)
return transform_visualisation_bytes(
visualisation=visualisation_bytes, expected_format=expected_format
)
def transform_visualisation_bytes(
visualisation: bytes,
expected_format: VisualisationResponseFormat,
) -> Union[str, np.ndarray, Image.Image]:
if expected_format not in IMAGES_TRANSCODING_METHODS:
raise NotImplementedError(
f"Expected format: {expected_format} is not supported in terms of visualisations transcoding."
)
transcoding_method = IMAGES_TRANSCODING_METHODS[expected_format]
return transcoding_method(visualisation)
def adjust_prediction_to_client_scaling_factor(
prediction: dict,
scaling_factor: Optional[float],
) -> dict:
if scaling_factor is None or prediction.get("is_stub", False):
return prediction
if "image" in prediction:
prediction["image"] = {
"width": round(prediction["image"]["width"] / scaling_factor),
"height": round(prediction["image"]["height"] / scaling_factor),
}
if predictions_should_not_be_post_processed(prediction=prediction):
return prediction
if "points" in prediction["predictions"][0]:
prediction[
"predictions"
] = adjust_prediction_with_bbox_and_points_to_client_scaling_factor(
predictions=prediction["predictions"],
scaling_factor=scaling_factor,
points_key="points",
)
elif "keypoints" in prediction["predictions"][0]:
prediction[
"predictions"
] = adjust_prediction_with_bbox_and_points_to_client_scaling_factor(
predictions=prediction["predictions"],
scaling_factor=scaling_factor,
points_key="keypoints",
)
elif "x" in prediction["predictions"][0] and "y" in prediction["predictions"][0]:
prediction[
"predictions"
] = adjust_object_detection_predictions_to_client_scaling_factor(
predictions=prediction["predictions"],
scaling_factor=scaling_factor,
)
return prediction
def predictions_should_not_be_post_processed(prediction: dict) -> bool:
# excluding from post-processing classification output and empty predictions
return (
"predictions" not in prediction
or not issubclass(type(prediction["predictions"]), list)
or len(prediction["predictions"]) == 0
)
def adjust_object_detection_predictions_to_client_scaling_factor(
predictions: List[dict],
scaling_factor: float,
) -> List[dict]:
result = []
for prediction in predictions:
prediction = adjust_bbox_coordinates_to_client_scaling_factor(
bbox=prediction,
scaling_factor=scaling_factor,
)
result.append(prediction)
return result
def adjust_prediction_with_bbox_and_points_to_client_scaling_factor(
predictions: List[dict],
scaling_factor: float,
points_key: str,
) -> List[dict]:
result = []
for prediction in predictions:
prediction = adjust_bbox_coordinates_to_client_scaling_factor(
bbox=prediction,
scaling_factor=scaling_factor,
)
prediction[points_key] = adjust_points_coordinates_to_client_scaling_factor(
points=prediction[points_key],
scaling_factor=scaling_factor,
)
result.append(prediction)
return result
def adjust_bbox_coordinates_to_client_scaling_factor(
bbox: dict,
scaling_factor: float,
) -> dict:
bbox["x"] = bbox["x"] / scaling_factor
bbox["y"] = bbox["y"] / scaling_factor
bbox["width"] = bbox["width"] / scaling_factor
bbox["height"] = bbox["height"] / scaling_factor
return bbox
def adjust_points_coordinates_to_client_scaling_factor(
points: List[dict],
scaling_factor: float,
) -> List[dict]:
result = []
for point in points:
point["x"] = point["x"] / scaling_factor
point["y"] = point["y"] / scaling_factor
result.append(point)
return result
|