File size: 3,890 Bytes
47ef73d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | """
Policy Registry for LeHome Challenge
This module provides a registry system for policies, allowing participants
to register their custom policies without modifying core evaluation code.
"""
from typing import Dict, Type, Optional
from .base_policy import BasePolicy
class PolicyRegistry:
"""
Global registry for all available policies.
Usage:
# Register a policy
@PolicyRegistry.register("my_policy")
class MyPolicy(BasePolicy):
pass
# Or register manually
PolicyRegistry.register_policy("my_policy", MyPolicy)
# Get available policies
available = PolicyRegistry.list_policies()
# Create policy instance
policy = PolicyRegistry.create("my_policy", model_path="...")
"""
_registry: Dict[str, Type[BasePolicy]] = {}
@classmethod
def register(cls, name: str):
"""
Decorator to register a policy class.
Args:
name: Unique identifier for the policy.
Example:
@PolicyRegistry.register("my_policy")
class MyPolicy(BasePolicy):
pass
"""
def decorator(policy_cls: Type[BasePolicy]):
cls.register_policy(name, policy_cls)
return policy_cls
return decorator
@classmethod
def register_policy(cls, name: str, policy_cls: Type[BasePolicy]):
"""
Manually register a policy class.
Args:
name: Unique identifier for the policy.
policy_cls: Policy class (must inherit from BasePolicy).
Raises:
ValueError: If policy name already exists or class doesn't inherit BasePolicy.
"""
if name in cls._registry:
raise ValueError(f"Policy '{name}' is already registered!")
if not issubclass(policy_cls, BasePolicy):
raise ValueError(f"Policy class must inherit from BasePolicy, got {policy_cls}")
cls._registry[name] = policy_cls
print(f"[PolicyRegistry] Registered policy: '{name}' -> {policy_cls.__name__}")
@classmethod
def get_policy_class(cls, name: str) -> Type[BasePolicy]:
"""
Get policy class by name.
Args:
name: Policy identifier.
Returns:
Policy class.
Raises:
KeyError: If policy name not found.
"""
if name not in cls._registry:
available = ", ".join(cls._registry.keys())
raise KeyError(
f"Policy '{name}' not found in registry. "
f"Available policies: {available}"
)
return cls._registry[name]
@classmethod
def create(cls, name: str, **kwargs) -> BasePolicy:
"""
Create a policy instance by name.
Args:
name: Policy identifier.
**kwargs: Arguments to pass to policy constructor.
Returns:
Policy instance.
"""
policy_cls = cls.get_policy_class(name)
return policy_cls(**kwargs)
@classmethod
def list_policies(cls) -> list:
"""
Get list of all registered policy names.
Returns:
List of policy names.
"""
return list(cls._registry.keys())
@classmethod
def is_registered(cls, name: str) -> bool:
"""
Check if a policy is registered.
Args:
name: Policy identifier.
Returns:
True if registered, False otherwise.
"""
return name in cls._registry
@classmethod
def clear(cls):
"""Clear all registered policies (mainly for testing)."""
cls._registry.clear()
|