sht4bharat commited on
Commit
0bbdc65
·
1 Parent(s): 941ab5e

Final Framework Standard Implementation: Client.py, Tests, Enhanced Observations, and Terminal Rewards

Browse files
Files changed (9) hide show
  1. README.md +28 -0
  2. __init__.py +14 -0
  3. client.py +33 -0
  4. inference.py +17 -6
  5. models.py +2 -0
  6. pyproject.toml +3 -1
  7. server/app.py +8 -0
  8. server/cloud_audit_env.py +32 -7
  9. tests/test_env.py +63 -0
README.md CHANGED
@@ -42,6 +42,34 @@ To guarantee stable JSON validation across different Pydantic parsing engines, t
42
 
43
  ---
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  ## 💻 Quick Start & Evaluation
46
 
47
  ### 1. Run the Environment Server Locally
 
42
 
43
  ---
44
 
45
+ ## 🚀 Framework Standard (Recommended)
46
+
47
+ As of the latest update, `openenv-cloudaudit` is a fully compliant OpenEnv package. You can interact with it using the standard async client:
48
+
49
+ ```python
50
+ from cloudaudit_env import CloudAuditClient, CloudAction
51
+
52
+ client = CloudAuditClient(base_url="http://localhost:7860")
53
+
54
+ # 1. Reset for a specific task
55
+ obs = await client.reset(task_name="medium_remediation")
56
+ print(f"Goal: {obs.task_description}")
57
+
58
+ # 2. Perform actions
59
+ action = CloudAction(action_type="enable_s3_enc", bucket_name="bucket-uuid")
60
+ obs = await client.step(action)
61
+ print(f"Progress: {obs.message}")
62
+ ```
63
+
64
+ ## 📊 Observation & Scoring
65
+ The environment now includes an enriched observation schema for advanced agent reasoning:
66
+ - **`task_description`**: Clearly state the agent's goal for the current session.
67
+ - **`vulnerability_manifest`**: Exposes the target vulnerability counts (e.g. `{"sg_vulns": 3}`).
68
+ - **`health_score`**: Monitors deployment safety (dropping below 0.5 penalizes the agent).
69
+ - **Terminal Bonus**: A trajectory-level reward (+0.1) granted on successful `submit` after full remediation.
70
+
71
+ ---
72
+
73
  ## 💻 Quick Start & Evaluation
74
 
75
  ### 1. Run the Environment Server Locally
