Spaces:
Sleeping
Sleeping
File size: 6,749 Bytes
9256ec9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | """
MLOps Pipeline Debugger β Pydantic Models
The agent acts as an ML engineer investigating a broken training run.
It has access to training artifacts (logs, configs, dataset stats, preprocessing code)
and must diagnose the root cause through systematic investigation.
Action Space:
read_config β Get training configuration (hyperparams, model arch, optimizer)
read_logs β Get training logs (filterable by keyword/epoch range)
check_dataset_stats β Get dataset split sizes, class distribution, feature statistics
inspect_preprocessing β Read preprocessing pipeline code
read_eval_results β Get validation and test set evaluation metrics
run_sanity_check β Compute a specific diagnostic check (label overlap, class balance, etc.)
query_artifact β Fetch a specific field from any artifact
submit_diagnosis β Final answer β triggers grading
Observation Space:
task_id, task_description
available_artifacts β list of artifacts the agent can inspect
last_action_result β result of the most recent action
artifacts_read β which artifacts have been read so far (exploration tracking)
step_count, max_steps
done
"""
from __future__ import annotations
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
# βββ Action ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class MLOpsAction(BaseModel):
"""
One action the agent can take per step.
action_type determines which fields are used:
read_config β (no extra fields)
read_logs β log_filter (optional keyword or "epoch:N-M")
check_dataset_stats β (no extra fields)
inspect_preprocessing β (no extra fields)
read_eval_results β (no extra fields)
run_sanity_check β sanity_check_type (required)
query_artifact β artifact_name + field_path (required)
submit_diagnosis β all diagnosis fields (required)
"""
action_type: Literal[
"read_config",
"read_logs",
"check_dataset_stats",
"inspect_preprocessing",
"read_eval_results",
"run_sanity_check",
"query_artifact",
"submit_diagnosis",
] = Field(..., description="Which action to perform")
# read_logs
log_filter: Optional[str] = Field(
None,
description="Filter logs by keyword (e.g. 'nan', 'warning', 'error') or epoch range 'epoch:1-5'"
)
# run_sanity_check
sanity_check_type: Optional[Literal[
"label_consistency", # Are train/eval label mappings identical?
"data_leakage", # Is there train/val sample overlap?
"gradient_norms", # Are gradient norms within normal range?
"class_balance", # Are classes balanced across splits?
"feature_statistics", # Do train/val feature distributions match?
"encoder_version_match", # Do all pipeline stages use the same encoder version?
"loss_trajectory", # Is the loss curve shape anomalous?
"metric_gap_analysis", # Is val vs test metric gap suspiciously large?
]] = Field(None, description="Type of sanity check to run")
# query_artifact
artifact_name: Optional[Literal[
"config.yaml",
"train.log",
"dataset_stats.json",
"preprocessing.py",
"eval_results.json",
"model_card.json",
]] = Field(None, description="Artifact to query a specific field from")
field_path: Optional[str] = Field(
None,
description="Dot-notation field path, e.g. 'optimizer.learning_rate' or 'metrics.val_accuracy'"
)
# submit_diagnosis
failure_category: Optional[Literal[
"config_error", # Wrong hyperparameter value
"data_leakage", # Train/val contamination
"evaluation_bug", # Eval pipeline uses wrong artifacts
"preprocessing_bug", # Data transformation applied incorrectly
"label_mismatch", # Label encoding inconsistency
"architecture_bug", # Model architecture misconfiguration
]] = Field(None, description="Category of the failure")
root_cause_file: Optional[str] = Field(
None, description="Which artifact file contains the root cause"
)
root_cause_field: Optional[str] = Field(
None, description="Specific parameter, function, or variable that is wrong"
)
diagnosis: Optional[str] = Field(
None, description="Natural language explanation of what went wrong and why"
)
proposed_fix: Optional[str] = Field(
None, description="Concrete change that would fix the issue"
)
# βββ Observation βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class ArtifactMeta(BaseModel):
name: str
description: str
size_hint: str # e.g. "47 lines", "12 fields"
last_modified: str
class MLOpsObservation(BaseModel):
"""Everything the agent sees after each step / reset."""
task_id: str
task_description: str
# Run summary β always visible
run_id: str
run_summary: Dict[str, Any] = Field(
description="High-level run info: model, dataset, final loss, training status"
)
available_artifacts: List[ArtifactMeta]
artifacts_read: List[str] = Field(
default_factory=list,
description="Names of artifacts the agent has already read"
)
last_action_result: Dict[str, Any] = Field(default_factory=dict)
step_count: int = 0
max_steps: int = 30
done: bool = False
messages: List[str] = Field(default_factory=list)
# βββ State βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class MLOpsState(BaseModel):
"""Full internal state β for RL harness and debugging."""
task_id: str
seed: int
step_count: int
max_steps: int
episode_done: bool
# Planted bug ground truth
bug_type: str
bug_category: str
bug_file: str
bug_field: str
gold_fix: str
# All generated artifacts (full text)
artifacts: Dict[str, str]
# Agent's investigation history
artifacts_read: List[str]
sanity_checks_run: List[str]
duplicate_queries: int
current_score: float
|