andnetdeboer commited on
Commit
47ef73d
·
verified ·
1 Parent(s): 54fa9aa

Upload scripts/eval_policy/registry.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/eval_policy/registry.py +136 -0
scripts/eval_policy/registry.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Policy Registry for LeHome Challenge
3
+
4
+ This module provides a registry system for policies, allowing participants
5
+ to register their custom policies without modifying core evaluation code.
6
+ """
7
+
8
+ from typing import Dict, Type, Optional
9
+ from .base_policy import BasePolicy
10
+
11
+
12
+ class PolicyRegistry:
13
+ """
14
+ Global registry for all available policies.
15
+
16
+ Usage:
17
+ # Register a policy
18
+ @PolicyRegistry.register("my_policy")
19
+ class MyPolicy(BasePolicy):
20
+ pass
21
+
22
+ # Or register manually
23
+ PolicyRegistry.register_policy("my_policy", MyPolicy)
24
+
25
+ # Get available policies
26
+ available = PolicyRegistry.list_policies()
27
+
28
+ # Create policy instance
29
+ policy = PolicyRegistry.create("my_policy", model_path="...")
30
+ """
31
+
32
+ _registry: Dict[str, Type[BasePolicy]] = {}
33
+
34
+ @classmethod
35
+ def register(cls, name: str):
36
+ """
37
+ Decorator to register a policy class.
38
+
39
+ Args:
40
+ name: Unique identifier for the policy.
41
+
42
+ Example:
43
+ @PolicyRegistry.register("my_policy")
44
+ class MyPolicy(BasePolicy):
45
+ pass
46
+ """
47
+ def decorator(policy_cls: Type[BasePolicy]):
48
+ cls.register_policy(name, policy_cls)
49
+ return policy_cls
50
+ return decorator
51
+
52
+ @classmethod
53
+ def register_policy(cls, name: str, policy_cls: Type[BasePolicy]):
54
+ """
55
+ Manually register a policy class.
56
+
57
+ Args:
58
+ name: Unique identifier for the policy.
59
+ policy_cls: Policy class (must inherit from BasePolicy).
60
+
61
+ Raises:
62
+ ValueError: If policy name already exists or class doesn't inherit BasePolicy.
63
+ """
64
+ if name in cls._registry:
65
+ raise ValueError(f"Policy '{name}' is already registered!")
66
+
67
+ if not issubclass(policy_cls, BasePolicy):
68
+ raise ValueError(f"Policy class must inherit from BasePolicy, got {policy_cls}")
69
+
70
+ cls._registry[name] = policy_cls
71
+ print(f"[PolicyRegistry] Registered policy: '{name}' -> {policy_cls.__name__}")
72
+
73
+ @classmethod
74
+ def get_policy_class(cls, name: str) -> Type[BasePolicy]:
75
+ """
76
+ Get policy class by name.
77
+
78
+ Args:
79
+ name: Policy identifier.
80
+
81
+ Returns:
82
+ Policy class.
83
+
84
+ Raises:
85
+ KeyError: If policy name not found.
86
+ """
87
+ if name not in cls._registry:
88
+ available = ", ".join(cls._registry.keys())
89
+ raise KeyError(
90
+ f"Policy '{name}' not found in registry. "
91
+ f"Available policies: {available}"
92
+ )
93
+ return cls._registry[name]
94
+
95
+ @classmethod
96
+ def create(cls, name: str, **kwargs) -> BasePolicy:
97
+ """
98
+ Create a policy instance by name.
99
+
100
+ Args:
101
+ name: Policy identifier.
102
+ **kwargs: Arguments to pass to policy constructor.
103
+
104
+ Returns:
105
+ Policy instance.
106
+ """
107
+ policy_cls = cls.get_policy_class(name)
108
+ return policy_cls(**kwargs)
109
+
110
+ @classmethod
111
+ def list_policies(cls) -> list:
112
+ """
113
+ Get list of all registered policy names.
114
+
115
+ Returns:
116
+ List of policy names.
117
+ """
118
+ return list(cls._registry.keys())
119
+
120
+ @classmethod
121
+ def is_registered(cls, name: str) -> bool:
122
+ """
123
+ Check if a policy is registered.
124
+
125
+ Args:
126
+ name: Policy identifier.
127
+
128
+ Returns:
129
+ True if registered, False otherwise.
130
+ """
131
+ return name in cls._registry
132
+
133
+ @classmethod
134
+ def clear(cls):
135
+ """Clear all registered policies (mainly for testing)."""
136
+ cls._registry.clear()