__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .models import CloudAction, CloudObservation, CloudState, SecurityGroup, S3Bucket, IAMPolicy
2
+ from .client import CloudAuditClient
3
+ from .server.cloud_audit_env import CloudAuditEnv
4
+
5
+ __all__ = [
6
+ "CloudAuditEnv",
7
+ "CloudAuditClient",
8
+ "CloudAction",
9
+ "CloudObservation",
10
+ "CloudState",
11
+ "SecurityGroup",
12
+ "S3Bucket",
13
+ "IAMPolicy"
14
+ ]
client.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import httpx
2
+ from typing import Optional, List, Dict
3
+ from .models import CloudAction, CloudObservation, CloudState
4
+
5
+ class CloudAuditClient:
6
+ """
7
+ Standard OpenEnv client for the CloudAudit environment.
8
+ Wraps the FastAPI endpoints into a clean Python API.
9
+ """
10
+ def __init__(self, base_url: str = "http://localhost:7860"):
11
+ self.base_url = base_url.rstrip("/")
12
+
13
+ async def reset(self, task_name: str = "easy_audit") -> CloudObservation:
14
+ async with httpx.AsyncClient(timeout=30.0) as client:
15
+ # OpenEnv standard reset often takes task_name in URL or body
16
+ resp = await client.post(f"{self.base_url}/reset", params={"task_name": task_name})
17
+ resp.raise_for_status()
18
+ return CloudObservation(**resp.json())
19
+
20
+ async def step(self, action: CloudAction) -> CloudObservation:
21
+ async with httpx.AsyncClient(timeout=30.0) as client:
22
+ resp = await client.post(f"{self.base_url}/step", json=action.model_dump())
23
+ resp.raise_for_status()
24
+ return CloudObservation(**resp.json())
25
+
26
+ async def get_state(self) -> CloudState:
27
+ async with httpx.AsyncClient(timeout=10.0) as client:
28
+ resp = await client.get(f"{self.base_url}/state")
29
+ resp.raise_for_status()
30
+ return CloudState(**resp.json())
31
+
32
+ async def close(self):
33
+ pass # Handle session cleanup if needed
inference.py CHANGED
@@ -114,26 +114,37 @@ async def run_episode(openai_client: OpenAI, env: AsyncCloudClient, task_name: s
114
  log_start(task=task_name, env=BENCHMARK, model=model_name)
115
 
116
  try:
117
- result = await env.reset()
 
 
 
118
 
119
  for step in range(1, MAX_STEPS + 1):
120
  if result.done:
121
  break
122
 
123
  obs_dict = result.observation.model_dump()
124
- # Remove reward/done/info from prompt context to keep LLM focused on state
125
- if "reward" in obs_dict: del obs_dict["reward"]
126
- if "done" in obs_dict: del obs_dict["done"]
127
- if "info" in obs_dict: del obs_dict["info"]
 
 
 
 
 
 
 
128
 
129
  obs_json = json.dumps(obs_dict)
130
 
131
  try:
 
132
  completion = openai_client.chat.completions.create(
133
  model=model_name,
134
  messages=[
135
  {"role": "system", "content": SYSTEM_PROMPT},
136
- {"role": "user", "content": f"Task: {task_name}\nObservation: {obs_json}\nDecide your next action."},
137
  ],
138
  temperature=TEMPERATURE,
139
  max_tokens=MAX_TOKENS,
 
114
  log_start(task=task_name, env=BENCHMARK, model=model_name)
115
 
116
  try:
117
+ # 1. Reset Environment with Task Context
118
+ response = await env.client.post(f"{env.base_url}/reset", params={"task_name": task_name})
119
+ response.raise_for_status()
120
+ result = StepResult(**response.json())
121
 
122
  for step in range(1, MAX_STEPS + 1):
123
  if result.done:
124
  break
125
 
126
  obs_dict = result.observation.model_dump()
127
+ # Log the goal on the first step
128
+ if step == 1:
129
+ print(f"[TASK] {result.observation.task_description}")
130
+
131
+ # Enrich prompt with manifest for better reasoning
132
+ manifest = result.observation.vulnerability_manifest
133
+ manifest_str = json.dumps({k: v for k, v in manifest.items() if v > 0})
134
+
135
+ # Remove high-level metadata from LLM context to avoid confusion
136
+ for key in ["reward", "done", "info", "message", "health_score"]:
137
+ if key in obs_dict: del obs_dict[key]
138
 
139
  obs_json = json.dumps(obs_dict)
140
 
141
  try:
142
+ prompt = f"Goal: {result.observation.task_description}\nPending Vulnerabilities: {manifest_str}\nCurrent Status: {result.observation.message}\nObservation: {obs_json}\nDecide your next action."
143
  completion = openai_client.chat.completions.create(
144
  model=model_name,
145
  messages=[
146
  {"role": "system", "content": SYSTEM_PROMPT},
147
+ {"role": "user", "content": prompt},
148
  ],
149
  temperature=TEMPERATURE,
150
  max_tokens=MAX_TOKENS,
models.py CHANGED
@@ -36,6 +36,8 @@ class CloudObservation(BaseModel):
36
  rds_instances: List[RDSInstance] = []
37
  ebs_volumes: List[EBSVolume] = []
38
  iam_policies: List[IAMPolicy]
 
 
39
  message: str = "Cloud resources loaded."
40
  reward: float = 0.0
41
  health_score: float = 1.0 # 0.0 to 1.0 (AVAILABILITY)
 
36
  rds_instances: List[RDSInstance] = []
37
  ebs_volumes: List[EBSVolume] = []
38
  iam_policies: List[IAMPolicy]
39
+ task_description: str = "Perform a cloud security audit and remediate vulnerabilities."
40
+ vulnerability_manifest: Dict[str, int] = {} # e.g. {"sg_vulns": 4, "s3_vulns": 3}
41
  message: str = "Cloud resources loaded."
42
  reward: float = 0.0
43
  health_score: float = 1.0 # 0.0 to 1.0 (AVAILABILITY)
pyproject.toml CHANGED
@@ -13,7 +13,9 @@ dependencies = [
13
  "fastapi",
14
  "uvicorn",
15
  "pydantic",
16
- "openai"
 
 
17
  ]
18
 
19
  [project.scripts]
 
13
  "fastapi",
14
  "uvicorn",
15
  "pydantic",
16
+ "openai",
17
+ "pytest",
18
+ "httpx"
19
  ]
20
 
21
  [project.scripts]
server/app.py CHANGED
@@ -10,6 +10,14 @@ app = create_fastapi_app(
10
  observation_cls=CloudObservation
11
  )
12
 
 
 
 
 
 
 
 
 
13
  def main():
14
  import uvicorn
15
  print("[DEBUG] Starting Unified CloudAudit Server", flush=True)
 
10
  observation_cls=CloudObservation
11
  )
12
 
13
+ @app.post("/reset")
14
+ async def reset_with_task(task_name: str = "easy_audit"):
15
+ # Access the shared environment instance created by create_fastapi_app
16
+ # Note: create_fastapi_app typically stores the env instance in app.state.env
17
+ env = app.state.env
18
+ obs = env.reset(task_name=task_name)
19
+ return obs.model_dump()
20
+
21
  def main():
22
  import uvicorn
23
  print("[DEBUG] Starting Unified CloudAudit Server", flush=True)
server/cloud_audit_env.py CHANGED
@@ -15,13 +15,22 @@ class CloudAuditEnv(Environment):
15
  self.max_steps = 30
16
  self.reset()
17
 
18
- def reset(self) -> CloudObservation:
 
19
  self.step_count = 0
20
  self.remediated_count = 0
21
  self.health_score = 1.0
22
  self.cumulative_reward = 0.0
23
  self.done = False
24
 
 
 
 
 
 
 
 
 
25
  # Procedural Generation Configuration
26
  self.sgs: List[SecurityGroup] = []
27
  self.buckets: List[S3Bucket] = []
@@ -191,7 +200,11 @@ class CloudAuditEnv(Environment):
191
 
192
  elif at == "submit":
193
  self.done = True
194
- message = "Audit report submitted."
 
 
 
 
195
 
196
  # Update Cumulative Tracking
197
  self.cumulative_reward += reward
@@ -235,12 +248,24 @@ class CloudAuditEnv(Environment):
235
  )
