File size: 5,195 Bytes
c446951 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import time
from typing import Dict, Optional
from fastapi import BackgroundTasks
from inference.core import logger
from inference.core.active_learning.middlewares import ActiveLearningMiddleware
from inference.core.cache.base import BaseCache
from inference.core.entities.requests.inference import InferenceRequest
from inference.core.entities.responses.inference import InferenceResponse
from inference.core.env import DISABLE_PREPROC_AUTO_ORIENT
from inference.core.managers.base import ModelManager
from inference.core.registries.base import ModelRegistry
ACTIVE_LEARNING_ELIGIBLE_PARAM = "active_learning_eligible"
BACKGROUND_TASKS_PARAM = "background_tasks"
class ActiveLearningManager(ModelManager):
def __init__(
self,
model_registry: ModelRegistry,
cache: BaseCache,
middlewares: Optional[Dict[str, ActiveLearningMiddleware]] = None,
):
super().__init__(model_registry=model_registry)
self._cache = cache
self._middlewares = middlewares if middlewares is not None else {}
def infer_from_request(
self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
prediction = super().infer_from_request(model_id=model_id, request=request)
active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False)
if not active_learning_eligible:
return prediction
self.register(prediction=prediction, model_id=model_id, request=request)
return prediction
def register(
self, prediction: InferenceResponse, model_id: str, request: InferenceRequest
) -> None:
try:
self.ensure_middleware_initialised(model_id=model_id, request=request)
self.register_datapoint(
prediction=prediction,
model_id=model_id,
request=request,
)
except Exception as error:
# Error handling to be decided
logger.warning(
f"Error in datapoint registration for Active Learning. Details: {error}. "
f"Error is suppressed in favour of normal operations of API."
)
def ensure_middleware_initialised(
self, model_id: str, request: InferenceRequest
) -> None:
if model_id in self._middlewares:
return None
start = time.perf_counter()
logger.debug(f"Initialising AL middleware for {model_id}")
self._middlewares[model_id] = ActiveLearningMiddleware.init(
api_key=request.api_key,
model_id=model_id,
cache=self._cache,
)
end = time.perf_counter()
logger.debug(f"Middleware init latency: {(end - start) * 1000} ms")
def register_datapoint(
self, prediction: InferenceResponse, model_id: str, request: InferenceRequest
) -> None:
start = time.perf_counter()
inference_inputs = getattr(request, "image", None)
if inference_inputs is None:
logger.warning(
"Could not register datapoint, as inference input has no `image` field."
)
return None
if not issubclass(type(inference_inputs), list):
inference_inputs = [inference_inputs]
if not issubclass(type(prediction), list):
results_dicts = [prediction.dict(by_alias=True, exclude={"visualization"})]
else:
results_dicts = [
e.dict(by_alias=True, exclude={"visualization"}) for e in prediction
]
prediction_type = self.get_task_type(model_id=model_id)
disable_preproc_auto_orient = (
getattr(request, "disable_preproc_auto_orient", False)
or DISABLE_PREPROC_AUTO_ORIENT
)
self._middlewares[model_id].register_batch(
inference_inputs=inference_inputs,
predictions=results_dicts,
prediction_type=prediction_type,
disable_preproc_auto_orient=disable_preproc_auto_orient,
)
end = time.perf_counter()
logger.debug(f"Registration: {(end - start) * 1000} ms")
class BackgroundTaskActiveLearningManager(ActiveLearningManager):
def infer_from_request(
self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
prediction = super().infer_from_request(model_id=model_id, request=request)
active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False)
if not active_learning_eligible:
return prediction
if BACKGROUND_TASKS_PARAM not in kwargs:
logger.warning(
"BackgroundTaskActiveLearningManager used against rules - `background_tasks` argument not "
"provided making Active Learning registration running sequentially."
)
self.register(prediction=prediction, model_id=model_id, request=request)
else:
background_tasks: BackgroundTasks = kwargs["background_tasks"]
background_tasks.add_task(
self.register, prediction=prediction, model_id=model_id, request=request
)
return prediction
|