File size: 11,154 Bytes
0b89610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
"""
Typed Pydantic models for the incident operations OpenEnv environment.
"""

from __future__ import annotations

from typing import Literal, Optional

from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator


class ChunkSummary(BaseModel):
    model_config = ConfigDict(
        json_schema_extra={
            "examples": [
                {
                    "chunk_id": "support_003",
                    "domain": "Customer Support Operations",
                    "tokens": 132,
                    "keywords": ["refund policy", "incident timeline", "billing ledger"],
                }
            ]
        }
    )

    chunk_id: str = Field(..., description="Unique artifact identifier exposed to the agent.")
    domain: str = Field(..., description="High-level source domain for the artifact.")
    tokens: int = Field(..., ge=1, description="Approximate token cost for including the artifact.")
    keywords: list[str] = Field(..., min_length=1, description="Important artifact hints available before inspection.")

    @field_validator("chunk_id", "domain")
    @classmethod
    def validate_non_empty_text(cls, value: str) -> str:
        value = value.strip()
        if not value:
            raise ValueError("Value must not be empty.")
        return value

    @field_validator("keywords")
    @classmethod
    def validate_keywords(cls, value: list[str]) -> list[str]:
        cleaned = [keyword.strip() for keyword in value if keyword.strip()]
        if not cleaned:
            raise ValueError("keywords must contain at least one non-empty entry.")
        return cleaned


class RagObservation(BaseModel):
    model_config = ConfigDict(
        json_schema_extra={
            "examples": [
                {
                    "case_id": "case-refund-triage-001",
                    "case_summary": "A business customer requests a refund after a confirmed outage.",
                    "objective": "Prepare a refund triage memo grounded in support evidence.",
                    "workflow_stage": "triage",
                    "customer_tier": "business",
                    "incident_severity": "sev2",
                    "available_artifacts": [
                        {
                            "chunk_id": "support_003",
                            "domain": "Customer Support Operations",
                            "tokens": 132,
                            "keywords": ["refund policy", "incident timeline", "billing ledger"],
                        }
                    ],
                    "reviewed_artifacts": [],
                    "prioritized_artifacts": [],
                    "plan_draft": None,
                    "report_requirements": ["State whether the case should proceed to refund review."],
                    "total_tokens_used": 0,
                    "token_budget": 850,
                    "step_number": 0,
                    "task_name": "refund_triage_easy",
                    "last_action_feedback": None,
                    "query": "Prepare an incident-linked refund triage memo.",
                    "available_chunks": [],
                    "selected_chunks": [],
                }
            ]
        }
    )

    case_id: str = Field(..., description="Unique identifier for the active simulated incident case.")
    case_summary: str = Field(..., description="Short real-world case summary presented to the agent.")
    objective: str = Field(..., description="The operational deliverable the agent must produce.")
    workflow_stage: Literal["triage", "analysis", "resolution", "submitted"] = Field(
        ..., description="Current workflow stage in the incident operations process."
    )
    customer_tier: Literal["standard", "business", "enterprise"] = Field(
        ..., description="Customer tier for the active case."
    )
    incident_severity: Literal["sev3", "sev2", "sev1"] = Field(
        ..., description="Severity of the active incident."
    )
    available_artifacts: list[ChunkSummary] = Field(
        ..., description="Artifacts that can be inspected, prioritized, or summarized."
    )
    reviewed_artifacts: list[str] = Field(
        default_factory=list,
        description="Artifact ids the agent has inspected so far.",
    )
    prioritized_artifacts: list[str] = Field(
        default_factory=list,
        description="Artifact ids currently included in the working resolution set.",
    )
    plan_draft: Optional[str] = Field(
        default=None,
        description="Current draft of the resolution plan or operational recommendation.",
    )
    report_requirements: list[str] = Field(
        default_factory=list,
        description="Deterministic requirements the final report must satisfy.",
    )
    progress_signals: dict[str, float] = Field(
        default_factory=dict,
        description="Normalized progress metrics for artifact coverage, planning, and workflow readiness.",
    )
    total_tokens_used: int = Field(..., ge=0, description="Current token cost of the prioritized working set.")
    token_budget: int = Field(..., ge=1, description="Maximum allowed token budget for the current task.")
    step_number: int = Field(..., ge=0, description="Current step number in the episode.")
    task_name: str = Field(..., description="Active task identifier.")
    last_action_feedback: Optional[str] = Field(default=None, description="Outcome of the previous action.")

    query: str = Field(..., description="Compatibility mirror of objective for legacy clients.")
    available_chunks: list[ChunkSummary] = Field(
        default_factory=list,
        description="Compatibility mirror of available_artifacts for legacy clients.",
    )
    selected_chunks: list[str] = Field(
        default_factory=list,
        description="Compatibility mirror of prioritized_artifacts for legacy clients.",
    )

    @field_validator("case_id", "case_summary", "objective", "task_name", "query")
    @classmethod
    def validate_required_strings(cls, value: str) -> str:
        value = value.strip()
        if not value:
            raise ValueError("Value must not be empty.")
        return value

    @field_validator("reviewed_artifacts", "prioritized_artifacts", "selected_chunks")
    @classmethod
    def validate_ids(cls, value: list[str]) -> list[str]:
        cleaned = [artifact_id.strip() for artifact_id in value if artifact_id.strip()]
        if len(cleaned) != len(set(cleaned)):
            raise ValueError("Artifact id lists must not contain duplicates.")
        return cleaned

    @field_validator("report_requirements")
    @classmethod
    def validate_report_requirements(cls, value: list[str]) -> list[str]:
        cleaned = [item.strip() for item in value if item.strip()]
        return cleaned

    @field_validator("last_action_feedback")
    @classmethod
    def validate_feedback(cls, value: Optional[str]) -> Optional[str]:
        if value is None:
            return value
        value = value.strip()
        return value or None

    @model_validator(mode="after")
    def validate_budget_and_aliases(self) -> "RagObservation":
        if self.total_tokens_used > self.token_budget:
            raise ValueError("total_tokens_used cannot exceed token_budget.")
        if self.query != self.objective:
            raise ValueError("query must mirror objective.")
        if self.selected_chunks != self.prioritized_artifacts:
            raise ValueError("selected_chunks must mirror prioritized_artifacts.")
        if len(self.available_chunks) != len(self.available_artifacts):
            raise ValueError("available_chunks must mirror available_artifacts.")
        return self


