| import json |
| from typing import List, Tuple |
|
|
| from inference.core.active_learning.entities import ( |
| Prediction, |
| PredictionFileType, |
| PredictionType, |
| SerialisedPrediction, |
| ) |
| from inference.core.constants import ( |
| CLASSIFICATION_TASK, |
| INSTANCE_SEGMENTATION_TASK, |
| OBJECT_DETECTION_TASK, |
| ) |
| from inference.core.exceptions import PredictionFormatNotSupported |
|
|
|
|
| def adjust_prediction_to_client_scaling_factor( |
| prediction: dict, scaling_factor: float, prediction_type: PredictionType |
| ) -> dict: |
| if abs(scaling_factor - 1.0) < 1e-5: |
| return prediction |
| if "image" in prediction: |
| prediction["image"] = { |
| "width": round(prediction["image"]["width"] / scaling_factor), |
| "height": round(prediction["image"]["height"] / scaling_factor), |
| } |
| if predictions_should_not_be_post_processed( |
| prediction=prediction, prediction_type=prediction_type |
| ): |
| return prediction |
| if prediction_type == INSTANCE_SEGMENTATION_TASK: |
| prediction["predictions"] = ( |
| adjust_prediction_with_bbox_and_points_to_client_scaling_factor( |
| predictions=prediction["predictions"], |
| scaling_factor=scaling_factor, |
| points_key="points", |
| ) |
| ) |
| if prediction_type == OBJECT_DETECTION_TASK: |
| prediction["predictions"] = ( |
| adjust_object_detection_predictions_to_client_scaling_factor( |
| predictions=prediction["predictions"], |
| scaling_factor=scaling_factor, |
| ) |
| ) |
| return prediction |
|
|
|
|
| def predictions_should_not_be_post_processed( |
| prediction: dict, prediction_type: PredictionType |
| ) -> bool: |
| |
| return ( |
| "is_stub" in prediction |
| or "predictions" not in prediction |
| or CLASSIFICATION_TASK in prediction_type |
| or len(prediction["predictions"]) == 0 |
| ) |
|
|
|
|
| def adjust_object_detection_predictions_to_client_scaling_factor( |
| predictions: List[dict], |
| scaling_factor: float, |
| ) -> List[dict]: |
| result = [] |
| for prediction in predictions: |
| prediction = adjust_bbox_coordinates_to_client_scaling_factor( |
| bbox=prediction, |
| scaling_factor=scaling_factor, |
| ) |
| result.append(prediction) |
| return result |
|
|
|
|
| def adjust_prediction_with_bbox_and_points_to_client_scaling_factor( |
| predictions: List[dict], |
| scaling_factor: float, |
| points_key: str, |
| ) -> List[dict]: |
| result = [] |
| for prediction in predictions: |
| prediction = adjust_bbox_coordinates_to_client_scaling_factor( |
| bbox=prediction, |
| scaling_factor=scaling_factor, |
| ) |
| prediction[points_key] = adjust_points_coordinates_to_client_scaling_factor( |
| points=prediction[points_key], |
| scaling_factor=scaling_factor, |
| ) |
| result.append(prediction) |
| return result |
|
|
|
|
| def adjust_bbox_coordinates_to_client_scaling_factor( |
| bbox: dict, |
| scaling_factor: float, |
| ) -> dict: |
| bbox["x"] = bbox["x"] / scaling_factor |
| bbox["y"] = bbox["y"] / scaling_factor |
| bbox["width"] = bbox["width"] / scaling_factor |
| bbox["height"] = bbox["height"] / scaling_factor |
| return bbox |
|
|
|
|
| def adjust_points_coordinates_to_client_scaling_factor( |
| points: List[dict], |
| scaling_factor: float, |
| ) -> List[dict]: |
| result = [] |
| for point in points: |
| point["x"] = point["x"] / scaling_factor |
| point["y"] = point["y"] / scaling_factor |
| result.append(point) |
| return result |
|
|
|
|
| def encode_prediction( |
| prediction: Prediction, |
| prediction_type: PredictionType, |
| ) -> Tuple[SerialisedPrediction, PredictionFileType]: |
| if CLASSIFICATION_TASK not in prediction_type: |
| return json.dumps(prediction), "json" |
| if "top" in prediction: |
| return prediction["top"], "txt" |
| raise PredictionFormatNotSupported( |
| f"Prediction type or prediction format not supported." |
| ) |
|
|