Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| """Base class for modular reward components.""" | |
| from abc import ABC, abstractmethod | |
| from typing import Any, Dict | |
| class BaseRewardComponent(ABC): | |
| """ | |
| Abstract base class for reward components. | |
| Each reward component calculates a specific aspect of the reward signal | |
| (exploration, badges, levels, etc.) and can be enabled/weighted via config. | |
| Attributes: | |
| name: Unique identifier for this reward component. | |
| weight: Scaling factor applied to calculated reward. | |
| enabled: Whether this component is active. | |
| Example: | |
| >>> class MyReward(BaseRewardComponent): | |
| ... def calculate(self, state, prev_state): | |
| ... return state["score"] - prev_state.get("score", 0) | |
| >>> | |
| >>> reward = MyReward(weight=2.0) | |
| >>> reward.calculate({"score": 100}, {"score": 50}) | |
| 100.0 # (100 - 50) * 2.0 | |
| """ | |
| def __init__(self, name: str, weight: float = 1.0, enabled: bool = True): | |
| """ | |
| Initialize reward component. | |
| Args: | |
| name: Unique identifier for this component. | |
| weight: Scaling factor for reward (default 1.0). | |
| enabled: Whether component is active (default True). | |
| """ | |
| self.name = name | |
| self.weight = weight | |
| self.enabled = enabled | |
| self._cumulative = 0.0 | |
| def calculate( | |
| self, state: Dict[str, Any], prev_state: Dict[str, Any] | |
| ) -> float: | |
| """ | |
| Calculate reward for this component. | |
| Args: | |
| state: Current game state dictionary. | |
| prev_state: Previous game state dictionary. | |
| Returns: | |
| Unweighted reward value for this component. | |
| """ | |
| pass | |
| def get_reward( | |
| self, state: Dict[str, Any], prev_state: Dict[str, Any] | |
| ) -> float: | |
| """ | |
| Get weighted reward if enabled. | |
| Args: | |
| state: Current game state. | |
| prev_state: Previous game state. | |
| Returns: | |
| Weighted reward if enabled, 0.0 otherwise. | |
| """ | |
| if not self.enabled: | |
| return 0.0 | |
| reward = self.calculate(state, prev_state) * self.weight | |
| self._cumulative += reward | |
| return reward | |
| def reset(self) -> None: | |
| """Reset component state for new episode.""" | |
| self._cumulative = 0.0 | |
| def cumulative_reward(self) -> float: | |
| """Total reward from this component in current episode.""" | |
| return self._cumulative | |