File size: 7,742 Bytes
82996ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
#!/usr/bin/env python3
"""Demo script — runs the heuristic baseline against all 7 tasks and prints results.

No API key required. Deterministic, bit-exact reproducible.

Usage:
    python demo.py
"""

from __future__ import annotations

import json
import sys
import time

from ml_training_debugger.models import MLTrainingAction
from server.environment import MLTrainingEnvironment

ALL_TASKS = [
    ("task_001", "Exploding Gradients", "easy"),
    ("task_002", "Vanishing Gradients", "easy"),
    ("task_003", "Data Leakage", "medium"),
    ("task_004", "Overfitting", "medium"),
    ("task_005", "BatchNorm Eval Mode", "hard"),
    ("task_006", "Code Bug", "hard"),
    ("task_007", "Scheduler Misconfigured", "medium-hard"),
]


def run_demo_episode(task_id: str, seed: int = 42) -> tuple[float, list[str], float]:
    """Run one heuristic episode. Returns (score, actions_taken, elapsed_seconds)."""
    env = MLTrainingEnvironment()
    start = time.time()
    obs = env.reset(seed=seed, episode_id=f"demo_{task_id}", task_id=task_id)

    # Step 1: inspect_gradients
    obs = env.step(MLTrainingAction(action_type="inspect_gradients"))

    if obs.gradient_stats:
        if any(g.is_exploding for g in obs.gradient_stats):
            env.step(MLTrainingAction(action_type="modify_config", target="learning_rate", value=0.001))
            env.step(MLTrainingAction(action_type="restart_run"))
            env.step(MLTrainingAction(action_type="mark_diagnosed", diagnosis="lr_too_high"))
            session = env._get_session()
            return (session.last_score or 0.0, session.state.actions_taken, time.time() - start)

        if any(g.is_vanishing for g in obs.gradient_stats):
            env.step(MLTrainingAction(action_type="modify_config", target="learning_rate", value=0.01))
            env.step(MLTrainingAction(action_type="restart_run"))
            env.step(MLTrainingAction(action_type="mark_diagnosed", diagnosis="vanishing_gradients"))
            session = env._get_session()
            return (session.last_score or 0.0, session.state.actions_taken, time.time() - start)

    # Step 2: inspect_data_batch
    obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
    if obs.data_batch_stats and obs.data_batch_stats.class_overlap_score > 0.5:
        env.step(MLTrainingAction(action_type="patch_data_loader"))
        env.step(MLTrainingAction(action_type="restart_run"))
        env.step(MLTrainingAction(action_type="mark_diagnosed", diagnosis="data_leakage"))
        session = env._get_session()
        return (session.last_score or 0.0, session.state.actions_taken, time.time() - start)

    # Detect overfitting pattern (defer action until after code check)
    _looks_like_overfitting = False
    if obs.val_loss_history and obs.training_loss_history and len(obs.val_loss_history) >= 10:
        early_train = sum(obs.training_loss_history[:5]) / 5
        late_train = sum(obs.training_loss_history[-5:]) / 5
        early_val = sum(obs.val_loss_history[:5]) / 5
        late_val = sum(obs.val_loss_history[-5:]) / 5
        train_dropped = late_train < early_train * 0.5
        train_loss_low = late_train < 0.15
        val_not_improving = late_val >= early_val * 0.95
        gap_widening = (late_val - late_train) > (early_val - early_train)
        if (train_dropped or train_loss_low) and (val_not_improving or gap_widening):
            if obs.data_batch_stats and obs.data_batch_stats.class_overlap_score < 0.3:
                _looks_like_overfitting = True

    # Step 3: inspect_model_modes
    obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
    if obs.model_mode_info and any(v == "eval" for v in obs.model_mode_info.values()):
        env.step(MLTrainingAction(action_type="fix_model_mode"))
        env.step(MLTrainingAction(action_type="restart_run"))
        env.step(MLTrainingAction(action_type="mark_diagnosed", diagnosis="batchnorm_eval_mode"))
        session = env._get_session()
        return (session.last_score or 0.0, session.state.actions_taken, time.time() - start)

    # Step 4: inspect_code
    obs = env.step(MLTrainingAction(action_type="inspect_code"))
    if obs.code_snippet:
        code = obs.code_snippet.code
        fixed = False
        if "model.eval()" in code and "model.train()" not in code:
            env.step(MLTrainingAction(action_type="fix_code", line=5, replacement="model.train()"))
            fixed = True
        elif ".detach()" in code:
            env.step(MLTrainingAction(action_type="fix_code", line=14, replacement="        loss = criterion(output, batch_y)"))
            fixed = True

        if fixed:
            env.step(MLTrainingAction(action_type="restart_run"))
        env.step(MLTrainingAction(action_type="mark_diagnosed", diagnosis="code_bug"))
        session = env._get_session()
        return (session.last_score or 0.0, session.state.actions_taken, time.time() - start)

    # Step 5: scheduler check
    if obs.training_loss_history and len(obs.training_loss_history) >= 10:
        early_loss = sum(obs.training_loss_history[:3]) / 3
        mid_loss = sum(obs.training_loss_history[5:8]) / 3
        finite_late = [v for v in obs.training_loss_history[-3:] if v != float("inf")]
        late_loss = sum(finite_late) / max(len(finite_late), 1)
        if early_loss > mid_loss and abs(late_loss - mid_loss) < 0.3:
            env.step(MLTrainingAction(action_type="modify_config", target="learning_rate", value=0.001))
            env.step(MLTrainingAction(action_type="restart_run"))
            env.step(MLTrainingAction(action_type="mark_diagnosed", diagnosis="scheduler_misconfigured"))
            session = env._get_session()
            return (session.last_score or 0.0, session.state.actions_taken, time.time() - start)

    # Overfitting fallback
    if _looks_like_overfitting:
        env.step(MLTrainingAction(action_type="modify_config", target="weight_decay", value=0.01))
        env.step(MLTrainingAction(action_type="restart_run"))
        env.step(MLTrainingAction(action_type="mark_diagnosed", diagnosis="overfitting"))
        session = env._get_session()
        return (session.last_score or 0.0, session.state.actions_taken, time.time() - start)

    # Final fallback
    env.step(MLTrainingAction(action_type="mark_diagnosed", diagnosis="overfitting"))
    session = env._get_session()
    return (session.last_score or 0.0, session.state.actions_taken, time.time() - start)


