File size: 8,783 Bytes
c1060df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# AdaptShield β€” Pydantic Data Models
#
# CRITICAL DESIGN DECISION: Phase1Action and Phase2Action are SEPARATE classes.
# A single combined class with optional fields causes 500 errors when the
# evaluator sends a Phase 2 payload and Pydantic tries to validate Phase 1 fields.

from enum import Enum
from typing import Any, Dict, List, Optional

from openenv.core.env_server.types import Action, Observation
from pydantic import Field, model_validator


class DefenseAction(str, Enum):
    """
    Strict action space for the Tactical Executor (Phase 2).
    Using Enum prevents LLM hallucination from reaching the grader.
    """
    RATE_LIMIT = "rate_limit"  # Light β€” throttles traffic, keeps service online
    ISOLATE    = "isolate"     # Heavy β€” takes node offline, stops spread
    HONEYPOT   = "honeypot"    # Strategic β€” redirects attacker to decoy
    PATCH      = "patch"       # Targeted β€” fixes supply chain vulnerability
    MONITOR    = "monitor"     # Passive β€” gather info, risk escalation


class ThreatType(str, Enum):
    """Known attack strategies the Threat Analyst can classify."""
    BRUTE_FORCE      = "brute_force"
    LATERAL_MOVEMENT = "lateral_movement"
    EXFILTRATION     = "exfiltration"
    SUPPLY_CHAIN     = "supply_chain"
    BENIGN           = "benign"


class Phase1Action(Action):
    """
    Threat Analyst output β€” pure reasoning, no defensive action.

    The agent reads raw network state and produces a structured
    threat assessment. This is graded independently for classification
    accuracy before Phase 2 acts on it.
    """
    threat_type:        str             = Field(
        ...,
        description="Identified attack strategy: brute_force, lateral_movement, "
                    "exfiltration, supply_chain, or benign",
    )
    confidence:         float           = Field(
        ...,
        ge=0.0,
        le=1.0,
        description="Confidence in the threat classification (0.0 to 1.0)",
    )
    target_node:        str             = Field(
        ...,
        description="Primary affected node: auth_service, payment_service, "
                    "database, or api_gateway",
    )
    recommended_action: DefenseAction   = Field(
        ...,
        description="Recommended defense action for Phase 2 to execute",
    )
    reasoning:          Optional[str]   = Field(
        default=None,
        description="Chain of thought. Not graded. Helps training stability.",
    )


class Phase2Action(Action):
    """
    Tactical Executor output β€” defensive action based ONLY on Phase 1 assessment.

    Phase 2 agent is deliberately blind to raw network state.
    It receives only the Phase 1 threat assessment and must act on it.
    """
    action:      DefenseAction  = Field(
        ...,
        description="Defense action to execute",
    )
    target_node: str            = Field(
        ...,
        description="Node to apply action to: auth_service, payment_service, "
                    "database, or api_gateway",
    )
    reasoning:   Optional[str]  = Field(
        default=None,
        description="Chain of thought. Not graded.",
    )


class AdaptShieldAction(Action):
    """
    Unified action model accepted by the OpenEnv HTTP server.

    The environment alternates between two phases, so the transport layer must
    accept either a Threat Analyst payload or a Tactical Executor payload.
    Validation keeps those shapes distinct while still fitting the single
    action model expected by `create_app`.
    """

    threat_type: Optional[str] = Field(
        default=None,
        description="Phase 1 only: identified attack strategy",
    )
    confidence: Optional[float] = Field(
        default=None,
        ge=0.0,
        le=1.0,
        description="Phase 1 only: confidence in the threat classification",
    )
    target_node: Optional[str] = Field(
        default=None,
        description="Target node for either phase",
    )
    recommended_action: Optional[DefenseAction] = Field(
        default=None,
        description="Phase 1 only: recommended follow-up action",
    )
    action: Optional[DefenseAction] = Field(
        default=None,
        description="Phase 2 only: defensive action to execute",
    )
    reasoning: Optional[str] = Field(
        default=None,
        description="Optional one-sentence rationale",
    )

    @model_validator(mode="after")
    def validate_phase_shape(self) -> "AdaptShieldAction":
        phase1_present = any(
            value is not None
            for value in (self.threat_type, self.confidence, self.recommended_action)
        )
        phase2_present = self.action is not None

        if phase1_present and phase2_present:
            raise ValueError(
                "Action payload must be either Phase 1 or Phase 2, not both."
            )
        if not phase1_present and not phase2_present:
            raise ValueError(
                "Action payload must contain Phase 1 fields or a Phase 2 action."
            )

        if phase1_present:
            missing = [
                field_name
                for field_name, value in (
                    ("threat_type", self.threat_type),
                    ("confidence", self.confidence),
                    ("target_node", self.target_node),
                    ("recommended_action", self.recommended_action),
                )
                if value is None
            ]
        else:
            missing = [
                field_name
                for field_name, value in (
                    ("action", self.action),
                    ("target_node", self.target_node),
                )
                if value is None
            ]

        if missing:
            raise ValueError(
                f"Missing required fields for this phase: {', '.join(missing)}"
            )

        return self


class AdaptShieldObservation(Observation):
    """
    Observation returned after each step.

    Phase 1 observation: contains full network state (network_nodes, active_alerts).
    Phase 2 observation: network_nodes and active_alerts are EMPTY.
                         phase1_assessment contains the Phase 1 output.

    Episode number is NEVER included β€” agent must rely on signals only.
    """

    # Identity
    scenario_id:    str             = Field(default="")
    task_name:      str             = Field(default="")
    phase:          int             = Field(default=1,
        description="1 = Threat Analyst turn, 2 = Tactical Executor turn")
    turn:           int             = Field(default=0)
    max_turns:      int             = Field(default=5)

    # Network state β€” populated in Phase 1, EMPTY in Phase 2
    network_nodes:  Dict[str, Any]  = Field(default_factory=dict)
    active_alerts:  List[str]       = Field(default_factory=list)
    attack_stage:   str             = Field(
        default="none",
        description="Current attack progression stage: recon, exploit, exfiltration, none",
    )

    # Rolling history of last 3 turns
    history:        List[Dict[str, str]] = Field(default_factory=list)

    # Phase 2 only β€” Phase 1 output passed to executor
    phase1_assessment: Optional[Dict[str, Any]] = Field(
        default=None,
        description="Populated only in Phase 2. Phase 2 agent sees ONLY this.",
    )

    # Context
    system_context:    str          = Field(default="")
    available_actions: List[str]    = Field(default_factory=list)

    # Feedback
    last_action_result: Optional[str] = Field(default=None)
    reward:             float          = Field(default=0.0)
    done:               bool           = Field(default=False)
    metadata:           Dict[str, Any] = Field(default_factory=dict)

    def model_dump(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
        """
        Keep metadata in OpenEnv HTTP observation payloads.

        OpenEnv's serializer excludes metadata from the nested observation by
        default. AdaptShield exposes normalized_score there, so we remove only
        that exclusion while preserving the serializer's reward/done handling.
        """
        exclude = kwargs.get("exclude")
        if isinstance(exclude, set) and "metadata" in exclude:
            kwargs["exclude"] = set(exclude) - {"metadata"}
        elif isinstance(exclude, dict) and "metadata" in exclude:
            kwargs["exclude"] = {
                key: value for key, value in exclude.items() if key != "metadata"
            }
        return super().model_dump(*args, **kwargs)


# Backward-compatible aliases for earlier package names.
AdaptshieldAction = AdaptShieldAction
AdaptshieldObservation = AdaptShieldObservation