Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Base Rubric class for reward computation. | |
| Rubrics compute rewards from actions and observations. The API is modeled | |
| after PyTorch's nn.Module: users implement forward(), and the framework | |
| handles child registration and hooks. | |
| See RFC 004 for full design: rfcs/004-rubrics.md | |
| """ | |
| import inspect | |
| from abc import ABC, abstractmethod | |
| from typing import Any, Dict, Iterator, List, Optional, Tuple, Callable | |
| class Rubric(ABC): | |
| """Abstract base class for reward computation. | |
| A Rubric computes a reward signal from an action and observation. | |
| Subclasses implement forward() to define the reward logic. | |
| Usage: | |
| class MyRubric(Rubric): | |
| def forward(self, action, observation) -> float: | |
| return 1.0 if action.valid else 0.0 | |
| rubric = MyRubric() | |
| reward = rubric(action, observation) | |
| Child rubrics are auto-registered when assigned as attributes, | |
| enabling hierarchical composition and introspection. | |
| """ | |
| _rubric_children: Dict[str, "Rubric"] | |
| _forward_hooks: List[Callable] | |
| _forward_pre_hooks: List[Callable] | |
| last_score: Optional[float] | |
| def __init__(self): | |
| # Use object.__setattr__ to avoid triggering __setattr__ during init | |
| object.__setattr__(self, "_rubric_children", {}) | |
| object.__setattr__(self, "_forward_hooks", []) | |
| object.__setattr__(self, "_forward_pre_hooks", []) | |
| object.__setattr__(self, "last_score", None) | |
| def __setattr__(self, name: str, value: Any) -> None: | |
| # Auto-register child rubrics when assigned as attributes | |
| if isinstance(value, Rubric): | |
| self._rubric_children[name] = value | |
| object.__setattr__(self, name, value) | |
| def __call__(self, action: Any, observation: Any): | |
| """Evaluate the rubric with hooks. | |
| Args: | |
| action: The action taken by the agent. | |
| observation: The resulting observation. | |
| Returns: | |
| Reward value (typically 0.0 to 1.0). | |
| """ | |
| # Check if forward method is async BEFORE calling it | |
| if inspect.iscoroutinefunction(self.forward): | |
| # Async path - pre-hooks will be called in _call_async | |
| result = self.forward(action, observation) | |
| return self._call_async(action, observation, result) | |
| else: | |
| # Sync path - call pre-hooks BEFORE forward() | |
| for hook in self._forward_pre_hooks: | |
| hook(self, action, observation) | |
| result = self.forward(action, observation) | |
| return self._call_sync(action, observation, result) | |
| def _call_sync(self, action: Any, observation: Any, result: float) -> float: | |
| """Synchronous call path.""" | |
| self.last_score = result | |
| # Post-forward hooks | |
| for hook in self._forward_hooks: | |
| hook(self, action, observation, result) | |
| return result | |
| async def _call_async(self, action: Any, observation: Any, result_coro) -> float: | |
| """Asynchronous call path.""" | |
| # Pre-forward hooks | |
| for hook in self._forward_pre_hooks: | |
| if inspect.iscoroutinefunction(hook): | |
| await hook(self, action, observation) | |
| else: | |
| hook(self, action, observation) | |
| # Await the forward result | |
| result = await result_coro | |
| self.last_score = result | |
| # Post-forward hooks | |
| for hook in self._forward_hooks: | |
| if inspect.iscoroutinefunction(hook): | |
| await hook(self, action, observation, result) | |
| else: | |
| hook(self, action, observation, result) | |
| return result | |
| def forward(self, action: Any, observation: Any) -> float: | |
| """Compute the reward. Implement this in subclasses. | |
| Args: | |
| action: The action taken by the agent. | |
| observation: The resulting observation. | |
| Returns: | |
| Reward value (typically 0.0 to 1.0). | |
| """ | |
| raise NotImplementedError | |
| def register_forward_hook( | |
| self, hook: Callable[["Rubric", Any, Any, float], None] | |
| ) -> None: | |
| """Register a hook called after forward(). | |
| Args: | |
| hook: Callable with signature (rubric, action, observation, result). | |
| """ | |
| self._forward_hooks.append(hook) | |
| def register_forward_pre_hook( | |
| self, hook: Callable[["Rubric", Any, Any], None] | |
| ) -> None: | |
| """Register a hook called before forward(). | |
| Args: | |
| hook: Callable with signature (rubric, action, observation). | |
| """ | |
| self._forward_pre_hooks.append(hook) | |
| def children(self) -> Iterator["Rubric"]: | |
| """Iterate over immediate child rubrics.""" | |
| yield from self._rubric_children.values() | |
| def named_children(self) -> Iterator[Tuple[str, "Rubric"]]: | |
| """Iterate over immediate child rubrics with names.""" | |
| yield from self._rubric_children.items() | |
| def rubrics(self) -> Iterator["Rubric"]: | |
| """Iterate over all descendant rubrics (depth-first).""" | |
| for child in self._rubric_children.values(): | |
| yield child | |
| yield from child.rubrics() | |
| def named_rubrics(self, prefix: str = "") -> Iterator[Tuple[str, "Rubric"]]: | |
| """Iterate over all descendant rubrics with dot-separated names.""" | |
| for name, child in self._rubric_children.items(): | |
| full_name = f"{prefix}.{name}" if prefix else name | |
| yield full_name, child | |
| yield from child.named_rubrics(full_name) | |
| def get_rubric(self, path: str) -> "Rubric": | |
| """Access a nested rubric by dot-separated path. | |
| Args: | |
| path: Dot-separated path (e.g., "code.syntax"). | |
| Returns: | |
| The rubric at the specified path. | |
| Raises: | |
| KeyError: If the path does not exist. | |
| """ | |
| parts = path.split(".") | |
| current = self | |
| for part in parts: | |
| if part not in current._rubric_children: | |
| raise KeyError(f"Rubric path not found: {path}") | |
| current = current._rubric_children[part] | |
| return current | |
| def reset(self) -> None: | |
| """Reset any internal state. Override in subclasses if needed.""" | |
| pass | |
| def state_dict(self) -> Dict[str, Any]: | |
| """Serialize rubric configuration for checkpointing.""" | |
| return {} | |
| def load_state_dict(self, state: Dict[str, Any]) -> None: | |
| """Load rubric configuration from checkpoint.""" | |
| pass | |