File size: 6,749 Bytes
9256ec9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""
MLOps Pipeline Debugger β€” Pydantic Models

The agent acts as an ML engineer investigating a broken training run.
It has access to training artifacts (logs, configs, dataset stats, preprocessing code)
and must diagnose the root cause through systematic investigation.

Action Space:
    read_config           β†’ Get training configuration (hyperparams, model arch, optimizer)
    read_logs             β†’ Get training logs (filterable by keyword/epoch range)
    check_dataset_stats   β†’ Get dataset split sizes, class distribution, feature statistics
    inspect_preprocessing β†’ Read preprocessing pipeline code
    read_eval_results     β†’ Get validation and test set evaluation metrics
    run_sanity_check      β†’ Compute a specific diagnostic check (label overlap, class balance, etc.)
    query_artifact        β†’ Fetch a specific field from any artifact
    submit_diagnosis      β†’ Final answer β€” triggers grading

Observation Space:
    task_id, task_description
    available_artifacts   β€” list of artifacts the agent can inspect
    last_action_result    β€” result of the most recent action
    artifacts_read        β€” which artifacts have been read so far (exploration tracking)
    step_count, max_steps
    done
"""

from __future__ import annotations
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field


# ─── Action ──────────────────────────────────────────────────────────────────

class MLOpsAction(BaseModel):
    """
    One action the agent can take per step.

    action_type determines which fields are used:
        read_config           β†’ (no extra fields)
        read_logs             β†’ log_filter (optional keyword or "epoch:N-M")
        check_dataset_stats   β†’ (no extra fields)
        inspect_preprocessing β†’ (no extra fields)
        read_eval_results     β†’ (no extra fields)
        run_sanity_check      β†’ sanity_check_type (required)
        query_artifact        β†’ artifact_name + field_path (required)
        submit_diagnosis      β†’ all diagnosis fields (required)
    """
    action_type: Literal[
        "read_config",
        "read_logs",
        "check_dataset_stats",
        "inspect_preprocessing",
        "read_eval_results",
        "run_sanity_check",
        "query_artifact",
        "submit_diagnosis",
    ] = Field(..., description="Which action to perform")

    # read_logs
    log_filter: Optional[str] = Field(
        None,
        description="Filter logs by keyword (e.g. 'nan', 'warning', 'error') or epoch range 'epoch:1-5'"
    )

    # run_sanity_check
    sanity_check_type: Optional[Literal[
        "label_consistency",      # Are train/eval label mappings identical?
        "data_leakage",           # Is there train/val sample overlap?
        "gradient_norms",         # Are gradient norms within normal range?
        "class_balance",          # Are classes balanced across splits?
        "feature_statistics",     # Do train/val feature distributions match?
        "encoder_version_match",  # Do all pipeline stages use the same encoder version?
        "loss_trajectory",        # Is the loss curve shape anomalous?
        "metric_gap_analysis",    # Is val vs test metric gap suspiciously large?
    ]] = Field(None, description="Type of sanity check to run")

    # query_artifact
    artifact_name: Optional[Literal[
        "config.yaml",
        "train.log",
        "dataset_stats.json",
        "preprocessing.py",
        "eval_results.json",
        "model_card.json",
    ]] = Field(None, description="Artifact to query a specific field from")
    field_path: Optional[str] = Field(
        None,
        description="Dot-notation field path, e.g. 'optimizer.learning_rate' or 'metrics.val_accuracy'"
    )

    # submit_diagnosis
    failure_category: Optional[Literal[
        "config_error",       # Wrong hyperparameter value
        "data_leakage",       # Train/val contamination
        "evaluation_bug",     # Eval pipeline uses wrong artifacts
        "preprocessing_bug",  # Data transformation applied incorrectly
        "label_mismatch",     # Label encoding inconsistency
        "architecture_bug",   # Model architecture misconfiguration
    ]] = Field(None, description="Category of the failure")
    root_cause_file: Optional[str] = Field(
        None, description="Which artifact file contains the root cause"
    )
    root_cause_field: Optional[str] = Field(
        None, description="Specific parameter, function, or variable that is wrong"
    )
    diagnosis: Optional[str] = Field(
        None, description="Natural language explanation of what went wrong and why"
    )
    proposed_fix: Optional[str] = Field(
        None, description="Concrete change that would fix the issue"
    )


# ─── Observation ─────────────────────────────────────────────────────────────

class ArtifactMeta(BaseModel):
    name: str
    description: str
    size_hint: str   # e.g. "47 lines", "12 fields"
    last_modified: str


class MLOpsObservation(BaseModel):
    """Everything the agent sees after each step / reset."""
    task_id: str
    task_description: str

    # Run summary β€” always visible
    run_id: str
    run_summary: Dict[str, Any] = Field(
        description="High-level run info: model, dataset, final loss, training status"
    )

    available_artifacts: List[ArtifactMeta]
    artifacts_read: List[str] = Field(
        default_factory=list,
        description="Names of artifacts the agent has already read"
    )

    last_action_result: Dict[str, Any] = Field(default_factory=dict)

    step_count: int = 0
    max_steps: int = 30
    done: bool = False
    messages: List[str] = Field(default_factory=list)


# ─── State ───────────────────────────────────────────────────────────────────

class MLOpsState(BaseModel):
    """Full internal state β€” for RL harness and debugging."""
    task_id: str
    seed: int
    step_count: int
    max_steps: int
    episode_done: bool

    # Planted bug ground truth
    bug_type: str
    bug_category: str
    bug_file: str
    bug_field: str
    gold_fix: str

    # All generated artifacts (full text)
    artifacts: Dict[str, str]

    # Agent's investigation history
    artifacts_read: List[str]
    sanity_checks_run: List[str]
    duplicate_queries: int

    current_score: float