| """Reward function — 7 components, separate from graders. |
| |
| Returns a float per step for RL training signal. Hard cap at [-1.0, 1.0]. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
|
|
| from ml_training_debugger.models import EpisodeState, MLTrainingAction |
| from ml_training_debugger.scenarios import ScenarioParams |
|
|
| |
| STEP_PENALTY = -0.01 |
| INVESTIGATION_BONUS = 0.05 |
| CONTEXT_GATED_PENALTY = -0.20 |
| INVALID_ACTION_PENALTY = -0.05 |
| WRONG_CODE_FIX_PENALTY = -0.10 |
| CORRECT_DIAGNOSIS_REWARD = 0.50 |
| WRONG_DIAGNOSIS_PENALTY = -0.30 |
| TERMINAL_CONVERGENCE_REWARD = 0.40 |
|
|
| INVESTIGATION_ACTIONS = frozenset( |
| { |
| "inspect_gradients", |
| "inspect_data_batch", |
| "inspect_model_modes", |
| "inspect_model_weights", |
| "inspect_code", |
| } |
| ) |
|
|
| _INSPECTION_STATE_MAP = { |
| "inspect_gradients": "gradients_inspected", |
| "inspect_data_batch": "data_inspected", |
| "inspect_model_modes": "model_modes_inspected", |
| "inspect_model_weights": "model_weights_inspected", |
| "inspect_code": "code_inspected", |
| } |
|
|
|
|
| def compute_reward( |
| action: MLTrainingAction, |
| state: EpisodeState, |
| scenario: ScenarioParams, |
| is_valid_action: bool = True, |
| is_correct_fix: bool | None = None, |
| convergence_confirmed: bool = False, |
| ) -> float: |
| """Compute reward for a single step. |
| |
| Args: |
| action: The action taken. |
| state: Episode state BEFORE the action is applied. |
| scenario: Current scenario params. |
| is_valid_action: Whether the action is in available_actions. |
| is_correct_fix: For fix_code — True/False/None. |
| convergence_confirmed: Whether restart showed convergence. |
| |
| Returns: |
| Reward float, capped at [-1.0, 1.0]. |
| """ |
| reward = 0.0 |
|
|
| |
| reward += STEP_PENALTY |
|
|
| if not is_valid_action: |
| reward += INVALID_ACTION_PENALTY |
| return max(-1.0, min(1.0, reward)) |
|
|
| action_type = action.action_type |
|
|
| |
| if action_type in INVESTIGATION_ACTIONS: |
| state_field = _INSPECTION_STATE_MAP.get(action_type) |
| if state_field and not getattr(state, state_field): |
| reward += INVESTIGATION_BONUS |
|
|
| |
| if action_type == "add_callback": |
| if state.gradients_inspected and state.gradients_were_normal: |
| reward += CONTEXT_GATED_PENALTY |
|
|
| |
| if action_type == "fix_code" and is_correct_fix is False: |
| reward += WRONG_CODE_FIX_PENALTY |
|
|
| |
| if action_type == "mark_diagnosed": |
| if action.diagnosis == scenario.root_cause.value: |
| reward += CORRECT_DIAGNOSIS_REWARD |
| else: |
| reward += WRONG_DIAGNOSIS_PENALTY |
|
|
| |
| if action_type == "restart_run": |
| if state.fix_action_taken and convergence_confirmed: |
| reward += TERMINAL_CONVERGENCE_REWARD |
|
|
| return max(-1.0, min(1.0, reward)) |
|
|