Sankie005's picture
Upload 434 files
c446951
from typing import Any, Dict, List, Optional
from inference.core import logger
from inference.core.active_learning.entities import (
ActiveLearningConfiguration,
RoboflowProjectMetadata,
SamplingMethod,
)
from inference.core.active_learning.samplers.close_to_threshold import (
initialize_close_to_threshold_sampling,
)
from inference.core.active_learning.samplers.contains_classes import (
initialize_classes_based_sampling,
)
from inference.core.active_learning.samplers.number_of_detections import (
initialize_detections_number_based_sampling,
)
from inference.core.active_learning.samplers.random import initialize_random_sampling
from inference.core.env import ACTIVE_LEARNING_ENABLED
from inference.core.exceptions import ActiveLearningConfigurationError
from inference.core.roboflow_api import (
get_roboflow_active_learning_configuration,
get_roboflow_dataset_type,
get_roboflow_workspace,
)
from inference.core.utils.roboflow import get_model_id_chunks
TYPE2SAMPLING_INITIALIZERS = {
"random": initialize_random_sampling,
"close_to_threshold": initialize_close_to_threshold_sampling,
"classes_based": initialize_classes_based_sampling,
"detections_number_based": initialize_detections_number_based_sampling,
}
def prepare_active_learning_configuration(
api_key: str,
model_id: str,
) -> Optional[ActiveLearningConfiguration]:
if not ACTIVE_LEARNING_ENABLED:
return None
project_metadata = get_roboflow_project_metadata(
api_key=api_key,
model_id=model_id,
)
if not project_metadata.active_learning_configuration.get("enabled", False):
return None
logger.info(
f"Configuring active learning for workspace: {project_metadata.workspace_id}, "
f"project: {project_metadata.dataset_id} of type: {project_metadata.dataset_type}. "
f"AL configuration: {project_metadata.active_learning_configuration}"
)
sampling_methods = initialize_sampling_methods(
sampling_strategies_configs=project_metadata.active_learning_configuration[
"sampling_strategies"
],
)
target_workspace_id = project_metadata.active_learning_configuration.get(
"target_workspace", project_metadata.workspace_id
)
target_dataset_id = project_metadata.active_learning_configuration.get(
"target_project", project_metadata.dataset_id
)
return ActiveLearningConfiguration.init(
roboflow_api_configuration=project_metadata.active_learning_configuration,
sampling_methods=sampling_methods,
workspace_id=target_workspace_id,
dataset_id=target_dataset_id,
model_id=model_id,
)
def get_roboflow_project_metadata(
api_key: str,
model_id: str,
) -> RoboflowProjectMetadata:
logger.info(f"Fetching active learning configuration.")
dataset_id, version_id = get_model_id_chunks(model_id=model_id)
workspace_id = get_roboflow_workspace(api_key=api_key)
dataset_type = get_roboflow_dataset_type(
api_key=api_key,
workspace_id=workspace_id,
dataset_id=dataset_id,
)
roboflow_api_configuration = get_roboflow_active_learning_configuration(
api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id
)
return RoboflowProjectMetadata(
dataset_id=dataset_id,
version_id=version_id,
workspace_id=workspace_id,
dataset_type=dataset_type,
active_learning_configuration=roboflow_api_configuration,
)
def initialize_sampling_methods(
sampling_strategies_configs: List[Dict[str, Any]]
) -> List[SamplingMethod]:
result = []
for sampling_strategy_config in sampling_strategies_configs:
sampling_type = sampling_strategy_config["type"]
if sampling_type not in TYPE2SAMPLING_INITIALIZERS:
logger.warn(
f"Could not identify sampling method `{sampling_type}` - skipping initialisation."
)
continue
initializer = TYPE2SAMPLING_INITIALIZERS[sampling_type]
result.append(initializer(sampling_strategy_config))
names = set(m.name for m in result)
if len(names) != len(result):
raise ActiveLearningConfigurationError(
"Detected duplication of Active Learning strategies names."
)
return result