# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import importlib import inspect from typing import Any, Callable from warnings import warn from physicsnemo.active_learning.protocols import ActiveLearningProtocol __all__ = ["registry"] class ActiveLearningRegistry: """ Registry for active learning protocols. This class provides a centralized registry for user-defined active learning protocols that implement the `ActiveLearningProtocol`. It enables string-based lookups for checkpointing and provides argument validation when constructing protocol instances. The registry supports two primary modes of interaction: 1. Registration via decorator: `@registry.register("my_strategy")` 2. Construction with validation: `registry.construct("my_strategy", **kwargs)` Attributes ---------- _registry : dict[str, type[ActiveLearningProtocol]] Internal dictionary mapping protocol names to their class types. Methods ------- register(cls_name: str) -> Callable[[type[ActiveLearningProtocol]], type[ActiveLearningProtocol]] Decorator to register a protocol class with a given name. construct(cls_name: str, **kwargs) -> ActiveLearningProtocol Construct an instance of a registered protocol with argument validation. is_registered(cls_name: str) -> bool Check if a protocol name is registered. Properties ---------- registered_names : list[str] A list of all registered protocol names, sorted alphabetically. Examples -------- Register a custom strategy: >>> from physicsnemo.active_learning._registry import registry >>> @registry.register("my_custom_strategy") ... class MyCustomStrategy: ... def __init__(self, param1: int, param2: str): ... self.param1 = param1 ... self.param2 = param2 Construct an instance with validation: >>> strategy = registry.construct("my_custom_strategy", param1=42, param2="test") """ def __init__(self) -> None: """Initialize an empty registry.""" self._registry: dict[str, type[ActiveLearningProtocol]] = {} def register( self, cls_name: str ) -> Callable[[type[ActiveLearningProtocol]], type[ActiveLearningProtocol]]: """ Decorator to register an active learning protocol class. This decorator registers a class implementing the `ActiveLearningProtocol` under the given name, allowing it to be retrieved and constructed later using the `construct` method. Parameters ---------- cls_name : str The name to register the protocol under. This will be used as the key for later retrieval. Returns ------- Callable[[type[ActiveLearningProtocol]], type[ActiveLearningProtocol]] A decorator function that registers the class and returns it unchanged. Raises ------ ValueError If a protocol with the same name is already registered. Examples -------- >>> @registry.register("my_new_strategy") ... class MyStrategy: ... def __init__(self, param: int): ... self.param = param """ def decorator( cls: type[ActiveLearningProtocol], ) -> type[ActiveLearningProtocol]: """ Method for decorating a class to registry it with the registry. """ if cls_name in self._registry: raise ValueError( f"Protocol '{cls_name}' is already registered. " f"Existing class: {self._registry[cls_name].__name__}" ) self._registry[cls_name] = cls return cls return decorator def construct( self, cls_name: str, module_path: str | None = None, **kwargs: Any ) -> ActiveLearningProtocol: """ Construct an instance of a registered protocol with argument validation. This method retrieves a registered protocol class by name, validates that the provided keyword arguments match the class's constructor signature, and returns a new instance of the class. Parameters ---------- cls_name : str The name of the registered protocol to construct. module_path: str | None The path to the module to get the class from. **kwargs : Any Keyword arguments to pass to the protocol's constructor. Returns ------- ActiveLearningProtocol A new instance of the requested protocol class. Raises ------ KeyError If the protocol name is not registered. TypeError If the provided keyword arguments do not match the constructor signature. This includes missing required parameters or unexpected parameters. Examples -------- >>> from physicsnemo.active_learning._registry import registry >>> @registry.register("my_latest_strategy") ... class MyStrategy: ... def __init__(self, param: int): ... self.param = param >>> strategy = registry.construct("my_latest_strategy", param=42) """ cls = self.get_class(cls_name, module_path) # Validate arguments against the class signature try: sig = inspect.signature(cls.__init__) except (ValueError, TypeError) as e: raise TypeError( f"Could not inspect signature of {cls.__name__}.__init__: {e}" ) # Get parameters, excluding 'self' params = { name: param for name, param in sig.parameters.items() if name != "self" } # Check if the signature accepts **kwargs has_var_keyword = any( p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values() ) # Check for missing required parameters missing = [] for name, param in params.items(): if ( param.kind not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL) and param.default is inspect.Parameter.empty and name not in kwargs ): missing.append(name) if missing: raise TypeError( f"Missing required arguments for {cls.__name__}: {', '.join(missing)}" ) # Check for unexpected parameters (unless **kwargs is present) if not has_var_keyword: param_names = { name for name, param in params.items() if param.kind not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL) } unexpected = [name for name in kwargs if name not in param_names] if unexpected: warn( f"Unexpected arguments for {cls.__name__}: {', '.join(unexpected)}. " f"Valid parameters: {', '.join(sorted(param_names))}" ) return cls(**kwargs) def __getitem__(self, cls_name: str) -> type[ActiveLearningProtocol]: """ Retrieve a registered protocol class by name using dict-like access. This method allows accessing registered protocol classes using square bracket notation, e.g., `registry['my_strategy']`. Parameters ---------- cls_name : str The name of the registered protocol to retrieve. Returns ------- type[ActiveLearningProtocol] The class type of the registered protocol. Raises ------ KeyError If the protocol name is not registered. Examples -------- >>> from physicsnemo.active_learning._registry import registry >>> @registry.register("my_strategy") ... class MyStrategy: ... def __init__(self, param: int): ... self.param = param >>> RetrievedClass = registry['my_strategy'] >>> instance = RetrievedClass(param=42) """ if cls_name not in self._registry: available = ", ".join(self._registry.keys()) if self._registry else "none" raise KeyError( f"Protocol '{cls_name}' is not registered. " f"Available protocols: {available}" ) return self._registry[cls_name] def is_registered(self, cls_name: str) -> bool: """ Check if a protocol name is registered. Parameters ---------- cls_name : str The name of the protocol to check. Returns ------- bool True if the protocol is registered, False otherwise. """ return cls_name in self._registry @property def registered_names(self) -> list[str]: """ A list of all registered protocol names, sorted alphabetically. Returns ------- list[str] A list of all registered protocol names, sorted alphabetically. """ return sorted(self._registry.keys()) def get_class(self, cls_name: str, module_path: str | None = None) -> type: """ Get a class by name from the registry or from a module path. Parameters ---------- cls_name: str The name of the class to get. module_path: str | None The path to the module to get the class from. Returns ------- type The class. Raises ------ NameError: If the class is not found in the registry or module. ModuleNotFoundError: If the module is not found with the specified module path. """ if cls_name in self.registered_names: return self._registry[cls_name] else: if module_path: module = importlib.import_module(module_path) cls = getattr(module, cls_name, None) if not cls: raise NameError( f"Class {cls_name} not found in module {module_path}" ) return cls else: raise NameError( f"Class {cls_name} not found in registry, and no module path was provided." ) # Module-level registry instance for global access registry = ActiveLearningRegistry()