# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. """Reward manager for aggregating modular reward components.""" from typing import Any, Dict, List, Optional, Type from .base import BaseRewardComponent class RewardManager: """ Manages multiple reward components and aggregates their outputs. Supports dynamic registration, per-component weighting, and detailed reward breakdowns for debugging/analysis. Attributes: components: List of registered reward components. global_scale: Global scaling factor applied to total reward. Example: >>> manager = RewardManager(global_scale=1.0) >>> manager.register(ExplorationReward(weight=2.0)) >>> manager.register(BadgeReward(weight=5.0)) >>> >>> reward = manager.calculate(current_state, prev_state) >>> breakdown = manager.get_breakdown() """ def __init__(self, global_scale: float = 1.0): """ Initialize reward manager. Args: global_scale: Global scaling factor for total reward. """ self.components: List[BaseRewardComponent] = [] self.global_scale = global_scale self._last_breakdown: Dict[str, float] = {} def register(self, component: BaseRewardComponent) -> "RewardManager": """ Register a reward component. Args: component: Reward component to register. Returns: Self for method chaining. """ self.components.append(component) return self def register_defaults(self, config: Optional[Dict[str, Any]] = None) -> "RewardManager": """ Register default reward components with optional config. Args: config: Optional config dict with component weights. Returns: Self for method chaining. """ from .exploration import ExplorationReward from .badge import BadgeReward from .level import LevelUpReward from .event import EventReward config = config or {} self.register(ExplorationReward( weight=config.get("exploration_weight", 0.02) )) self.register(BadgeReward( weight=config.get("badge_weight", 5.0) )) self.register(LevelUpReward( weight=config.get("level_weight", 1.0) )) self.register(EventReward( weight=config.get("event_weight", 0.1) )) return self def calculate( self, state: Dict[str, Any], prev_state: Dict[str, Any] ) -> float: """ Calculate total reward from all components. Args: state: Current game state. prev_state: Previous game state. Returns: Total weighted and scaled reward. """ self._last_breakdown = {} total = 0.0 for component in self.components: reward = component.get_reward(state, prev_state) self._last_breakdown[component.name] = reward total += reward return total * self.global_scale def reset(self) -> None: """Reset all components for new episode.""" self._last_breakdown = {} for component in self.components: component.reset() def get_breakdown(self) -> Dict[str, float]: """Get reward breakdown by component from last calculation.""" return self._last_breakdown.copy() def get_cumulative_breakdown(self) -> Dict[str, float]: """Get cumulative rewards by component for current episode.""" return { comp.name: comp.cumulative_reward for comp in self.components }