File size: 7,425 Bytes
f1a1961
 
 
 
 
 
 
 
 
 
 
 
 
 
47ab3b8
f1a1961
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47ab3b8
f1a1961
 
 
1c620a3
 
47ab3b8
1c620a3
 
 
 
 
 
 
 
 
 
 
 
f1a1961
1c620a3
 
 
 
 
 
 
 
 
 
 
 
f1a1961
1c620a3
 
 
 
 
 
 
 
 
 
 
 
 
47ab3b8
1c620a3
 
 
 
 
 
30134ef
1c620a3
f1a1961
1c620a3
 
 
 
 
 
 
f1a1961
1c620a3
 
 
 
 
 
 
 
47ab3b8
1c620a3
 
 
 
 
 
 
 
47ab3b8
f1a1961
1c620a3
f1a1961
1c620a3
 
 
 
f1a1961
47ab3b8
1c620a3
 
 
 
 
 
 
 
47ab3b8
f1a1961
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import uuid
import datetime
from typing import Optional, Tuple, Dict, Any, List
from .models import CloudAction, CloudObservation, CloudState, CloudActionType

class CloudAuditEnv:
    def __init__(self):
        self.task_id = "easy"
        self._initialize_state()

    def _initialize_state(self):
        self.episode_id = str(uuid.uuid4())
        self.step_count = 0
        self.is_completed = False
        self.score = 0.01
        
        # Mock Infrastructure
        self.resources = {
            "s3": [
                {"id": "prod-data-001", "region": "us-east-1", "public": True, "tags": {"env": "prod"}},
                {"id": "prod-logs-002", "region": "us-east-1", "public": False, "tags": {"env": "prod"}},
                {"id": "dev-test-01", "region": "us-west-2", "public": True, "tags": {"env": "dev"}},
            ],
            "ec2": [
                {"id": "i-0abcdef1234567890", "type": "t2.micro", "state": "running", "tags": {"env": "dev"}, 
                 "security_groups": [{"id": "sg-01", "rules": [{"port": 22, "cidr": "0.0.0.0/0"}, {"port": 3389, "cidr": "0.0.0.0/0"}]}]},
                {"id": "i-0987654321fedcba0", "type": "m5.large", "state": "running", "tags": {"env": "prod"}, 
                 "security_groups": [{"id": "sg-02", "rules": [{"port": 443, "cidr": "0.0.0.0/0"}]}]},
            ],
            "logs": {
                "auth-logs": [
                    {"timestamp": "2026-04-05T10:00:00Z", "user": "admin", "action": "Login", "ip": "1.1.1.1"},
                    {"timestamp": "2026-04-05T10:15:00Z", "user": "iam-role-01", "action": "DeleteStorage", "ip": "192.168.1.50"},
                    {"timestamp": "2026-04-05T10:30:00Z", "user": "user-02", "action": "ListBuckets", "ip": "2.2.2.2"},
                ]
            }
        }

    def reset(self, task_id: str = "easy") -> CloudObservation:
        """Required by openenv-core 0.1.1: takes task_id, returns JUST the observation."""
        self.task_id = task_id
        self._initialize_state()
        return CloudObservation(info=f"Environment reset. Task: {self.task_id}", reward=0.01, done=False)

    def step(self, action: CloudAction) -> CloudObservation:
        """Required by openenv-core 0.1.1: takes action, returns JUST the observation with reward/done fields."""
        try:
            self.step_count += 1
            reward = 0.005
            terminated = False
            truncated = self.step_count >= 20  # Limit steps
            
            obs = CloudObservation()
            
            if action.action == CloudActionType.LIST:
                r_type = action.resource_type
                if r_type in self.resources:
                    obs.resources = self.resources[r_type]
                    obs.status = f"Listed {len(obs.resources)} {r_type} resources."
                else:
                    obs.status = f"Unknown resource type: {r_type}"

            elif action.action == CloudActionType.DESCRIBE:
                res_id = action.resource_id
                found = False
                for r_type in ["s3", "ec2"]:
                    for r in self.resources[r_type]:
                        if r["id"] == res_id:
                            obs.details = r
                            obs.status = f"Described resource {res_id}"
                            found = True
                            break
                if not found:
                    obs.status = f"Resource not found: {res_id}"

            elif action.action == CloudActionType.MODIFY:
                res_id = action.resource_id
                patch = action.patch
                # Simple EC2 security group patching for Medium task
                if self.task_id == "medium" and res_id == "i-0abcdef1234567890":
                    for sg in self.resources["ec2"][0]["security_groups"]:
                        if patch and "rules" in patch:
                            sg["rules"] = patch["rules"]
                    obs.status = f"Updated security groups for EC2 instance {res_id}"
                    # Check for reward
                    rules = self.resources["ec2"][0]["security_groups"][0]["rules"]
                    has_rdp = any(r["port"] == 3389 and r["cidr"] == "0.0.0.0/0" for r in rules)
                    if not has_rdp:
                        reward = 0.85
                        terminated = True
                        obs.info = "Success! Port 3389 removed. Task completed."
                    else:
                        obs.info = "Port 3389 is still open. Remove it by omitting it from the rules list."
                elif self.task_id == "medium":
                    obs.status = f"Invalid resource ID '{res_id}'. Use the EC2 instance ID 'i-0abcdef1234567890', not the security group ID."
                else:
                    obs.status = "Action not permitted or invalid resource."

            elif action.action == CloudActionType.LOGS:
                log_name = action.resource_id
                if log_name in self.resources["logs"]:
                    obs.logs = self.resources["logs"][log_name]
                    obs.status = f"Fetched logs for {log_name}"
                else:
                    obs.status = f"Logs not found: {log_name}"

            elif action.action == CloudActionType.SUBMIT:
                # For Easy and Hard tasks
                if self.task_id == "easy":
                    # Expecting agent to list public S3 buckets in prod
                    if action.answer:
                        answers = [a.strip() for a in action.answer.split(",")]
                        expected = ["prod-data-001"]
                        if set(answers) == set(expected):
                            reward = 0.85
                            terminated = True
                            obs.info = "Correct! Task completed."
                        else:
                            obs.info = f"Incorrect. Expected the public prod S3 bucket ID. Got: {answers}"
                
                elif self.task_id == "hard":
                    # Expecting rogue IP from auth-logs
                    if action.answer and action.answer.strip() == "192.168.1.50":
                        reward = 0.85
                        terminated = True
                        obs.info = "Correct! Rogue IP identified. Task completed."
                    else:
                        obs.info = f"Wrong IP address. Got: {action.answer}. Check the auth-logs for the DeleteStorage action."
                
                elif self.task_id == "medium":
                    obs.info = "For the medium task, use the 'modify' action to update the EC2 security group, not 'submit'."

            self.score = min(0.99, self.score + reward)
            obs.reward = reward
            obs.done = terminated or truncated
            return obs
        except Exception as e:
            import sys
            import traceback
            print(f"ERROR in environment.step: {str(e)}", file=sys.stderr)
            traceback.print_exc(file=sys.stderr)
            return CloudObservation(status=f"Internal Server Error: {str(e)}", reward=0.01, done=True)

    def state(self) -> CloudState:
        return CloudState(
            episode_id=self.episode_id,
            step_count=self.step_count,
            task_id=self.task_id,
            is_completed=self.is_completed,
            score=self.score
        )