| import os |
| from typing import Optional, Tuple, Union |
|
|
| from inference.core.cache import cache |
| from inference.core.devices.utils import GLOBAL_DEVICE_ID |
| from inference.core.entities.types import DatasetID, ModelType, TaskType, VersionID |
| from inference.core.env import LAMBDA, MODEL_CACHE_DIR |
| from inference.core.exceptions import ( |
| MissingApiKeyError, |
| ModelArtefactError, |
| ModelNotRecognisedError, |
| ) |
| from inference.core.logger import logger |
| from inference.core.models.base import Model |
| from inference.core.registries.base import ModelRegistry |
| from inference.core.roboflow_api import ( |
| MODEL_TYPE_DEFAULTS, |
| MODEL_TYPE_KEY, |
| PROJECT_TASK_TYPE_KEY, |
| ModelEndpointType, |
| get_roboflow_dataset_type, |
| get_roboflow_model_data, |
| get_roboflow_workspace, |
| ) |
| from inference.core.utils.file_system import dump_json, read_json |
| from inference.core.utils.roboflow import get_model_id_chunks |
| from inference.models.aliases import resolve_roboflow_model_alias |
|
|
| GENERIC_MODELS = { |
| "clip": ("embed", "clip"), |
| "sam": ("embed", "sam"), |
| "gaze": ("gaze", "l2cs"), |
| "doctr": ("ocr", "doctr"), |
| "grounding_dino": ("object-detection", "grounding-dino"), |
| "cogvlm": ("llm", "cogvlm"), |
| "yolo_world": ("object-detection", "yolo-world"), |
| } |
|
|
| STUB_VERSION_ID = "0" |
| CACHE_METADATA_LOCK_TIMEOUT = 1.0 |
|
|
|
|
| class RoboflowModelRegistry(ModelRegistry): |
| """A Roboflow-specific model registry which gets the model type using the model id, |
| then returns a model class based on the model type. |
| """ |
|
|
| def get_model(self, model_id: str, api_key: str) -> Model: |
| """Returns the model class based on the given model id and API key. |
| |
| Args: |
| model_id (str): The ID of the model to be retrieved. |
| api_key (str): The API key used to authenticate. |
| |
| Returns: |
| Model: The model class corresponding to the given model ID and type. |
| |
| Raises: |
| ModelNotRecognisedError: If the model type is not supported or found. |
| """ |
| model_type = get_model_type(model_id, api_key) |
| if model_type not in self.registry_dict: |
| raise ModelNotRecognisedError(f"Model type not supported: {model_type}") |
| return self.registry_dict[model_type] |
|
|
|
|
| def get_model_type( |
| model_id: str, |
| api_key: Optional[str] = None, |
| ) -> Tuple[TaskType, ModelType]: |
| """Retrieves the model type based on the given model ID and API key. |
| |
| Args: |
| model_id (str): The ID of the model. |
| api_key (str): The API key used to authenticate. |
| |
| Returns: |
| tuple: The project task type and the model type. |
| |
| Raises: |
| WorkspaceLoadError: If the workspace could not be loaded or if the API key is invalid. |
| DatasetLoadError: If the dataset could not be loaded due to invalid ID, workspace ID or version ID. |
| MissingDefaultModelError: If default model is not configured and API does not provide this info |
| MalformedRoboflowAPIResponseError: Roboflow API responds in invalid format. |
| """ |
| model_id = resolve_roboflow_model_alias(model_id=model_id) |
| dataset_id, version_id = get_model_id_chunks(model_id=model_id) |
| if dataset_id in GENERIC_MODELS: |
| logger.debug(f"Loading generic model: {dataset_id}.") |
| return GENERIC_MODELS[dataset_id] |
| cached_metadata = get_model_metadata_from_cache( |
| dataset_id=dataset_id, version_id=version_id |
| ) |
| if cached_metadata is not None: |
| return cached_metadata[0], cached_metadata[1] |
| if version_id == STUB_VERSION_ID: |
| if api_key is None: |
| raise MissingApiKeyError( |
| "Stub model version provided but no API key was provided. API key is required to load stub models." |
| ) |
| workspace_id = get_roboflow_workspace(api_key=api_key) |
| project_task_type = get_roboflow_dataset_type( |
| api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id |
| ) |
| model_type = "stub" |
| save_model_metadata_in_cache( |
| dataset_id=dataset_id, |
| version_id=version_id, |
| project_task_type=project_task_type, |
| model_type=model_type, |
| ) |
| return project_task_type, model_type |
| api_data = get_roboflow_model_data( |
| api_key=api_key, |
| model_id=model_id, |
| endpoint_type=ModelEndpointType.ORT, |
| device_id=GLOBAL_DEVICE_ID, |
| ).get("ort") |
| if api_data is None: |
| raise ModelArtefactError("Error loading model artifacts from Roboflow API.") |
| |
| project_task_type = api_data.get("type", "object-detection") |
| model_type = api_data.get("modelType") |
| if model_type is None or model_type == "ort": |
| |
| |
| model_type = MODEL_TYPE_DEFAULTS.get(project_task_type) |
| if model_type is None or project_task_type is None: |
| raise ModelArtefactError("Error loading model artifacts from Roboflow API.") |
| save_model_metadata_in_cache( |
| dataset_id=dataset_id, |
| version_id=version_id, |
| project_task_type=project_task_type, |
| model_type=model_type, |
| ) |
|
|
| return project_task_type, model_type |
|
|
|
|
| def get_model_metadata_from_cache( |
| dataset_id: str, version_id: str |
| ) -> Optional[Tuple[TaskType, ModelType]]: |
| if LAMBDA: |
| return _get_model_metadata_from_cache( |
| dataset_id=dataset_id, version_id=version_id |
| ) |
| with cache.lock( |
| f"lock:metadata:{dataset_id}:{version_id}", expire=CACHE_METADATA_LOCK_TIMEOUT |
| ): |
| return _get_model_metadata_from_cache( |
| dataset_id=dataset_id, version_id=version_id |
| ) |
|
|
|
|
| def _get_model_metadata_from_cache( |
| dataset_id: str, version_id: str |
| ) -> Optional[Tuple[TaskType, ModelType]]: |
| model_type_cache_path = construct_model_type_cache_path( |
| dataset_id=dataset_id, version_id=version_id |
| ) |
| if not os.path.isfile(model_type_cache_path): |
| return None |
| try: |
| model_metadata = read_json(path=model_type_cache_path) |
| if model_metadata_content_is_invalid(content=model_metadata): |
| return None |
| return model_metadata[PROJECT_TASK_TYPE_KEY], model_metadata[MODEL_TYPE_KEY] |
| except ValueError as e: |
| logger.warning( |
| f"Could not load model description from cache under path: {model_type_cache_path} - decoding issue: {e}." |
| ) |
| return None |
|
|
|
|
| def model_metadata_content_is_invalid(content: Optional[Union[list, dict]]) -> bool: |
| if content is None: |
| logger.warning("Empty model metadata file encountered in cache.") |
| return True |
| if not issubclass(type(content), dict): |
| logger.warning("Malformed file encountered in cache.") |
| return True |
| if PROJECT_TASK_TYPE_KEY not in content or MODEL_TYPE_KEY not in content: |
| logger.warning( |
| f"Could not find one of required keys {PROJECT_TASK_TYPE_KEY} or {MODEL_TYPE_KEY} in cache." |
| ) |
| return True |
| return False |
|
|
|
|
| def save_model_metadata_in_cache( |
| dataset_id: DatasetID, |
| version_id: VersionID, |
| project_task_type: TaskType, |
| model_type: ModelType, |
| ) -> None: |
| if LAMBDA: |
| _save_model_metadata_in_cache( |
| dataset_id=dataset_id, |
| version_id=version_id, |
| project_task_type=project_task_type, |
| model_type=model_type, |
| ) |
| return None |
| with cache.lock( |
| f"lock:metadata:{dataset_id}:{version_id}", expire=CACHE_METADATA_LOCK_TIMEOUT |
| ): |
| _save_model_metadata_in_cache( |
| dataset_id=dataset_id, |
| version_id=version_id, |
| project_task_type=project_task_type, |
| model_type=model_type, |
| ) |
| return None |
|
|
|
|
| def _save_model_metadata_in_cache( |
| dataset_id: DatasetID, |
| version_id: VersionID, |
| project_task_type: TaskType, |
| model_type: ModelType, |
| ) -> None: |
| model_type_cache_path = construct_model_type_cache_path( |
| dataset_id=dataset_id, version_id=version_id |
| ) |
| metadata = { |
| PROJECT_TASK_TYPE_KEY: project_task_type, |
| MODEL_TYPE_KEY: model_type, |
| } |
| dump_json( |
| path=model_type_cache_path, content=metadata, allow_override=True, indent=4 |
| ) |
|
|
|
|
| def construct_model_type_cache_path(dataset_id: str, version_id: str) -> str: |
| cache_dir = os.path.join(MODEL_CACHE_DIR, dataset_id, version_id) |
| return os.path.join(cache_dir, "model_type.json") |
|
|