File size: 3,890 Bytes
47ef73d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
Policy Registry for LeHome Challenge

This module provides a registry system for policies, allowing participants
to register their custom policies without modifying core evaluation code.
"""

from typing import Dict, Type, Optional
from .base_policy import BasePolicy


class PolicyRegistry:
    """
    Global registry for all available policies.
    
    Usage:
        # Register a policy
        @PolicyRegistry.register("my_policy")
        class MyPolicy(BasePolicy):
            pass
        
        # Or register manually
        PolicyRegistry.register_policy("my_policy", MyPolicy)
        
        # Get available policies
        available = PolicyRegistry.list_policies()
        
        # Create policy instance
        policy = PolicyRegistry.create("my_policy", model_path="...")
    """
    
    _registry: Dict[str, Type[BasePolicy]] = {}
    
    @classmethod
    def register(cls, name: str):
        """
        Decorator to register a policy class.
        
        Args:
            name: Unique identifier for the policy.
            
        Example:
            @PolicyRegistry.register("my_policy")
            class MyPolicy(BasePolicy):
                pass
        """
        def decorator(policy_cls: Type[BasePolicy]):
            cls.register_policy(name, policy_cls)
            return policy_cls
        return decorator
    
    @classmethod
    def register_policy(cls, name: str, policy_cls: Type[BasePolicy]):
        """
        Manually register a policy class.
        
        Args:
            name: Unique identifier for the policy.
            policy_cls: Policy class (must inherit from BasePolicy).
            
        Raises:
            ValueError: If policy name already exists or class doesn't inherit BasePolicy.
        """
        if name in cls._registry:
            raise ValueError(f"Policy '{name}' is already registered!")
        
        if not issubclass(policy_cls, BasePolicy):
            raise ValueError(f"Policy class must inherit from BasePolicy, got {policy_cls}")
        
        cls._registry[name] = policy_cls
        print(f"[PolicyRegistry] Registered policy: '{name}' -> {policy_cls.__name__}")
    
    @classmethod
    def get_policy_class(cls, name: str) -> Type[BasePolicy]:
        """
        Get policy class by name.
        
        Args:
            name: Policy identifier.
            
        Returns:
            Policy class.
            
        Raises:
            KeyError: If policy name not found.
        """
        if name not in cls._registry:
            available = ", ".join(cls._registry.keys())
            raise KeyError(
                f"Policy '{name}' not found in registry. "
                f"Available policies: {available}"
            )
        return cls._registry[name]
    
    @classmethod
    def create(cls, name: str, **kwargs) -> BasePolicy:
        """
        Create a policy instance by name.
        
        Args:
            name: Policy identifier.
            **kwargs: Arguments to pass to policy constructor.
            
        Returns:
            Policy instance.
        """
        policy_cls = cls.get_policy_class(name)
        return policy_cls(**kwargs)
    
    @classmethod
    def list_policies(cls) -> list:
        """
        Get list of all registered policy names.
        
        Returns:
            List of policy names.
        """
        return list(cls._registry.keys())
    
    @classmethod
    def is_registered(cls, name: str) -> bool:
        """
        Check if a policy is registered.
        
        Args:
            name: Policy identifier.
            
        Returns:
            True if registered, False otherwise.
        """
        return name in cls._registry
    
    @classmethod
    def clear(cls):
        """Clear all registered policies (mainly for testing)."""
        cls._registry.clear()