File size: 2,657 Bytes
ff8ce5f
25fff92
ff8ce5f
b37875f
 
 
5bf3c8c
 
25fff92
 
 
5bf3c8c
 
 
dfbe1fe
 
 
ff8ce5f
 
 
 
 
 
 
 
 
 
 
 
 
 
6b279f6
5bf3c8c
 
 
 
 
 
 
 
 
 
 
 
d3b224f
 
5bf3c8c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import Any, Dict, Literal

from openenv.core.env_server.types import Action, Observation, State
from pydantic import Field


class WhyDidItFailAction(Action):
    """Agent's diagnostic action."""
    action_type: Literal["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"] = Field(
        ..., description="One of: inspect_logs | inspect_config | inspect_gradients | submit_diagnosis"
    )
    diagnosis: str | None = Field(None, description=
        "Required when action_type=submit_diagnosis. Its the agent's conclusion about what is wrong.")
    suggested_fix: str | None = Field(None, description=
        "Required when action_type=submit_diagnosis. Exact fix to apply.")
    reasoning: str | None = Field(None, description=
        "Required when action_type=submit_diagnosis. Explain what evidence led to this diagnosis.")
    metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata.")


class WhyDidItFailState(State):
    """Full episode state exposed via GET /state and WSStateMessage."""
    scenario_key: str | None = Field(None, description=
        "Key of the active scenario (e.g. 'exploding_gradients'). None before reset.")
    difficulty: str | None = Field(None, description=
        "Difficulty tier of the active scenario: easy, medium, or hard.")
    inspection_order: list[str] = Field(default_factory=list, description=
        "Sources inspected so far this episode, in the order they were first visited.")
    required_sources: list[str] = Field(default_factory=list, description=
        "Sources the agent must inspect before submitting a valid diagnosis.")
    max_steps: int = Field(0, description=
        "Hard step ceiling for this episode. Exceeding it terminates with score 0.10.")


class WhyDidItFailObservation(Observation):
    """What the agent sees after each action."""
    task_description: str = Field(..., description=
        "The problem the agent must diagnose.")
    visible_data: dict = Field(..., description=
        "Data returned by the last action (logs, config, gradients, etc.).")
    available_actions: list[str] = Field(..., description=
        "Which action_types are valid on this step.")
    steps_taken: int = Field(..., description=
        "Number of actions taken so far in this episode.")
    reward: float = Field(default=0.10, description=    # type: ignore[override]
        "Score for the current step. 0.90 = max.")
    done: bool = Field(default=False, description=
        "True when the episode has ended.")
    feedback: str = Field(..., description=
        "Partial progress hint from the environment.")