|
|
import urllib.parse |
|
|
from enum import Enum |
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union |
|
|
|
|
|
import numpy as np |
|
|
import requests |
|
|
from requests import Response |
|
|
from requests_toolbelt import MultipartEncoder |
|
|
|
|
|
from inference.core import logger |
|
|
from inference.core.entities.types import ( |
|
|
DatasetID, |
|
|
ModelType, |
|
|
TaskType, |
|
|
VersionID, |
|
|
WorkspaceID, |
|
|
) |
|
|
from inference.core.env import API_BASE_URL |
|
|
from inference.core.exceptions import ( |
|
|
MalformedRoboflowAPIResponseError, |
|
|
MissingDefaultModelError, |
|
|
RoboflowAPIConnectionError, |
|
|
RoboflowAPIIAlreadyAnnotatedError, |
|
|
RoboflowAPIIAnnotationRejectionError, |
|
|
RoboflowAPIImageUploadRejectionError, |
|
|
RoboflowAPINotAuthorizedError, |
|
|
RoboflowAPINotNotFoundError, |
|
|
RoboflowAPIUnsuccessfulRequestError, |
|
|
WorkspaceLoadError, |
|
|
) |
|
|
from inference.core.utils.requests import api_key_safe_raise_for_status |
|
|
from inference.core.utils.url_utils import wrap_url |
|
|
|
|
|
MODEL_TYPE_DEFAULTS = { |
|
|
"object-detection": "yolov5v2s", |
|
|
"instance-segmentation": "yolact", |
|
|
"classification": "vit", |
|
|
} |
|
|
PROJECT_TASK_TYPE_KEY = "project_task_type" |
|
|
MODEL_TYPE_KEY = "model_type" |
|
|
|
|
|
NOT_FOUND_ERROR_MESSAGE = ( |
|
|
"Could not find requested Roboflow resource. Check that the provided dataset and " |
|
|
"version are correct, and check that the provided Roboflow API key has the correct permissions." |
|
|
) |
|
|
|
|
|
|
|
|
def raise_from_lambda( |
|
|
inner_error: Exception, exception_type: Type[Exception], message: str |
|
|
) -> None: |
|
|
raise exception_type(message) from inner_error |
|
|
|
|
|
|
|
|
DEFAULT_ERROR_HANDLERS = { |
|
|
401: lambda e: raise_from_lambda( |
|
|
e, |
|
|
RoboflowAPINotAuthorizedError, |
|
|
"Unauthorized access to roboflow API - check API key.", |
|
|
), |
|
|
404: lambda e: raise_from_lambda( |
|
|
e, RoboflowAPINotNotFoundError, NOT_FOUND_ERROR_MESSAGE |
|
|
), |
|
|
} |
|
|
|
|
|
|
|
|
def wrap_roboflow_api_errors( |
|
|
http_errors_handlers: Optional[ |
|
|
Dict[int, Callable[[Union[requests.exceptions.HTTPError]], None]] |
|
|
] = None, |
|
|
) -> callable: |
|
|
def decorator(function: callable) -> callable: |
|
|
def wrapper(*args, **kwargs) -> Any: |
|
|
try: |
|
|
return function(*args, **kwargs) |
|
|
except (requests.exceptions.ConnectionError, ConnectionError) as error: |
|
|
logger.error(f"Could not connect to Roboflow API. Error: {error}") |
|
|
raise RoboflowAPIConnectionError( |
|
|
"Could not connect to Roboflow API." |
|
|
) from error |
|
|
except requests.exceptions.HTTPError as error: |
|
|
logger.error( |
|
|
f"HTTP error encountered while requesting Roboflow API response: {error}" |
|
|
) |
|
|
user_handler_override = ( |
|
|
http_errors_handlers if http_errors_handlers is not None else {} |
|
|
) |
|
|
status_code = error.response.status_code |
|
|
default_handler = DEFAULT_ERROR_HANDLERS.get(status_code) |
|
|
error_handler = user_handler_override.get(status_code, default_handler) |
|
|
if error_handler is not None: |
|
|
error_handler(error) |
|
|
raise RoboflowAPIUnsuccessfulRequestError( |
|
|
f"Unsuccessful request to Roboflow API with response code: {status_code}" |
|
|
) from error |
|
|
except requests.exceptions.InvalidJSONError as error: |
|
|
logger.error( |
|
|
f"Could not decode JSON response from Roboflow API. Error: {error}." |
|
|
) |
|
|
raise MalformedRoboflowAPIResponseError( |
|
|
"Could not decode JSON response from Roboflow API." |
|
|
) from error |
|
|
|
|
|
return wrapper |
|
|
|
|
|
return decorator |
|
|
|
|
|
|
|
|
@wrap_roboflow_api_errors() |
|
|
def get_roboflow_workspace(api_key: str) -> WorkspaceID: |
|
|
api_url = _add_params_to_url( |
|
|
url=API_BASE_URL, |
|
|
params=[("api_key", api_key), ("nocache", "true")], |
|
|
) |
|
|
api_key_info = _get_from_url(url=api_url) |
|
|
workspace_id = api_key_info.get("workspace") |
|
|
if workspace_id is None: |
|
|
raise WorkspaceLoadError(f"Empty workspace encountered, check your API key.") |
|
|
return workspace_id |
|
|
|
|
|
|
|
|
@wrap_roboflow_api_errors() |
|
|
def get_roboflow_dataset_type( |
|
|
api_key: str, workspace_id: WorkspaceID, dataset_id: DatasetID |
|
|
) -> TaskType: |
|
|
api_url = _add_params_to_url( |
|
|
url=f"{API_BASE_URL}/{workspace_id}/{dataset_id}", |
|
|
params=[("api_key", api_key), ("nocache", "true")], |
|
|
) |
|
|
dataset_info = _get_from_url(url=api_url) |
|
|
project_task_type = dataset_info.get("project", {}) |
|
|
if "type" not in project_task_type: |
|
|
logger.warning( |
|
|
f"Project task type not defined for workspace={workspace_id} and dataset={dataset_id}, defaulting " |
|
|
f"to object-detection." |
|
|
) |
|
|
return project_task_type.get("type", "object-detection") |
|
|
|
|
|
|
|
|
@wrap_roboflow_api_errors( |
|
|
http_errors_handlers={ |
|
|
500: lambda e: raise_from_lambda( |
|
|
e, RoboflowAPINotNotFoundError, NOT_FOUND_ERROR_MESSAGE |
|
|
) |
|
|
|
|
|
|
|
|
} |
|
|
) |
|
|
def get_roboflow_model_type( |
|
|
api_key: str, |
|
|
workspace_id: WorkspaceID, |
|
|
dataset_id: DatasetID, |
|
|
version_id: VersionID, |
|
|
project_task_type: ModelType, |
|
|
) -> ModelType: |
|
|
api_url = _add_params_to_url( |
|
|
url=f"{API_BASE_URL}/{workspace_id}/{dataset_id}/{version_id}", |
|
|
params=[("api_key", api_key), ("nocache", "true")], |
|
|
) |
|
|
version_info = _get_from_url(url=api_url) |
|
|
model_type = version_info["version"] |
|
|
if "modelType" not in model_type: |
|
|
if project_task_type not in MODEL_TYPE_DEFAULTS: |
|
|
raise MissingDefaultModelError( |
|
|
f"Could not set default model for {project_task_type}" |
|
|
) |
|
|
logger.warning( |
|
|
f"Model type not defined - using default for {project_task_type} task." |
|
|
) |
|
|
return model_type.get("modelType", MODEL_TYPE_DEFAULTS[project_task_type]) |
|
|
|
|
|
|
|
|
class ModelEndpointType(Enum): |
|
|
ORT = "ort" |
|
|
CORE_MODEL = "core_model" |
|
|
|
|
|
|
|
|
@wrap_roboflow_api_errors() |
|
|
def get_roboflow_model_data( |
|
|
api_key: str, |
|
|
model_id: str, |
|
|
endpoint_type: ModelEndpointType, |
|
|
device_id: str, |
|
|
) -> dict: |
|
|
api_url = _add_params_to_url( |
|
|
url=f"{API_BASE_URL}/{endpoint_type.value}/{model_id}", |
|
|
params=[ |
|
|
("api_key", api_key), |
|
|
("nocache", "true"), |
|
|
("device", device_id), |
|
|
("dynamic", "true"), |
|
|
], |
|
|
) |
|
|
return _get_from_url(url=api_url) |
|
|
|
|
|
|
|
|
@wrap_roboflow_api_errors() |
|
|
def get_roboflow_active_learning_configuration( |
|
|
api_key: str, |
|
|
workspace_id: WorkspaceID, |
|
|
dataset_id: DatasetID, |
|
|
) -> dict: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return { |
|
|
"enabled": True, |
|
|
"max_image_size": (1200, 1200), |
|
|
"jpeg_compression_level": 75, |
|
|
"persist_predictions": True, |
|
|
"sampling_strategies": [ |
|
|
{ |
|
|
"name": "default_strategy", |
|
|
"type": "random", |
|
|
"traffic_percentage": 0.1, |
|
|
"tags": ["random-traffic"], |
|
|
"limits": [ |
|
|
{"type": "minutely", "value": 10}, |
|
|
{"type": "hourly", "value": 100}, |
|
|
{"type": "daily", "value": 1000}, |
|
|
], |
|
|
}, |
|
|
{ |
|
|
"name": "hard_examples", |
|
|
"type": "close_to_threshold", |
|
|
"threshold": 0.3, |
|
|
"epsilon": 0.3, |
|
|
"probability": 0.3, |
|
|
"tags": ["hard-case"], |
|
|
"limits": [ |
|
|
{"type": "minutely", "value": 10}, |
|
|
{"type": "hourly", "value": 100}, |
|
|
{"type": "daily", "value": 1000}, |
|
|
], |
|
|
}, |
|
|
{ |
|
|
"name": "multiple_detections", |
|
|
"type": "detections_number_based", |
|
|
"probability": 0.2, |
|
|
"more_than": 3, |
|
|
"tags": ["crowded"], |
|
|
"limits": [ |
|
|
{"type": "minutely", "value": 10}, |
|
|
{"type": "hourly", "value": 100}, |
|
|
{"type": "daily", "value": 1000}, |
|
|
], |
|
|
}, |
|
|
{ |
|
|
"name": "underrepresented_classes", |
|
|
"type": "classes_based", |
|
|
"selected_class_names": ["cat"], |
|
|
"probability": 1.0, |
|
|
"tags": ["hard-classes"], |
|
|
"limits": [ |
|
|
{"type": "minutely", "value": 10}, |
|
|
{"type": "hourly", "value": 100}, |
|
|
{"type": "daily", "value": 1000}, |
|
|
], |
|
|
}, |
|
|
], |
|
|
"batching_strategy": { |
|
|
"batches_name_prefix": "al_batch", |
|
|
"recreation_interval": "daily", |
|
|
"max_batch_images": None, |
|
|
}, |
|
|
"tags": ["a", "b"], |
|
|
} |
|
|
|
|
|
|
|
|
@wrap_roboflow_api_errors() |
|
|
def register_image_at_roboflow( |
|
|
api_key: str, |
|
|
dataset_id: DatasetID, |
|
|
local_image_id: str, |
|
|
image_bytes: bytes, |
|
|
batch_name: str, |
|
|
tags: Optional[List[str]] = None, |
|
|
) -> dict: |
|
|
url = f"{API_BASE_URL}/dataset/{dataset_id}/upload" |
|
|
params = [ |
|
|
("api_key", api_key), |
|
|
("batch", batch_name), |
|
|
] |
|
|
tags = tags if tags is not None else [] |
|
|
for tag in tags: |
|
|
params.append(("tag", tag)) |
|
|
wrapped_url = wrap_url(_add_params_to_url(url=url, params=params)) |
|
|
m = MultipartEncoder( |
|
|
fields={ |
|
|
"name": f"{local_image_id}.jpg", |
|
|
"file": ("imageToUpload", image_bytes, "image/jpeg"), |
|
|
} |
|
|
) |
|
|
response = requests.post( |
|
|
url=wrapped_url, |
|
|
data=m, |
|
|
headers={"Content-Type": m.content_type}, |
|
|
) |
|
|
api_key_safe_raise_for_status(response=response) |
|
|
parsed_response = response.json() |
|
|
if not parsed_response.get("duplicate") and not parsed_response.get("success"): |
|
|
raise RoboflowAPIImageUploadRejectionError( |
|
|
f"Server rejected image: {parsed_response}" |
|
|
) |
|
|
return parsed_response |
|
|
|
|
|
|
|
|
@wrap_roboflow_api_errors( |
|
|
http_errors_handlers={ |
|
|
409: lambda e: raise_from_lambda( |
|
|
e, |
|
|
RoboflowAPIIAlreadyAnnotatedError, |
|
|
"Given datapoint already has annotation.", |
|
|
) |
|
|
} |
|
|
) |
|
|
def annotate_image_at_roboflow( |
|
|
api_key: str, |
|
|
dataset_id: DatasetID, |
|
|
local_image_id: str, |
|
|
roboflow_image_id: str, |
|
|
annotation_content: str, |
|
|
annotation_file_type: str, |
|
|
is_prediction: bool = True, |
|
|
) -> dict: |
|
|
url = f"{API_BASE_URL}/dataset/{dataset_id}/annotate/{roboflow_image_id}" |
|
|
params = [ |
|
|
("api_key", api_key), |
|
|
("name", f"{local_image_id}.{annotation_file_type}"), |
|
|
("prediction", str(is_prediction).lower()), |
|
|
] |
|
|
wrapped_url = wrap_url(_add_params_to_url(url=url, params=params)) |
|
|
response = requests.post( |
|
|
wrapped_url, |
|
|
data=annotation_content, |
|
|
headers={"Content-Type": "text/plain"}, |
|
|
) |
|
|
api_key_safe_raise_for_status(response=response) |
|
|
parsed_response = response.json() |
|
|
if "error" in parsed_response or not parsed_response.get("success"): |
|
|
raise RoboflowAPIIAnnotationRejectionError( |
|
|
f"Failed to save annotation for {roboflow_image_id}. API response: {parsed_response}" |
|
|
) |
|
|
return parsed_response |
|
|
|
|
|
|
|
|
@wrap_roboflow_api_errors() |
|
|
def get_roboflow_labeling_batches( |
|
|
api_key: str, workspace_id: WorkspaceID, dataset_id: str |
|
|
) -> dict: |
|
|
api_url = _add_params_to_url( |
|
|
url=f"{API_BASE_URL}/{workspace_id}/{dataset_id}/batches", |
|
|
params=[("api_key", api_key)], |
|
|
) |
|
|
return _get_from_url(url=api_url) |
|
|
|
|
|
|
|
|
@wrap_roboflow_api_errors() |
|
|
def get_roboflow_labeling_jobs( |
|
|
api_key: str, workspace_id: WorkspaceID, dataset_id: str |
|
|
) -> dict: |
|
|
api_url = _add_params_to_url( |
|
|
url=f"{API_BASE_URL}/{workspace_id}/{dataset_id}/jobs", |
|
|
params=[("api_key", api_key)], |
|
|
) |
|
|
return _get_from_url(url=api_url) |
|
|
|
|
|
|
|
|
@wrap_roboflow_api_errors() |
|
|
def get_from_url( |
|
|
url: str, |
|
|
json_response: bool = True, |
|
|
) -> Union[Response, dict]: |
|
|
return _get_from_url(url=url, json_response=json_response) |
|
|
|
|
|
|
|
|
def _get_from_url(url: str, json_response: bool = True) -> Union[Response, dict]: |
|
|
response = requests.get(wrap_url(url)) |
|
|
api_key_safe_raise_for_status(response=response) |
|
|
if json_response: |
|
|
return response.json() |
|
|
return response |
|
|
|
|
|
|
|
|
def _add_params_to_url(url: str, params: List[Tuple[str, str]]) -> str: |
|
|
if len(params) == 0: |
|
|
return url |
|
|
params_chunks = [ |
|
|
f"{name}={urllib.parse.quote_plus(value)}" for name, value in params |
|
|
] |
|
|
parameters_string = "&".join(params_chunks) |
|
|
return f"{url}?{parameters_string}" |
|
|
|