File size: 5,986 Bytes
e2f8b29 0b9b77b e2f8b29 0b9b77b e2f8b29 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | """All Pydantic models, enums, and typed data structures.
No business logic. Pure data definitions.
Spec reference: Section 10 — Data Models.
"""
from __future__ import annotations
import enum
from typing import Optional, Union
import torch # noqa: F401 — PyTorch-native project, required import
from openenv.core.env_server.types import Action, Observation
from pydantic import BaseModel, Field
class RootCauseDiagnosis(str, enum.Enum):
"""Closed enumeration of ML failure root causes. Spec Section 10."""
LR_TOO_HIGH = "lr_too_high"
VANISHING_GRADIENTS = "vanishing_gradients"
DATA_LEAKAGE = "data_leakage"
OVERFITTING = "overfitting"
BATCHNORM_EVAL_MODE = "batchnorm_eval_mode"
CODE_BUG = "code_bug"
SCHEDULER_MISCONFIGURED = "scheduler_misconfigured"
VALID_DIAGNOSES: set[str] = {d.value for d in RootCauseDiagnosis}
class TrainingConfig(BaseModel):
"""Typed hyperparameter configuration. Spec Section 10."""
learning_rate: float = 0.001
weight_decay: float = 0.0001
batch_size: int = 64
hidden_dim: int = 64
num_layers: int = 3
optimizer: str = "adam"
dropout_rate: float = 0.0
gradient_clip_norm: Optional[float] = None
VALID_CONFIG_KEYS: set[str] = set(TrainingConfig.model_fields.keys())
class GradientStats(BaseModel):
"""Per-layer gradient information from real torch.autograd. Spec Section 10."""
layer_name: str
norm_history: list[float]
mean_norm: float
max_norm: float
is_exploding: bool # True when mean_norm > 10.0
is_vanishing: bool # True when mean_norm < 1e-6
class ModelWeightStats(BaseModel):
"""Per-layer weight statistics from real state_dict(). Spec Section 10."""
layer_name: str
weight_norm: float
weight_mean: float
weight_std: float
weight_min: float
weight_max: float
dead_neuron_pct: float = 0.0
has_nan: bool = False
has_inf: bool = False
class DataBatchStats(BaseModel):
"""Data batch inspection results. Spec Section 10."""
label_distribution: dict[int, float]
feature_mean: float
feature_std: float
null_count: int = 0
class_overlap_score: float
batch_size: int
duplicate_ratio: float = 0.0
confusion_matrix: Optional[list[list[float]]] = None
class CodeSnippet(BaseModel):
"""PyTorch code for Task 6 inspection. Spec Section 10."""
code: str
filename: str = "train.py"
line_count: int
imports: list[str]
hint: Optional[str] = None
class EpisodeState(BaseModel):
"""Tracks agent history within an episode. Spec Section 10."""
step_count: int = 0
gradients_inspected: bool = False
gradients_were_normal: bool = False
data_inspected: bool = False
model_modes_inspected: bool = False
model_weights_inspected: bool = False
code_inspected: bool = False
fix_action_taken: bool = False
restart_after_fix: bool = False
diagnosis_submitted: bool = False
actions_taken: list[str] = Field(default_factory=list)
def compute_available_actions(self) -> list[str]:
"""Dynamically compute available actions based on current state.
Rules from spec Section 10 — Dynamic available_actions:
- restart_run: only after fix_action_taken
- rollback_checkpoint: only after restart_after_fix
- fix_code: only after code_inspected
- mark_diagnosed: disappears after diagnosis_submitted
"""
actions: list[str] = [
"inspect_gradients",
"inspect_data_batch",
"inspect_model_modes",
"inspect_model_weights",
"inspect_code",
"modify_config",
"add_callback",
"replace_optimizer",
"patch_data_loader",
"fix_model_mode",
]
if self.code_inspected:
actions.append("fix_code")
if self.fix_action_taken:
actions.append("restart_run")
if self.restart_after_fix:
actions.append("rollback_checkpoint")
if not self.diagnosis_submitted:
actions.append("mark_diagnosed")
return actions
ALL_ACTION_TYPES: set[str] = {
"inspect_gradients",
"inspect_data_batch",
"inspect_model_modes",
"inspect_model_weights",
"inspect_code",
"modify_config",
"add_callback",
"replace_optimizer",
"patch_data_loader",
"fix_model_mode",
"fix_code",
"restart_run",
"mark_diagnosed",
"rollback_checkpoint",
}
class MLTrainingAction(Action):
"""What the agent can do — extends openenv Action. Spec Section 10."""
action_type: str
target: Optional[str] = None
value: Optional[Union[float, int, str]] = None
diagnosis: Optional[str] = None
line: Optional[int] = None
replacement: Optional[str] = None
class MLTrainingObservation(Observation):
"""Full observation — extends openenv Observation.
Observation base has built-in: done (bool), reward (float|None), metadata (dict).
Spec Section 10.
"""
run_id: str = ""
framework: str = "pytorch"
epoch: int = 20
training_loss_history: list[float] = Field(default_factory=list)
val_loss_history: list[float] = Field(default_factory=list)
val_accuracy_history: list[float] = Field(default_factory=list)
gradient_stats: list[GradientStats] = Field(default_factory=list)
model_weight_stats: Optional[list[ModelWeightStats]] = None
gpu_memory_used_gb: float = 6.2
gpu_memory_total_gb: float = 16.0
learning_rate: float = 0.001
current_config: TrainingConfig = Field(default_factory=TrainingConfig)
error_log: Optional[str] = None
data_batch_stats: Optional[DataBatchStats] = None
model_mode_info: Optional[dict[str, str]] = None
code_snippet: Optional[CodeSnippet] = None
available_actions: list[str] = Field(default_factory=list)
episode_state: EpisodeState = Field(default_factory=EpisodeState)
notes: Optional[str] = None
|