File size: 7,025 Bytes
6be6d8e
 
206438f
6be6d8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""Rule-based heuristic baseline for the /baseline endpoint.

Extracted from app.py to keep route definitions clean.
"""

from __future__ import annotations

from ml_training_debugger.models import MLTrainingAction
from server.environment import MLTrainingEnvironment

ALL_TASK_IDS = [
    "task_001",
    "task_002",
    "task_003",
    "task_004",
    "task_005",
    "task_006",
    "task_007",
]


def run_baseline_all_tasks() -> dict[str, float]:
    """Run the rule-based baseline on all tasks. Returns {task_id: score}."""
    scores: dict[str, float] = {}
    for task_id in ALL_TASK_IDS:
        env = MLTrainingEnvironment()
        env.reset(seed=42, episode_id=f"baseline_{task_id}", task_id=task_id)
        scores[task_id] = round(_run_heuristic_episode(env), 4)
    return scores


def _run_heuristic_episode(
    env: MLTrainingEnvironment, task_id: str = "",
) -> float:
    """Run one heuristic baseline episode. Returns grader score."""
    # Step 1: inspect_gradients
    obs = env.step(MLTrainingAction(action_type="inspect_gradients"))

    if obs.gradient_stats:
        if any(g.is_exploding for g in obs.gradient_stats):
            env.step(MLTrainingAction(
                action_type="modify_config", target="learning_rate", value=0.001,
            ))
            env.step(MLTrainingAction(action_type="restart_run"))
            env.step(MLTrainingAction(
                action_type="mark_diagnosed", diagnosis="lr_too_high",
            ))
            return _get_score(env)

        if any(g.is_vanishing for g in obs.gradient_stats):
            env.step(MLTrainingAction(
                action_type="modify_config", target="learning_rate", value=0.01,
            ))
            env.step(MLTrainingAction(action_type="restart_run"))
            env.step(MLTrainingAction(
                action_type="mark_diagnosed", diagnosis="vanishing_gradients",
            ))
            return _get_score(env)

    # Step 2: inspect_data_batch
    obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
    if obs.data_batch_stats and obs.data_batch_stats.class_overlap_score > 0.5:
        env.step(MLTrainingAction(action_type="patch_data_loader"))
        env.step(MLTrainingAction(action_type="restart_run"))
        env.step(MLTrainingAction(
            action_type="mark_diagnosed", diagnosis="data_leakage",
        ))
        return _get_score(env)

    # Detect overfitting pattern
    looks_like_overfitting = _detect_overfitting(obs)

    # Step 3: inspect_model_modes
    obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
    if obs.model_mode_info:
        if any(v == "eval" for v in obs.model_mode_info.values()):
            env.step(MLTrainingAction(action_type="fix_model_mode"))
            env.step(MLTrainingAction(action_type="restart_run"))
            env.step(MLTrainingAction(
                action_type="mark_diagnosed", diagnosis="batchnorm_eval_mode",
            ))
            return _get_score(env)

    # Step 4: inspect_code (for Task 6)
    obs = env.step(MLTrainingAction(action_type="inspect_code"))
    if obs.code_snippet:
        code = obs.code_snippet.code
        _try_code_fix(env, code)

        session = env._get_session()
        if session and session.state.fix_action_taken:
            env.step(MLTrainingAction(action_type="restart_run"))

        env.step(MLTrainingAction(
            action_type="mark_diagnosed", diagnosis="code_bug",
        ))
        return _get_score(env)

    # Step 5: scheduler issue (loss stagnates after initial progress)
    if _detect_scheduler_issue(obs):
        env.step(MLTrainingAction(
            action_type="modify_config", target="learning_rate", value=0.001,
        ))
        env.step(MLTrainingAction(action_type="restart_run"))
        env.step(MLTrainingAction(
            action_type="mark_diagnosed", diagnosis="scheduler_misconfigured",
        ))
        return _get_score(env)

    # Overfitting fallback
    if looks_like_overfitting:
        env.step(MLTrainingAction(
            action_type="modify_config", target="weight_decay", value=0.01,
        ))
        env.step(MLTrainingAction(action_type="restart_run"))
        env.step(MLTrainingAction(
            action_type="mark_diagnosed", diagnosis="overfitting",
        ))
        return _get_score(env)

    # Final fallback
    env.step(MLTrainingAction(
        action_type="mark_diagnosed", diagnosis="overfitting",
    ))
    return _get_score(env)


def _try_code_fix(env: MLTrainingEnvironment, code: str) -> None:
    """Attempt to fix a detected code bug."""
    if "model.eval()" in code and "model.train()" not in code:
        env.step(MLTrainingAction(
            action_type="fix_code", line=5, replacement="model.train()",
        ))
    elif ".detach()" in code:
        env.step(MLTrainingAction(
            action_type="fix_code", line=14,
            replacement="        loss = criterion(output, batch_y)",
        ))
    elif "inplace=True" in code:
        env.step(MLTrainingAction(
            action_type="fix_code", line=15,
            replacement="        output = F.relu(output)",
        ))
    elif "optimizer.zero_grad()" not in code and "optimizer.step()" in code:
        env.step(MLTrainingAction(
            action_type="fix_code", line=11,
            replacement="        optimizer.zero_grad()",
        ))


def _detect_overfitting(obs: object) -> bool:
    """Detect overfitting pattern from observation."""
    if not (obs.val_loss_history and obs.training_loss_history
            and len(obs.val_loss_history) >= 10):
        return False
    early_train = sum(obs.training_loss_history[:5]) / 5
    late_train = sum(obs.training_loss_history[-5:]) / 5
    early_val = sum(obs.val_loss_history[:5]) / 5
    late_val = sum(obs.val_loss_history[-5:]) / 5
    train_dropped = late_train < early_train * 0.5
    train_loss_low = late_train < 0.15
    val_not_improving = late_val >= early_val * 0.95
    gap_widening = (late_val - late_train) > (early_val - early_train)
    return (
        (train_dropped or train_loss_low)
        and (val_not_improving or gap_widening)
        and obs.data_batch_stats
        and obs.data_batch_stats.class_overlap_score < 0.3
    )


def _detect_scheduler_issue(obs: object) -> bool:
    """Detect scheduler misconfiguration from loss history."""
    if not (obs.training_loss_history and len(obs.training_loss_history) >= 10):
        return False
    early_loss = sum(obs.training_loss_history[:3]) / 3
    mid_loss = sum(obs.training_loss_history[5:8]) / 3
    finite_late = [v for v in obs.training_loss_history[-3:] if v != float("inf")]
    late_loss = sum(finite_late) / max(len(finite_late), 1)
    return early_loss > mid_loss and abs(late_loss - mid_loss) < 0.3


def _get_score(env: MLTrainingEnvironment) -> float:
    """Extract the grader score from the environment."""
    session = env._get_session()
    if session and session.last_score is not None:
        return session.last_score
    return 0.0