def main() -> None:
    print("=" * 70)
    print("PyTorch Training Run Debugger — Demo")
    print("Running heuristic baseline on all 7 tasks")
    print("=" * 70)
    print()

    total_score = 0.0
    results = []

    for task_id, name, difficulty in ALL_TASKS:
        score, actions, elapsed = run_demo_episode(task_id)
        total_score += score
        results.append((task_id, name, difficulty, score, actions, elapsed))

        diagnosis = "none"
        for a in actions:
            if a.startswith("mark_diagnosed:"):
                diagnosis = a.split(":")[1]

        print(f"  {task_id} | {name:<25} | {difficulty:<10} | score={score:.2f} | {len(actions)} steps | {elapsed:.2f}s")
        print(f"           actions: {' -> '.join(a.replace('mark_diagnosed:', 'diag:') for a in actions)}")
        print()

    avg = total_score / len(ALL_TASKS)
    print("-" * 70)
    print(f"  Average score: {avg:.2f}")
    print(f"  Tasks solved:  {sum(1 for _, _, _, s, _, _ in results if s >= 0.95)}/{len(ALL_TASKS)}")
    print()

    # Also output machine-readable JSON
    scores = {task_id: round(score, 4) for task_id, _, _, score, _, _ in results}
    print("JSON output:")
    print(json.dumps(scores, indent=2))


if __name__ == "__main__":
    main()