236
 
237
  def _get_observation(self, message: str, reward: float = 0.0, done: bool = False) -> CloudObservation:
 
 
 
 
 
 
 
 
 
 
238
  return CloudObservation(
239
- security_groups=self.sgs,
240
- s3_buckets=self.buckets,
241
- rds_instances=self.rds,
242
- ebs_volumes=self.ebs,
243
- iam_policies=self.policies,
 
 
244
  message=message,
245
  reward=reward,
246
  health_score=self.health_score,
 
15
  self.max_steps = 30
16
  self.reset()
17
 
18
+ def reset(self, task_name: str = "easy_audit") -> CloudObservation:
19
+ self.task_name = task_name
20
  self.step_count = 0
21
  self.remediated_count = 0
22
  self.health_score = 1.0
23
  self.cumulative_reward = 0.0
24
  self.done = False
25
 
26
+ # Task Descriptions
27
+ descriptions = {
28
+ "easy_audit": "Find and fix all Security Group rules that allow SSH/RDP access from the public internet (0.0.0.0/0).",
29
+ "medium_remediation": "Ensure all S3 buckets, RDS instances, and EBS volumes are encrypted at rest.",
30
+ "hard_iam_refactor": "Refactor IAM policies to remove wildcards ('*') while preserving required service permissions."
31
+ }
32
+ self.task_description = descriptions.get(task_name, "Perform a comprehensive cloud security audit.")
33
+
34
  # Procedural Generation Configuration
35
  self.sgs: List[SecurityGroup] = []
36
  self.buckets: List[S3Bucket] = []
 
200
 
201
  elif at == "submit":
202
  self.done = True
203
+ if self.remediated_count >= self.initial_vulns and self.health_score > 0.5:
204
+ reward = 0.1 # Terminal bonus
205
+ message = "Audit report submitted. Perfect remediation achieved!"
206
+ else:
207
+ message = "Audit report submitted."
208
 
209
  # Update Cumulative Tracking
210
  self.cumulative_reward += reward
 
248
  )
