File size: 3,034 Bytes
206438f
e2f8b29
206438f
e2f8b29
 
 
 
206438f
e2f8b29
 
 
 
206438f
e2f8b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206438f
e2f8b29
 
 
 
 
 
 
 
206438f
e2f8b29
 
 
 
 
206438f
e2f8b29
 
 
 
206438f
e2f8b29
 
 
206438f
e2f8b29
 
 
 
 
 
206438f
e2f8b29
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Reward function — 7 components, separate from graders.

Returns a float per step for RL training signal. Hard cap at [-1.0, 1.0].
"""

from __future__ import annotations

import torch  # noqa: F401

from ml_training_debugger.models import EpisodeState, MLTrainingAction
from ml_training_debugger.scenarios import ScenarioParams

# Reward constants
STEP_PENALTY = -0.01
INVESTIGATION_BONUS = 0.05
CONTEXT_GATED_PENALTY = -0.20
INVALID_ACTION_PENALTY = -0.05
WRONG_CODE_FIX_PENALTY = -0.10
CORRECT_DIAGNOSIS_REWARD = 0.50
WRONG_DIAGNOSIS_PENALTY = -0.30
TERMINAL_CONVERGENCE_REWARD = 0.40

INVESTIGATION_ACTIONS = frozenset(
    {
        "inspect_gradients",
        "inspect_data_batch",
        "inspect_model_modes",
        "inspect_model_weights",
        "inspect_code",
    }
)

_INSPECTION_STATE_MAP = {
    "inspect_gradients": "gradients_inspected",
    "inspect_data_batch": "data_inspected",
    "inspect_model_modes": "model_modes_inspected",
    "inspect_model_weights": "model_weights_inspected",
    "inspect_code": "code_inspected",
}


def compute_reward(
    action: MLTrainingAction,
    state: EpisodeState,
    scenario: ScenarioParams,
    is_valid_action: bool = True,
    is_correct_fix: bool | None = None,
    convergence_confirmed: bool = False,
) -> float:
    """Compute reward for a single step.

    Args:
        action: The action taken.
        state: Episode state BEFORE the action is applied.
        scenario: Current scenario params.
        is_valid_action: Whether the action is in available_actions.
        is_correct_fix: For fix_code — True/False/None.
        convergence_confirmed: Whether restart showed convergence.

    Returns:
        Reward float, capped at [-1.0, 1.0].
    """
    reward = 0.0

    # Step penalty (unconditional)
    reward += STEP_PENALTY

    if not is_valid_action:
        reward += INVALID_ACTION_PENALTY
        return max(-1.0, min(1.0, reward))

    action_type = action.action_type

    # Investigation bonus — first-time only
    if action_type in INVESTIGATION_ACTIONS:
        state_field = _INSPECTION_STATE_MAP.get(action_type)
        if state_field and not getattr(state, state_field):
            reward += INVESTIGATION_BONUS

    # Context-gated penalty: adding gradient clipping after seeing normal gradients
    if action_type == "add_callback":
        if state.gradients_inspected and state.gradients_were_normal:
            reward += CONTEXT_GATED_PENALTY

    # Wrong code fix
    if action_type == "fix_code" and is_correct_fix is False:
        reward += WRONG_CODE_FIX_PENALTY

    # Diagnosis
    if action_type == "mark_diagnosed":
        if action.diagnosis == scenario.root_cause.value:
            reward += CORRECT_DIAGNOSIS_REWARD
        else:
            reward += WRONG_DIAGNOSIS_PENALTY

    # Convergence after fix+restart
    if action_type == "restart_run":
        if state.fix_action_taken and convergence_confirmed:
            reward += TERMINAL_CONVERGENCE_REWARD

    return max(-1.0, min(1.0, reward))