omkarrr88
minor changes
206438f
"""Reward function — 7 components, separate from graders.
Returns a float per step for RL training signal. Hard cap at [-1.0, 1.0].
"""
from __future__ import annotations
import torch # noqa: F401
from ml_training_debugger.models import EpisodeState, MLTrainingAction
from ml_training_debugger.scenarios import ScenarioParams
# Reward constants
STEP_PENALTY = -0.01
INVESTIGATION_BONUS = 0.05
CONTEXT_GATED_PENALTY = -0.20
INVALID_ACTION_PENALTY = -0.05
WRONG_CODE_FIX_PENALTY = -0.10
CORRECT_DIAGNOSIS_REWARD = 0.50
WRONG_DIAGNOSIS_PENALTY = -0.30
TERMINAL_CONVERGENCE_REWARD = 0.40
INVESTIGATION_ACTIONS = frozenset(
{
"inspect_gradients",
"inspect_data_batch",
"inspect_model_modes",
"inspect_model_weights",
"inspect_code",
}
)
_INSPECTION_STATE_MAP = {
"inspect_gradients": "gradients_inspected",
"inspect_data_batch": "data_inspected",
"inspect_model_modes": "model_modes_inspected",
"inspect_model_weights": "model_weights_inspected",
"inspect_code": "code_inspected",
}
def compute_reward(
action: MLTrainingAction,
state: EpisodeState,
scenario: ScenarioParams,
is_valid_action: bool = True,
is_correct_fix: bool | None = None,
convergence_confirmed: bool = False,
) -> float:
"""Compute reward for a single step.
Args:
action: The action taken.
state: Episode state BEFORE the action is applied.
scenario: Current scenario params.
is_valid_action: Whether the action is in available_actions.
is_correct_fix: For fix_code — True/False/None.
convergence_confirmed: Whether restart showed convergence.
Returns:
Reward float, capped at [-1.0, 1.0].
"""
reward = 0.0
# Step penalty (unconditional)
reward += STEP_PENALTY
if not is_valid_action:
reward += INVALID_ACTION_PENALTY
return max(-1.0, min(1.0, reward))
action_type = action.action_type
# Investigation bonus — first-time only
if action_type in INVESTIGATION_ACTIONS:
state_field = _INSPECTION_STATE_MAP.get(action_type)
if state_field and not getattr(state, state_field):
reward += INVESTIGATION_BONUS
# Context-gated penalty: adding gradient clipping after seeing normal gradients
if action_type == "add_callback":
if state.gradients_inspected and state.gradients_were_normal:
reward += CONTEXT_GATED_PENALTY
# Wrong code fix
if action_type == "fix_code" and is_correct_fix is False:
reward += WRONG_CODE_FIX_PENALTY
# Diagnosis
if action_type == "mark_diagnosed":
if action.diagnosis == scenario.root_cause.value:
reward += CORRECT_DIAGNOSIS_REWARD
else:
reward += WRONG_DIAGNOSIS_PENALTY
# Convergence after fix+restart
if action_type == "restart_run":
if state.fix_action_taken and convergence_confirmed:
reward += TERMINAL_CONVERGENCE_REWARD
return max(-1.0, min(1.0, reward))