File size: 6,746 Bytes
edc5871
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Base Rubric class for reward computation.

Rubrics compute rewards from actions and observations. The API is modeled
after PyTorch's nn.Module: users implement forward(), and the framework
handles child registration and hooks.

See RFC 004 for full design: rfcs/004-rubrics.md
"""

import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterator, List, Optional, Tuple, Callable


class Rubric(ABC):
    """Abstract base class for reward computation.

    A Rubric computes a reward signal from an action and observation.
    Subclasses implement forward() to define the reward logic.

    Usage:
        class MyRubric(Rubric):
            def forward(self, action, observation) -> float:
                return 1.0 if action.valid else 0.0

        rubric = MyRubric()
        reward = rubric(action, observation)

    Child rubrics are auto-registered when assigned as attributes,
    enabling hierarchical composition and introspection.
    """

    _rubric_children: Dict[str, "Rubric"]
    _forward_hooks: List[Callable]
    _forward_pre_hooks: List[Callable]
    last_score: Optional[float]

    def __init__(self):
        # Use object.__setattr__ to avoid triggering __setattr__ during init
        object.__setattr__(self, "_rubric_children", {})
        object.__setattr__(self, "_forward_hooks", [])
        object.__setattr__(self, "_forward_pre_hooks", [])
        object.__setattr__(self, "last_score", None)

    def __setattr__(self, name: str, value: Any) -> None:
        # Auto-register child rubrics when assigned as attributes
        if isinstance(value, Rubric):
            self._rubric_children[name] = value
        object.__setattr__(self, name, value)

    def __call__(self, action: Any, observation: Any):
        """Evaluate the rubric with hooks.

        Args:
            action: The action taken by the agent.
            observation: The resulting observation.

        Returns:
            Reward value (typically 0.0 to 1.0).
        """
        # Check if forward method is async BEFORE calling it
        if inspect.iscoroutinefunction(self.forward):
            # Async path - pre-hooks will be called in _call_async
            result = self.forward(action, observation)
            return self._call_async(action, observation, result)
        else:
            # Sync path - call pre-hooks BEFORE forward()
            for hook in self._forward_pre_hooks:
                hook(self, action, observation)
            result = self.forward(action, observation)
            return self._call_sync(action, observation, result)

    def _call_sync(self, action: Any, observation: Any, result: float) -> float:
        """Synchronous call path."""
        self.last_score = result

        # Post-forward hooks
        for hook in self._forward_hooks:
            hook(self, action, observation, result)

        return result

    async def _call_async(self, action: Any, observation: Any, result_coro) -> float:
        """Asynchronous call path."""
        # Pre-forward hooks
        for hook in self._forward_pre_hooks:
            if inspect.iscoroutinefunction(hook):
                await hook(self, action, observation)
            else:
                hook(self, action, observation)

        # Await the forward result
        result = await result_coro
        self.last_score = result

        # Post-forward hooks
        for hook in self._forward_hooks:
            if inspect.iscoroutinefunction(hook):
                await hook(self, action, observation, result)
            else:
                hook(self, action, observation, result)

        return result

    @abstractmethod
    def forward(self, action: Any, observation: Any) -> float:
        """Compute the reward. Implement this in subclasses.

        Args:
            action: The action taken by the agent.
            observation: The resulting observation.

        Returns:
            Reward value (typically 0.0 to 1.0).
        """
        raise NotImplementedError

    def register_forward_hook(
        self, hook: Callable[["Rubric", Any, Any, float], None]
    ) -> None:
        """Register a hook called after forward().

        Args:
            hook: Callable with signature (rubric, action, observation, result).
        """
        self._forward_hooks.append(hook)

    def register_forward_pre_hook(
        self, hook: Callable[["Rubric", Any, Any], None]
    ) -> None:
        """Register a hook called before forward().

        Args:
            hook: Callable with signature (rubric, action, observation).
        """
        self._forward_pre_hooks.append(hook)

    def children(self) -> Iterator["Rubric"]:
        """Iterate over immediate child rubrics."""
        yield from self._rubric_children.values()

    def named_children(self) -> Iterator[Tuple[str, "Rubric"]]:
        """Iterate over immediate child rubrics with names."""
        yield from self._rubric_children.items()

    def rubrics(self) -> Iterator["Rubric"]:
        """Iterate over all descendant rubrics (depth-first)."""
        for child in self._rubric_children.values():
            yield child
            yield from child.rubrics()

    def named_rubrics(self, prefix: str = "") -> Iterator[Tuple[str, "Rubric"]]:
        """Iterate over all descendant rubrics with dot-separated names."""
        for name, child in self._rubric_children.items():
            full_name = f"{prefix}.{name}" if prefix else name
            yield full_name, child
            yield from child.named_rubrics(full_name)

    def get_rubric(self, path: str) -> "Rubric":
        """Access a nested rubric by dot-separated path.

        Args:
            path: Dot-separated path (e.g., "code.syntax").

        Returns:
            The rubric at the specified path.

        Raises:
            KeyError: If the path does not exist.
        """
        parts = path.split(".")
        current = self
        for part in parts:
            if part not in current._rubric_children:
                raise KeyError(f"Rubric path not found: {path}")
            current = current._rubric_children[part]
        return current

    def reset(self) -> None:
        """Reset any internal state. Override in subclasses if needed."""
        pass

    def state_dict(self) -> Dict[str, Any]:
        """Serialize rubric configuration for checkpointing."""
        return {}

    def load_state_dict(self, state: Dict[str, Any]) -> None:
        """Load rubric configuration from checkpoint."""
        pass