| import time |
| from typing import Dict, Optional |
|
|
| from fastapi import BackgroundTasks |
|
|
| from inference.core import logger |
| from inference.core.active_learning.middlewares import ActiveLearningMiddleware |
| from inference.core.cache.base import BaseCache |
| from inference.core.entities.requests.inference import InferenceRequest |
| from inference.core.entities.responses.inference import InferenceResponse |
| from inference.core.env import DISABLE_PREPROC_AUTO_ORIENT |
| from inference.core.managers.base import ModelManager |
| from inference.core.registries.base import ModelRegistry |
|
|
| ACTIVE_LEARNING_ELIGIBLE_PARAM = "active_learning_eligible" |
| DISABLE_ACTIVE_LEARNING_PARAM = "disable_active_learning" |
| BACKGROUND_TASKS_PARAM = "background_tasks" |
|
|
|
|
| class ActiveLearningManager(ModelManager): |
| def __init__( |
| self, |
| model_registry: ModelRegistry, |
| cache: BaseCache, |
| middlewares: Optional[Dict[str, ActiveLearningMiddleware]] = None, |
| ): |
| super().__init__(model_registry=model_registry) |
| self._cache = cache |
| self._middlewares = middlewares if middlewares is not None else {} |
|
|
| async def infer_from_request( |
| self, model_id: str, request: InferenceRequest, **kwargs |
| ) -> InferenceResponse: |
| prediction = await super().infer_from_request( |
| model_id=model_id, request=request, **kwargs |
| ) |
| active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False) |
| active_learning_disabled_for_request = getattr( |
| request, DISABLE_ACTIVE_LEARNING_PARAM, False |
| ) |
| if ( |
| not active_learning_eligible |
| or active_learning_disabled_for_request |
| or request.api_key is None |
| ): |
| return prediction |
| self.register(prediction=prediction, model_id=model_id, request=request) |
| return prediction |
|
|
| def register( |
| self, prediction: InferenceResponse, model_id: str, request: InferenceRequest |
| ) -> None: |
| try: |
| self.ensure_middleware_initialised(model_id=model_id, request=request) |
| self.register_datapoint( |
| prediction=prediction, |
| model_id=model_id, |
| request=request, |
| ) |
| except Exception as error: |
| |
| logger.warning( |
| f"Error in datapoint registration for Active Learning. Details: {error}. " |
| f"Error is suppressed in favour of normal operations of API." |
| ) |
|
|
| def ensure_middleware_initialised( |
| self, model_id: str, request: InferenceRequest |
| ) -> None: |
| if model_id in self._middlewares: |
| return None |
| start = time.perf_counter() |
| logger.debug(f"Initialising AL middleware for {model_id}") |
| self._middlewares[model_id] = ActiveLearningMiddleware.init( |
| api_key=request.api_key, |
| model_id=model_id, |
| cache=self._cache, |
| ) |
| end = time.perf_counter() |
| logger.debug(f"Middleware init latency: {(end - start) * 1000} ms") |
|
|
| def register_datapoint( |
| self, prediction: InferenceResponse, model_id: str, request: InferenceRequest |
| ) -> None: |
| start = time.perf_counter() |
| inference_inputs = getattr(request, "image", None) |
| if inference_inputs is None: |
| logger.warning( |
| "Could not register datapoint, as inference input has no `image` field." |
| ) |
| return None |
| if not issubclass(type(inference_inputs), list): |
| inference_inputs = [inference_inputs] |
| if not issubclass(type(prediction), list): |
| results_dicts = [prediction.dict(by_alias=True, exclude={"visualization"})] |
| else: |
| results_dicts = [ |
| e.dict(by_alias=True, exclude={"visualization"}) for e in prediction |
| ] |
| prediction_type = self.get_task_type(model_id=model_id) |
| disable_preproc_auto_orient = ( |
| getattr(request, "disable_preproc_auto_orient", False) |
| or DISABLE_PREPROC_AUTO_ORIENT |
| ) |
| self._middlewares[model_id].register_batch( |
| inference_inputs=inference_inputs, |
| predictions=results_dicts, |
| prediction_type=prediction_type, |
| disable_preproc_auto_orient=disable_preproc_auto_orient, |
| ) |
| end = time.perf_counter() |
| logger.debug(f"Registration: {(end - start) * 1000} ms") |
|
|
|
|
| class BackgroundTaskActiveLearningManager(ActiveLearningManager): |
| async def infer_from_request( |
| self, model_id: str, request: InferenceRequest, **kwargs |
| ) -> InferenceResponse: |
| active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False) |
| active_learning_disabled_for_request = getattr( |
| request, DISABLE_ACTIVE_LEARNING_PARAM, False |
| ) |
| kwargs[ACTIVE_LEARNING_ELIGIBLE_PARAM] = False |
| prediction = await super().infer_from_request( |
| model_id=model_id, request=request, **kwargs |
| ) |
| if ( |
| not active_learning_eligible |
| or active_learning_disabled_for_request |
| or request.api_key is None |
| ): |
| return prediction |
| if BACKGROUND_TASKS_PARAM not in kwargs: |
| logger.warning( |
| "BackgroundTaskActiveLearningManager used against rules - `background_tasks` argument not " |
| "provided making Active Learning registration running sequentially." |
| ) |
| self.register(prediction=prediction, model_id=model_id, request=request) |
| else: |
| background_tasks: BackgroundTasks = kwargs["background_tasks"] |
| background_tasks.add_task( |
| self.register, prediction=prediction, model_id=model_id, request=request |
| ) |
| return prediction |
|
|