""" Policy Registry for LeHome Challenge This module provides a registry system for policies, allowing participants to register their custom policies without modifying core evaluation code. """ from typing import Dict, Type, Optional from .base_policy import BasePolicy class PolicyRegistry: """ Global registry for all available policies. Usage: # Register a policy @PolicyRegistry.register("my_policy") class MyPolicy(BasePolicy): pass # Or register manually PolicyRegistry.register_policy("my_policy", MyPolicy) # Get available policies available = PolicyRegistry.list_policies() # Create policy instance policy = PolicyRegistry.create("my_policy", model_path="...") """ _registry: Dict[str, Type[BasePolicy]] = {} @classmethod def register(cls, name: str): """ Decorator to register a policy class. Args: name: Unique identifier for the policy. Example: @PolicyRegistry.register("my_policy") class MyPolicy(BasePolicy): pass """ def decorator(policy_cls: Type[BasePolicy]): cls.register_policy(name, policy_cls) return policy_cls return decorator @classmethod def register_policy(cls, name: str, policy_cls: Type[BasePolicy]): """ Manually register a policy class. Args: name: Unique identifier for the policy. policy_cls: Policy class (must inherit from BasePolicy). Raises: ValueError: If policy name already exists or class doesn't inherit BasePolicy. """ if name in cls._registry: raise ValueError(f"Policy '{name}' is already registered!") if not issubclass(policy_cls, BasePolicy): raise ValueError(f"Policy class must inherit from BasePolicy, got {policy_cls}") cls._registry[name] = policy_cls print(f"[PolicyRegistry] Registered policy: '{name}' -> {policy_cls.__name__}") @classmethod def get_policy_class(cls, name: str) -> Type[BasePolicy]: """ Get policy class by name. Args: name: Policy identifier. Returns: Policy class. Raises: KeyError: If policy name not found. """ if name not in cls._registry: available = ", ".join(cls._registry.keys()) raise KeyError( f"Policy '{name}' not found in registry. " f"Available policies: {available}" ) return cls._registry[name] @classmethod def create(cls, name: str, **kwargs) -> BasePolicy: """ Create a policy instance by name. Args: name: Policy identifier. **kwargs: Arguments to pass to policy constructor. Returns: Policy instance. """ policy_cls = cls.get_policy_class(name) return policy_cls(**kwargs) @classmethod def list_policies(cls) -> list: """ Get list of all registered policy names. Returns: List of policy names. """ return list(cls._registry.keys()) @classmethod def is_registered(cls, name: str) -> bool: """ Check if a policy is registered. Args: name: Policy identifier. Returns: True if registered, False otherwise. """ return name in cls._registry @classmethod def clear(cls): """Clear all registered policies (mainly for testing).""" cls._registry.clear()