File size: 6,606 Bytes
e2f8b29 0b9b77b e2f8b29 0b9b77b e2f8b29 0b9b77b e2f8b29 0b9b77b e2f8b29 0b9b77b e2f8b29 0b9b77b e2f8b29 0b9b77b e2f8b29 0b9b77b e2f8b29 0b9b77b e2f8b29 0b9b77b e2f8b29 0b9b77b e2f8b29 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | """ScenarioParams and scenario sampling.
Internal scenario configuration — not exposed to the agent.
"""
from __future__ import annotations
import dataclasses
from typing import Optional
import torch
from ml_training_debugger.models import RootCauseDiagnosis
@dataclasses.dataclass(frozen=True)
class ScenarioParams:
"""Internal scenario parameters created at reset() time."""
task_id: str
root_cause: RootCauseDiagnosis
seed: int
learning_rate: float = 0.001
weight_decay: float = 0.0001
leakage_pct: float = 0.0
depth_multiplier: float = 1.0
divergence_epoch: int = 5
red_herring_intensity: float = 1.0
red_herring_spike_layer: str = "fc"
bug_type: Optional[str] = None
notes: Optional[str] = None
error_log: Optional[str] = None
gpu_memory_used_gb: float = 6.2
max_steps: int = 20
model_type: str = "cnn"
difficulty_level: int = 3
scheduler_gamma: float = 0.1
scheduler_step_size: int = 10
def _task_seed(task_id: str, seed: int) -> int:
"""Derive a deterministic seed from task_id and provided seed."""
task_num = int(task_id.split("_")[1])
return seed * 1000 + task_num
def _choose(options: list, rng: torch.Generator) -> object:
"""Choose a random element from a list using torch RNG."""
idx = int(torch.randint(0, len(options), (1,), generator=rng).item())
return options[idx]
def _pick_model_type(rng: torch.Generator) -> str:
"""Randomly pick CNN or MLP architecture."""
return str(_choose(["cnn", "mlp"], rng))
def sample_scenario(
task_id: str, seed: int = 42, difficulty_level: int = 3
) -> ScenarioParams:
"""Sample a ScenarioParams for the given task.
Args:
task_id: One of task_001 through task_007.
seed: Base seed for reproducibility.
difficulty_level: 1 (easy signals) to 5 (max ambiguity). Default 3.
Returns:
ScenarioParams with randomized fault parameters.
Raises:
ValueError: If task_id is unknown.
"""
effective_seed = _task_seed(task_id, seed)
rng = torch.Generator()
rng.manual_seed(effective_seed)
if task_id == "task_001":
lr = _choose([0.05, 0.08, 0.10, 0.15, 0.30], rng)
return ScenarioParams(
task_id=task_id,
root_cause=RootCauseDiagnosis.LR_TOO_HIGH,
seed=effective_seed,
learning_rate=float(lr),
error_log=f"RuntimeError: Loss is NaN at epoch 12 (lr={lr})",
max_steps=20,
model_type=_pick_model_type(rng),
difficulty_level=difficulty_level,
)
if task_id == "task_002":
lr = _choose([1e-6, 5e-6, 1e-5], rng)
depth_mult = _choose([1.0, 1.5, 2.0], rng)
return ScenarioParams(
task_id=task_id,
root_cause=RootCauseDiagnosis.VANISHING_GRADIENTS,
seed=effective_seed,
learning_rate=float(lr),
depth_multiplier=float(depth_mult),
notes=(
"Training resumed from a checkpoint saved at epoch 0 — "
"early learning rate warmup may still be in effect."
),
max_steps=20,
model_type=_pick_model_type(rng),
difficulty_level=difficulty_level,
)
if task_id == "task_003":
leakage = _choose([0.12, 0.18, 0.22, 0.28], rng)
return ScenarioParams(
task_id=task_id,
root_cause=RootCauseDiagnosis.DATA_LEAKAGE,
seed=effective_seed,
leakage_pct=float(leakage),
notes=(
"Model architecture upgraded from 2-layer to 4-layer CNN "
"at epoch 2. Performance improvement may reflect increased "
"model capacity."
),
max_steps=25,
model_type=_pick_model_type(rng),
difficulty_level=difficulty_level,
)
if task_id == "task_004":
wd = _choose([0.0, 0.0001, 0.001], rng)
div_epoch = _choose([5, 8, 12], rng)
return ScenarioParams(
task_id=task_id,
root_cause=RootCauseDiagnosis.OVERFITTING,
seed=effective_seed,
weight_decay=float(wd),
divergence_epoch=int(div_epoch),
notes=(
"Dataset augmentation was disabled for this run to speed "
"up training. Re-enabling may improve generalization."
),
max_steps=25,
model_type=_pick_model_type(rng),
difficulty_level=difficulty_level,
)
if task_id == "task_005":
intensity = torch.empty(1).uniform_(0.8, 2.5, generator=rng).item()
spike_layer = _choose(["fc", "conv1"], rng)
return ScenarioParams(
task_id=task_id,
root_cause=RootCauseDiagnosis.BATCHNORM_EVAL_MODE,
seed=effective_seed,
red_herring_intensity=float(intensity),
red_herring_spike_layer=str(spike_layer),
gpu_memory_used_gb=14.56,
error_log=(
"Warning: GPU memory pressure detected, consider reducing "
"batch size or enabling gradient checkpointing"
),
max_steps=30,
model_type="cnn", # CNN always for BatchNorm eval — MLP BatchNorm1d behaves differently
difficulty_level=difficulty_level,
)
if task_id == "task_006":
bug = _choose(
["eval_mode", "detach_loss", "zero_grad_missing", "inplace_relu"], rng
)
return ScenarioParams(
task_id=task_id,
root_cause=RootCauseDiagnosis.CODE_BUG,
seed=effective_seed,
bug_type=str(bug),
notes="Try adjusting the learning rate schedule.",
max_steps=30,
model_type="cnn", # Code templates reference CNN training — keep CNN for consistency
difficulty_level=difficulty_level,
)
if task_id == "task_007":
gamma = _choose([0.01, 0.001, 0.0001], rng)
step_size = _choose([2, 3, 5], rng)
return ScenarioParams(
task_id=task_id,
root_cause=RootCauseDiagnosis.SCHEDULER_MISCONFIGURED,
seed=effective_seed,
scheduler_gamma=float(gamma),
scheduler_step_size=int(step_size),
notes="LR scheduler was recently added to improve convergence.",
max_steps=25,
model_type=_pick_model_type(rng),
difficulty_level=difficulty_level,
)
raise ValueError(f"Unknown task_id: {task_id}")
|