249
 
250
  def _get_observation(self, message: str, reward: float = 0.0, done: bool = False) -> CloudObservation:
251
+ # Task-specific resource filtering (Optional based on user recommendation)
252
+ show_sgs = self.sgs
253
+ show_buckets = self.buckets
254
+ show_rds = self.rds
255
+ show_ebs = self.ebs
256
+ show_policies = self.policies
257
+
258
+ # If task is focused, we still show other resources but manifest tells the agent where the vulns are
259
+ manifest = self.vulnerability_manifest.copy()
260
+
261
  return CloudObservation(
262
+ security_groups=show_sgs,
263
+ s3_buckets=show_buckets,
264
+ rds_instances=show_rds,
265
+ ebs_volumes=show_ebs,
266
+ iam_policies=show_policies,
267
+ task_description=self.task_description,
268
+ vulnerability_manifest=manifest,
269
  message=message,
270
  reward=reward,
271
  health_score=self.health_score,
tests/test_env.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from server.cloud_audit_env import CloudAuditEnv
3
+ from models import CloudAction
4
+
5
+ def test_env_reset():
6
+ env = CloudAuditEnv()
7
+ obs = env.reset(task_name="easy_audit")
8
+ assert obs.task_description == "Find and fix all Security Group rules that allow SSH/RDP access from the public internet (0.0.0.0/0)."
9
+ assert len(obs.security_groups) > 0
10
+ assert obs.health_score == 1.0
11
+ assert not obs.done
12
+
13
+ def test_env_step_audit():
14
+ env = CloudAuditEnv()
15
+ env.reset()
16
+ action = CloudAction(action_type="audit")
17
+ obs = env.step(action)
18
+ assert obs.message == "Audit log generated."
19
+ assert obs.reward == 0.01
20
+ assert env.step_count == 1
21
+
22
+ def test_health_penalty_iam():
23
+ env = CloudAuditEnv()
24
+ env.reset(task_name="hard_iam_refactor")
25
+
26
+ # Get a policy that has a required permission
27
+ p_id = list(env.required_iam_perms.keys())[0]
28
+ required = env.required_iam_perms[p_id]
29
+
30
+ # Update with an empty document (breaking access)
31
+ action = CloudAction(action_type="update_iam", policy_id=p_id, new_document="{}")
32
+ obs = env.step(action)
33
+
34
+ assert obs.health_score < 1.0
35
+ assert "CRITICAL" in obs.message
36
+
37
+ def test_terminal_bonus():
38
+ env = CloudAuditEnv()
39
+ env.reset()
40
+
41
+ # Spoof remediation count to match initial vulns
42
+ env.remediated_count = env.initial_vulns
43
+
44
+ action = CloudAction(action_type="submit")
45
+ obs = env.step(action)
46
+
47
+ assert obs.done
48
+ assert obs.reward == 0.1 # Terminal bonus
49
+ assert "Perfect remediation" in obs.message
50
+
51
+ def test_grader_resilience():
52
+ from graders import get_task_score
53
+ env = CloudAuditEnv()
54
+ env.reset(task_name="easy_audit")
55
+
56
+ # Test with object-based state (directly from env)
57
+ score = get_task_score("easy_audit", env.state.__dict__)
58
+ assert 0.15 <= score <= 0.85
59
+
60
+ # Test with dict-based state (simulating JSON API)
61
+ state_dict = env.state.model_dump()
62
+ score = get_task_score("easy_audit", state_dict)
63
+ assert 0.15 <= score <= 0.85