File size: 3,539 Bytes
f89b1ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import re
from typing import Any, Dict, List, Literal, Optional

from openenv.core.env_server import (
    Action as BaseAction,
)
from openenv.core.env_server import (
    Observation as BaseObservation,
)
from openenv.core.env_server import (
    State as BaseState,
)
from pydantic import BaseModel, Field, field_validator

# ── Action Payload Models (Pydantic-validated) ─────────────────────


class ExecuteSQLPayload(BaseModel):
    query: str = Field(..., min_length=1, max_length=2000)


class ReadFilePayload(BaseModel):
    filepath: str = Field(..., min_length=1, max_length=255)


class WriteFilePayload(BaseModel):
    filepath: str = Field(..., min_length=1, max_length=255)
    content: str = Field(..., max_length=1_000_000)


class RunScriptPayload(BaseModel):
    filepath: str = Field(..., min_length=1, max_length=255)
    args: List[str] = Field(default_factory=list, max_length=20)

    @field_validator("filepath")
    @classmethod
    def must_be_safe_script_name(cls, v: str) -> str:
        basename = v.rsplit("/", 1)[-1]
        if not re.match(r"^[a-zA-Z0-9_\-]+\.py$", basename):
            raise ValueError("Script name must be alphanumeric with .py extension.")
        return v

    @field_validator("args")
    @classmethod
    def args_must_be_safe(cls, v: list[str]) -> list[str]:
        for arg in v:
            if not isinstance(arg, str) or len(arg) > 500:
                raise ValueError("Each arg must be a string under 500 chars.")
        return v


class SendEmailPayload(BaseModel):
    to_email: str = Field(..., max_length=320)
    subject: str = Field(..., min_length=1, max_length=500)
    body: str = Field(..., min_length=1, max_length=100_000)

    @field_validator("to_email")
    @classmethod
    def must_look_like_email(cls, v: str) -> str:
        if not re.match(r"^[^@\s]+@[^@\s]+\.[^@\s]+$", v):
            raise ValueError("Invalid email format.")
        return v


ACTION_TYPE = Literal["ExecuteSQL", "ReadFile", "WriteFile", "RunScript", "SendEmail"]

PAYLOAD_MODELS: dict[str, type[BaseModel]] = {
    "ExecuteSQL": ExecuteSQLPayload,
    "ReadFile": ReadFilePayload,
    "WriteFile": WriteFilePayload,
    "RunScript": RunScriptPayload,
    "SendEmail": SendEmailPayload,
}


# ── Action Model (extends OpenEnv Action) ──────────────────────────


class DataOpsAction(BaseAction):
    action_type: ACTION_TYPE = Field(
        ..., description="One of: ExecuteSQL, ReadFile, WriteFile, RunScript, SendEmail"
    )
    payload: Dict[str, Any] = Field(
        ..., description="Parameters for the chosen action type."
    )


# ── Observation Model (extends OpenEnv Observation) ────────────────


class DataOpsObservation(BaseObservation):
    status: Literal["success", "error"] = "error"
    message: str = ""
    stdout: Optional[str] = None
    stderr: Optional[str] = None
    sql_results: Optional[List[Dict[str, Any]]] = None
    email_delivery_status: Optional[str] = None
    step_count: int = 0
    max_steps: int = 0


# ── State Model (extends OpenEnv State) ────────────────────────────


class DataOpsState(BaseState):
    task_id: str = ""
    task_description: str = ""
    seed: int = 0
    max_steps: int = 15
    done: bool = False
    cumulative_reward: float = 0.0
    actions_taken: List[str] = Field(default_factory=list)
    emails_sent: int = 0