dc_ops_env / rewards /reward_function.py
Melikshah's picture
Upload folder using huggingface_hub
aedaf74 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Multi-objective reward function for DC-Ops environment.
Research-informed design:
- Softplus barrier functions for safety constraints
(Google/DeepMind 2017, ICLR 2025 DC Cooling)
- Delta-based progress rewards for credit assignment
(process reward model literature)
- Normalized components in [-1, 1] via tanh
- Scenario-type-aware weight profiles
All components are bounded to [-1, 1]. Total reward is the weighted sum,
clamped to [-1, 1].
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from typing import Optional
from ..config import ASHRAE_CLASSES
from ..simulation.thermal import ThermalSimulation
from ..simulation.power import PowerSimulation
from ..simulation.types import UPSMode
from ..actions.parser import CommandResult
from ..scenarios.base import ScenarioResult
# ---------------------------------------------------------------------------
# Numerically stable softplus
# ---------------------------------------------------------------------------
def softplus(x: float) -> float:
"""Numerically stable softplus: ln(1 + exp(x)).
- x > 20: returns x (avoids exp overflow)
- x < -20: returns 0.0 (avoids underflow noise)
"""
if x > 20.0:
return x
if x < -20.0:
return 0.0
return math.log1p(math.exp(x))
# ---------------------------------------------------------------------------
# Reward components dataclass
# ---------------------------------------------------------------------------
@dataclass
class RewardComponents:
"""Individual reward components for logging and analysis."""
thermal_safety: float = 0.0
power_safety: float = 0.0
efficiency: float = 0.0
scenario_progress: float = 0.0
procedure: float = 0.0
action_quality: float = 0.0
speed_bonus: float = 0.0
total: float = 0.0
# ---------------------------------------------------------------------------
# Weight profiles
# ---------------------------------------------------------------------------
@dataclass
class RewardWeights:
"""Weights for reward components. Should sum to 1.0."""
thermal_safety: float = 0.30
power_safety: float = 0.10
efficiency: float = 0.15
scenario_progress: float = 0.25
procedure: float = 0.15
action_quality: float = 0.05
WEIGHT_PROFILES: dict[str, RewardWeights] = {
"thermal": RewardWeights(
thermal_safety=0.30,
power_safety=0.05,
efficiency=0.10,
scenario_progress=0.30,
procedure=0.20,
action_quality=0.05,
),
"power": RewardWeights(
thermal_safety=0.10,
power_safety=0.25,
efficiency=0.05,
scenario_progress=0.30,
procedure=0.25,
action_quality=0.05,
),
"default": RewardWeights(
thermal_safety=0.30,
power_safety=0.15,
efficiency=0.25,
scenario_progress=0.0,
procedure=0.0,
action_quality=0.30,
),
}
# ---------------------------------------------------------------------------
# Softplus barrier constants
# ---------------------------------------------------------------------------
# Thermal barriers
_ALPHA_RECOMMENDED = 2.0 # °C transition width at recommended limit
_ALPHA_ALLOWABLE = 1.5 # °C transition width at allowable limit
_ALLOWABLE_WEIGHT = 3.0 # Allowable violations 3x worse per degree
_THERMAL_NORM = 8.0 # Normalization so T=40°C (A2) → R≈-0.97
# Thermal safety positive baseline — small reward for being well within limits
# Based on DCRL-Green (ICLR 2025): agents learn faster with a positive signal
# for maintaining safe state, not just penalties for violations.
_SAFE_MARGIN_C = 3.0 # °C below recommended max to qualify as "safe"
_SAFE_BASELINE = 0.1 # Small positive reward when all zones safe
# Power barriers
_SOC_THRESHOLD = 0.5 # Concern increases below 50% SOC
_SOC_ALPHA = 0.15 # Sharp transition around threshold
_UPS_FAULT_PENALTY = 5.0 # Fixed penalty for UPS fault
_POWER_NORM = 4.0 # Normalization constant
# Efficiency
_PUE_NORM = 2.0 # PUE sensitivity: PUE=3.0 → R≈-0.76
# Action quality
_REPEAT_WHITELIST = frozenset({"wait", "check_status"})
# ---------------------------------------------------------------------------
# Main reward function
# ---------------------------------------------------------------------------
class RewardFunction:
"""Composable, research-informed reward function for DC operations.
Usage:
rf = RewardFunction(scenario_type="thermal")
rf.reset() # Call at episode start
# Each step:
components = rf.compute(thermal_sim, power_sim, cmd_result,
action_command, action_history, scenario_result)
reward = components.total
"""
def __init__(
self,
scenario_type: str = "default",
weights: Optional[RewardWeights] = None,
) -> None:
self._scenario_type = scenario_type
self._weights = weights or WEIGHT_PROFILES.get(
scenario_type, WEIGHT_PROFILES["default"]
)
self._prev_progress: float = 0.0
def reset(self) -> None:
"""Reset state between episodes."""
self._prev_progress = 0.0
def compute(
self,
thermal_sim: ThermalSimulation,
power_sim: Optional[PowerSimulation],
cmd_result: CommandResult,
action_command: str,
action_history: list[str],
scenario_result: Optional[ScenarioResult],
) -> RewardComponents:
"""Compute all reward components and weighted total.
Returns RewardComponents with per-component values and total.
Total is clamped to [-1, 1].
"""
r_thermal = self._thermal_safety(thermal_sim)
r_power = self._power_safety(power_sim)
r_efficiency = self._efficiency(thermal_sim, power_sim)
r_progress = self._scenario_progress(scenario_result)
r_procedure = self._procedure(scenario_result)
r_action = self._action_quality(
cmd_result, action_command, action_history, thermal_sim, power_sim,
)
w = self._weights
total = (
w.thermal_safety * r_thermal
+ w.power_safety * r_power
+ w.efficiency * r_efficiency
+ w.scenario_progress * r_progress
+ w.procedure * r_procedure
+ w.action_quality * r_action
)
total = max(-1.0, min(1.0, total))
return RewardComponents(
thermal_safety=r_thermal,
power_safety=r_power,
efficiency=r_efficiency,
scenario_progress=r_progress,
procedure=r_procedure,
action_quality=r_action,
total=total,
)
# -------------------------------------------------------------------
# Component implementations
# -------------------------------------------------------------------
@staticmethod
def _thermal_safety(thermal_sim: ThermalSimulation) -> float:
"""ASHRAE compliance via dual softplus barriers.
Returns value in [-1, _SAFE_BASELINE].
Two barriers per zone: recommended (gentle) and allowable (steep).
Averaged across zones so the signal is independent of zone count.
Positive baseline (+0.1) when ALL zones are well within safe range
(>= _SAFE_MARGIN_C below recommended max). This provides gradient
signal for maintaining good state, not just avoiding violations.
(Informed by DCRL-Green, ICLR 2025.)
"""
zones = thermal_sim.state.zones
if not zones:
return 0.0
n_zones = len(zones)
penalty = 0.0
all_safe = True
for zone in zones:
ashrae = ASHRAE_CLASSES.get(zone.ashrae_class)
if not ashrae:
continue
t = zone.max_inlet_temp_c
rec_max = ashrae.recommended_max_c
allow_max = ashrae.allowable_max_c
# Check if zone is well within safe range
if t > rec_max - _SAFE_MARGIN_C:
all_safe = False
# Soft barrier at recommended limit
penalty += softplus((t - rec_max) / _ALPHA_RECOMMENDED) / n_zones
# Harder barrier at allowable limit
penalty += (
_ALLOWABLE_WEIGHT
* softplus((t - allow_max) / _ALPHA_ALLOWABLE)
/ n_zones
)
if penalty < 1e-6 and all_safe:
return _SAFE_BASELINE
return -math.tanh(penalty / _THERMAL_NORM)
@staticmethod
def _power_safety(power_sim: Optional[PowerSimulation]) -> float:
"""UPS battery and fault condition penalty.
Returns value in [-1, 0].
Penalty compounds across multiple failing UPS units.
"""
if power_sim is None:
return 0.0
penalty = 0.0
for ups in power_sim.state.ups_units:
if ups.mode == UPSMode.ON_BATTERY:
penalty += softplus((_SOC_THRESHOLD - ups.battery_soc) / _SOC_ALPHA)
elif ups.mode == UPSMode.FAULT:
penalty += _UPS_FAULT_PENALTY
return -math.tanh(penalty / _POWER_NORM)
@staticmethod
def _efficiency(
thermal_sim: ThermalSimulation,
power_sim: Optional[PowerSimulation],
) -> float:
"""PUE-based energy efficiency penalty.
Returns value in [-1, 0].
PUE 1.0 (ideal) → 0, PUE 2.0 → -0.46, PUE 3.0 → -0.76.
During power emergencies (UPS on battery), efficiency is suppressed
to zero — the agent should not be penalized for load shedding that
increases PUE but correctly preserves battery life.
"""
# Suppress efficiency signal during power emergencies
if power_sim is not None:
for ups in power_sim.state.ups_units:
if ups.mode in (UPSMode.ON_BATTERY, UPSMode.FAULT):
return 0.0
pue = thermal_sim.state.pue
return -math.tanh((pue - 1.0) / _PUE_NORM)
def _scenario_progress(self, scenario_result: Optional[ScenarioResult]) -> float:
"""Delta-based progress toward scenario resolution.
Returns value in [-1, 1].
Rewards the CHANGE in progress — gives credit to the action that
actually caused forward progress.
"""
if scenario_result is None:
return 0.0
current = scenario_result.progress
delta = current - self._prev_progress
self._prev_progress = current
return max(-1.0, min(1.0, delta))
@staticmethod
def _procedure(scenario_result: Optional[ScenarioResult]) -> float:
"""Procedural correctness from scenario rules.
Returns value in [-1, 1].
"""
if scenario_result is None:
return 0.0
return max(-1.0, min(1.0, scenario_result.procedure_reward))
@staticmethod
def _action_quality(
cmd_result: CommandResult,
action_command: str,
action_history: list[str],
thermal_sim: ThermalSimulation,
power_sim: Optional[PowerSimulation],
) -> float:
"""Action quality assessment.
Returns value in [-1, 1].
Considers: validity, repetition, action type, urgency context.
"""
if not cmd_result.success:
return -0.5
cmd_lower = action_command.strip().lower()
name = cmd_result.command_name
# Check for exact repeated command — but whitelist commands that
# are legitimately repeatable (wait, check_status).
if name not in _REPEAT_WHITELIST:
prior = (
[h.strip().lower() for h in action_history[:-1]]
if len(action_history) > 1
else []
)
if cmd_lower in prior:
return -0.2
# "wait" quality depends on whether there's an active concern
if name == "wait":
if _has_active_concern(thermal_sim, power_sim):
# Waiting during a power event where we're waiting for
# generator startup is acceptable — check if generator
# is in startup sequence.
if power_sim is not None and _generator_starting(power_sim):
return 0.1 # Waiting for gen to warm up is reasonable
return -0.2 # Waiting during a thermal problem
return 0.0 # Nothing wrong, waiting is fine
# Information-gathering actions are valuable
if name in ("diagnose", "check_status"):
return 0.3
# Active interventions
if name in (
"adjust_setpoint", "set_fan_speed", "set_rack_load",
"migrate_workload", "start_generator", "stop_generator",
"set_ups_mode", "start_crac", "stop_crac", "refuel_generator",
):
return 0.2
# Administrative
if name == "acknowledge_alarm":
return 0.1
# Escalation — handled solely by scenario procedure rules now,
# no extra penalty here. The environment no longer double-penalizes.
if name == "escalate":
return -0.1
return 0.1 # Other valid commands
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _has_active_concern(
thermal_sim: ThermalSimulation,
power_sim: Optional[PowerSimulation],
) -> bool:
"""Check if there is an active thermal or power concern."""
for zone in thermal_sim.state.zones:
ashrae = ASHRAE_CLASSES.get(zone.ashrae_class)
if ashrae and zone.max_inlet_temp_c > ashrae.recommended_max_c:
return True
if power_sim:
for ups in power_sim.state.ups_units:
if ups.mode == UPSMode.ON_BATTERY:
return True
return False
def _generator_starting(power_sim: PowerSimulation) -> bool:
"""Check if the generator is in a startup sequence (agent should wait)."""
from ..simulation.types import GeneratorState
return power_sim.state.generator.state in (
GeneratorState.START_DELAY,
GeneratorState.CRANKING,
GeneratorState.WARMING,
)