File size: 3,844 Bytes
ec8c511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from __future__ import annotations

from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Field
# Standard OpenEnv types (if openenv-core is installed)
try:
    from openenv.core.env_server.types import Action, Observation
except ImportError:
    # Fallback if not installed
    class Action(BaseModel):
        pass
    class Observation(BaseModel):
        pass

# --- Custom Action/Observation classes as seen in video ---

class FirewallAction(Action):
    """Action for the AI Firewall environment."""
    action: int = Field(..., description="Action index: 0=ALLOW, 1=BLOCK, 2=INSPECT, 3=SANDBOX, 4=RATE_LIMIT, 5=QUARANTINE")
    session_id: Optional[str] = Field(None, description="Specific session to act upon")

class FirewallObservation(Observation):
    """Observation for the AI Firewall environment."""
    features: List[float] = Field(..., description="22-dimensional normalized feature vector")
    focus_session_id: Optional[str] = Field(None, description="ID of the session currently in focus")

# --- Original models from env/models.py ---

class ActionRecord(BaseModel):
    tick: int
    session_id: str
    action: int
    action_name: str
    malicious: bool
    reward: float
    components: Dict[str, float]

class ResetRequest(BaseModel):
    task: str = Field(default="easy", description="Task difficulty: easy, medium, hard")
    seed: Optional[int] = Field(default=None, description="Random seed for reproducibility")

class StepRequest(BaseModel):
    actions: Dict[str, int] = Field(default_factory=dict, description="Map of session_id to action index")

class StepSingleRequest(BaseModel):
    action: int = Field(..., description="Action index (0-5) for the current focus session")

class ToolRequest(BaseModel):
    kwargs: Dict[str, Any] = Field(default_factory=dict, description="Arguments for the tool call")

class StateResponse(BaseModel):
    episode_id: int
    task: str
    step_count: int
    current_tick: int
    observation_dim: int
    num_actions: int
    budget_remaining: float
    total_reward: float
    pending_session_count: int
    inspected_session_count: int
    pending_session_ids: List[str]
    inspected_session_ids: List[str]
    queue_length: int
    focus_session_id: Optional[str]
    focus_observation: List[float]

class StepResponse(BaseModel):
    reward: float
    done: bool
    state: StateResponse
    info: Dict[str, Any]

class StepSingleResponse(BaseModel):
    observation: List[float]
    reward: float
    done: bool
    state: StateResponse
    info: Dict[str, Any]

class EvaluateSessionResponse(BaseModel):
    session_id: str
    features: Dict[str, Any]
    observation: List[float]
    is_inspected: bool
    revealed_malicious: Optional[bool]
    expires_tick: int

class NetworkStatsResponse(BaseModel):
    episode_id: int
    task: str
    tick: int
    step_count: int
    total_reward: float
    budget_remaining: float
    budget_used_pct: float
    total_malicious: int
    total_benign: int
    detection_rate: float
    false_positive_rate: float
    efficiency: float
    early_detection_bonus: float
    cascade_prevention: float
    correct_allows: int
    inspections: int
    expired_malicious: int
    expired_benign: int

class HealthResponse(BaseModel):
    status: str
    version: str

class ToolsListResponse(BaseModel):
    tools: List[str]

class TakeActionResponse(BaseModel):
    reward: float
    record: ActionRecord

class LLMChatRequest(BaseModel):
    prompt: str
    api_key: Optional[str] = None
    base_url: Optional[str] = None
    model: Optional[str] = None

class LLMChatResponse(BaseModel):
    content: str
    model: str

class LLMConfigResponse(BaseModel):
    base_url: str
    model: str
    has_api_key: bool

class LLMTestResponse(BaseModel):
    ok: bool
    model: str
    content: str