class RagAction(BaseModel):
    model_config = ConfigDict(
        json_schema_extra={
            "examples": [
                {"action_type": "inspect_artifact", "artifact_id": "support_003"},
                {"action_type": "summarize_artifact", "artifact_id": "support_003", "compression_ratio": 0.55},
                {"action_type": "set_resolution_plan", "plan": "Verify outage evidence and route manual exceptions to finance review."},
                {"action_type": "submit_report", "answer": "Proceed to refund review only after outage and billing evidence are confirmed. [support_001] [support_003]"},
            ]
        }
    )

    action_type: Literal[
        "inspect_artifact",
        "prioritize_artifact",
        "summarize_artifact",
        "set_resolution_plan",
        "submit_report",
        "select_chunk",
        "deselect_chunk",
        "compress_chunk",
        "submit_answer",
    ] = Field(..., description="The environment action the agent wants to perform.")
    artifact_id: Optional[str] = Field(default=None, description="Target artifact id for artifact actions.")
    chunk_id: Optional[str] = Field(default=None, description="Legacy alias for artifact_id.")
    compression_ratio: Optional[float] = Field(default=None, ge=0.3, le=0.9)
    plan: Optional[str] = Field(default=None, description="Draft of the current operational resolution plan.")
    answer: Optional[str] = Field(default=None, description="Final report or resolution memo to submit.")

    @field_validator("artifact_id", "chunk_id", "plan", "answer")
    @classmethod
    def normalize_optional_strings(cls, value: Optional[str]) -> Optional[str]:
        if value is None:
            return value
        value = value.strip()
        return value or None

    @model_validator(mode="after")
    def validate_action_semantics(self) -> "RagAction":
        normalized_artifact_id = self.artifact_id or self.chunk_id
        if self.action_type in {"inspect_artifact", "prioritize_artifact", "select_chunk", "deselect_chunk"}:
            if normalized_artifact_id is None:
                raise ValueError("artifact_id or chunk_id is required for artifact selection actions.")
        elif self.action_type in {"summarize_artifact", "compress_chunk"}:
            if normalized_artifact_id is None:
                raise ValueError("artifact_id or chunk_id is required for summarize actions.")
            if self.compression_ratio is None:
                raise ValueError("compression_ratio is required for summarize actions.")
        elif self.action_type == "set_resolution_plan":
            if self.plan is None:
                raise ValueError("plan is required for set_resolution_plan.")
        elif self.action_type in {"submit_report", "submit_answer"}:
            if self.answer is None:
                raise ValueError("answer is required for submit_report/submit_answer.")
        return self


class RagReward(BaseModel):
    total: float = Field(..., ge=0.0, le=1.0)
    token_efficiency: float = Field(..., ge=0.0, le=1.0)
    answer_quality: float = Field(..., ge=0.0, le=1.0)
    retrieval_precision: float = Field(..., ge=0.0, le=1.0)
    penalty: float = Field(..., ge=0.0, le=1.0)

    @model_validator(mode="after")
    def validate_total_bound(self) -> "RagReward":
        if self.total > 1.0 or self.total < 0.0:
            raise ValueError("total must remain within [0.0, 1.0].")
        return self