from typing import Any, Dict, List, Optional from inference.core import logger from inference.core.active_learning.entities import ( ActiveLearningConfiguration, RoboflowProjectMetadata, SamplingMethod, ) from inference.core.active_learning.samplers.close_to_threshold import ( initialize_close_to_threshold_sampling, ) from inference.core.active_learning.samplers.contains_classes import ( initialize_classes_based_sampling, ) from inference.core.active_learning.samplers.number_of_detections import ( initialize_detections_number_based_sampling, ) from inference.core.active_learning.samplers.random import initialize_random_sampling from inference.core.env import ACTIVE_LEARNING_ENABLED from inference.core.exceptions import ActiveLearningConfigurationError from inference.core.roboflow_api import ( get_roboflow_active_learning_configuration, get_roboflow_dataset_type, get_roboflow_workspace, ) from inference.core.utils.roboflow import get_model_id_chunks TYPE2SAMPLING_INITIALIZERS = { "random": initialize_random_sampling, "close_to_threshold": initialize_close_to_threshold_sampling, "classes_based": initialize_classes_based_sampling, "detections_number_based": initialize_detections_number_based_sampling, } def prepare_active_learning_configuration( api_key: str, model_id: str, ) -> Optional[ActiveLearningConfiguration]: if not ACTIVE_LEARNING_ENABLED: return None project_metadata = get_roboflow_project_metadata( api_key=api_key, model_id=model_id, ) if not project_metadata.active_learning_configuration.get("enabled", False): return None logger.info( f"Configuring active learning for workspace: {project_metadata.workspace_id}, " f"project: {project_metadata.dataset_id} of type: {project_metadata.dataset_type}. " f"AL configuration: {project_metadata.active_learning_configuration}" ) sampling_methods = initialize_sampling_methods( sampling_strategies_configs=project_metadata.active_learning_configuration[ "sampling_strategies" ], ) target_workspace_id = project_metadata.active_learning_configuration.get( "target_workspace", project_metadata.workspace_id ) target_dataset_id = project_metadata.active_learning_configuration.get( "target_project", project_metadata.dataset_id ) return ActiveLearningConfiguration.init( roboflow_api_configuration=project_metadata.active_learning_configuration, sampling_methods=sampling_methods, workspace_id=target_workspace_id, dataset_id=target_dataset_id, model_id=model_id, ) def get_roboflow_project_metadata( api_key: str, model_id: str, ) -> RoboflowProjectMetadata: logger.info(f"Fetching active learning configuration.") dataset_id, version_id = get_model_id_chunks(model_id=model_id) workspace_id = get_roboflow_workspace(api_key=api_key) dataset_type = get_roboflow_dataset_type( api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id, ) roboflow_api_configuration = get_roboflow_active_learning_configuration( api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id ) return RoboflowProjectMetadata( dataset_id=dataset_id, version_id=version_id, workspace_id=workspace_id, dataset_type=dataset_type, active_learning_configuration=roboflow_api_configuration, ) def initialize_sampling_methods( sampling_strategies_configs: List[Dict[str, Any]] ) -> List[SamplingMethod]: result = [] for sampling_strategy_config in sampling_strategies_configs: sampling_type = sampling_strategy_config["type"] if sampling_type not in TYPE2SAMPLING_INITIALIZERS: logger.warn( f"Could not identify sampling method `{sampling_type}` - skipping initialisation." ) continue initializer = TYPE2SAMPLING_INITIALIZERS[sampling_type] result.append(initializer(sampling_strategy_config)) names = set(m.name for m in result) if len(names) != len(result): raise ActiveLearningConfigurationError( "Detected duplication of Active Learning strategies names." ) return result