Spaces:
Sleeping
Sleeping
Final Framework Standard Implementation: Client.py, Tests, Enhanced Observations, and Terminal Rewards
0bbdc65 | import json | |
| import random | |
| import uuid | |
| from typing import List, Tuple, Dict | |
| from openenv.core.env_server import Environment | |
| from models import ( | |
| CloudAction, CloudObservation, CloudState, | |
| SecurityGroup, SecurityGroupRule, S3Bucket, IAMPolicy, | |
| RDSInstance, EBSVolume | |
| ) | |
| class CloudAuditEnv(Environment): | |
| def __init__(self): | |
| super().__init__() | |
| self.max_steps = 30 | |
| self.reset() | |
| def reset(self, task_name: str = "easy_audit") -> CloudObservation: | |
| self.task_name = task_name | |
| self.step_count = 0 | |
| self.remediated_count = 0 | |
| self.health_score = 1.0 | |
| self.cumulative_reward = 0.0 | |
| self.done = False | |
| # Task Descriptions | |
| descriptions = { | |
| "easy_audit": "Find and fix all Security Group rules that allow SSH/RDP access from the public internet (0.0.0.0/0).", | |
| "medium_remediation": "Ensure all S3 buckets, RDS instances, and EBS volumes are encrypted at rest.", | |
| "hard_iam_refactor": "Refactor IAM policies to remove wildcards ('*') while preserving required service permissions." | |
| } | |
| self.task_description = descriptions.get(task_name, "Perform a comprehensive cloud security audit.") | |
| # Procedural Generation Configuration | |
| self.sgs: List[SecurityGroup] = [] | |
| self.buckets: List[S3Bucket] = [] | |
| self.rds: List[RDSInstance] = [] | |
| self.ebs: List[EBSVolume] = [] | |
| self.policies: List[IAMPolicy] = [] | |
| self.vulnerability_manifest = {"sg_vulns": 0, "s3_vulns": 0, "rds_vulns": 0, "ebs_vulns": 0, "iam_vulns": 0} | |
| self.essential_rules: Dict[str, List[Tuple[int, str]]] = {} # sg_id -> list of (port, cidr) | |
| self.required_iam_perms: Dict[str, str] = {} # policy_id -> required substring (e.g. s3:GetObject) | |
| self._generate_procedural_assets() | |
| self.initial_vulns = sum(self.vulnerability_manifest.values()) | |
| return self._get_observation(f"Environment reset. Procedural audit pending with {len(self.sgs) + len(self.buckets) + len(self.rds) + len(self.ebs) + len(self.policies)} resources.", reward=0.0, done=False) | |
| def _generate_procedural_assets(self): | |
| # 1. Security Groups (5-7) | |
| num_sgs = random.randint(5, 7) | |
| for i in range(num_sgs): | |
| sg_id = f"sg-{uuid.uuid4().hex[:6]}" | |
| rules = [] | |
| # Legitimate rule (Essential) | |
| ess_port = random.choice([443, 8080, 5432]) | |
| rules.append(SecurityGroupRule(port=ess_port, cidr="10.0.0.0/8")) | |
| self.essential_rules[sg_id] = [(ess_port, "10.0.0.0/8")] | |
| # Vulnerable rule (sometimes) | |
| if random.random() > 0.3: | |
| vuln_port = random.choice([22, 3389]) | |
| rules.append(SecurityGroupRule(port=vuln_port, cidr="0.0.0.0/0")) | |
| self.vulnerability_manifest["sg_vulns"] += 1 | |
| self.sgs.append(SecurityGroup(id=sg_id, name=f"group-{i}", ingress_rules=rules)) | |
| # 2. S3 Buckets (4-6) | |
| num_buckets = random.randint(4, 6) | |
| for i in range(num_buckets): | |
| name = f"bucket-{uuid.uuid4().hex[:8]}" | |
| encrypted = random.choice([True, False]) | |
| if not encrypted: self.vulnerability_manifest["s3_vulns"] += 1 | |
| self.buckets.append(S3Bucket(name=name, encrypted=encrypted)) | |
| # 3. RDS Instances (2-3) | |
| num_rds = random.randint(2, 3) | |
| for i in range(num_rds): | |
| rid = f"db-{uuid.uuid4().hex[:4]}" | |
| enc = random.choice([True, False]) | |
| if not enc: self.vulnerability_manifest["rds_vulns"] += 1 | |
| self.rds.append(RDSInstance(id=rid, engine="postgres", encrypted=enc)) | |
| # 4. EBS Volumes (2-3) | |
| num_ebs = random.randint(2, 3) | |
| for i in range(num_ebs): | |
| eid = f"vol-{uuid.uuid4().hex[:4]}" | |
| enc = random.choice([True, False]) | |
| if not enc: self.vulnerability_manifest["ebs_vulns"] += 1 | |
| self.ebs.append(EBSVolume(id=eid, encrypted=enc)) | |
| # 5. IAM Policies (3) | |
| for i in range(3): | |
| pid = f"p-{uuid.uuid4().hex[:4]}" | |
| required = random.choice(["s3:GetObject", "ec2:DescribeInstances", "iam:GetUser"]) | |
| self.required_iam_perms[pid] = required | |
| is_vuln = random.choice([True, False]) | |
| if is_vuln: | |
| doc = {"Version": "2012-10-17", "Statement": [{"Effect": "Allow", "Action": "*", "Resource": "*"}]} | |
| self.vulnerability_manifest["iam_vulns"] += 1 | |
| else: | |
| doc = {"Version": "2012-10-17", "Statement": [{"Effect": "Allow", "Action": required, "Resource": "*"}]} | |
| self.policies.append(IAMPolicy(id=pid, name=f"Policy-{i}", document=json.dumps(doc))) | |
| def step(self, action: CloudAction) -> CloudObservation: | |
| self.step_count += 1 | |
| reward = 0.0 | |
| message = "" | |
| if self.step_count >= self.max_steps: | |
| self.done = True | |
| message = "Max steps reached." | |
| at = action.action_type | |
| fix_reward_increment = 0.8 / max(1, self.initial_vulns) | |
| if at == "audit": | |
| message = "Audit log generated." | |
| reward = 0.01 | |
| elif at == "fix_sg": | |
| if action.sg_id and action.port is not None and action.cidr_to_remove: | |
| for sg in self.sgs: | |
| if sg.id == action.sg_id: | |
| # Check for availability penalty (removing essential rule) | |
| if (action.port, action.cidr_to_remove) in self.essential_rules.get(sg.id, []): | |
| self.health_score -= 0.2 | |
| message = f"CRITICAL: Removed essential rule on {sg.id}! Availability decreased." | |
| original_len = len(sg.ingress_rules) | |
| old_vulns = self._check_sg_vulns(sg) | |
| sg.ingress_rules = [r for r in sg.ingress_rules if not (r.port == action.port and r.cidr == action.cidr_to_remove)] | |
| new_vulns = self._check_sg_vulns(sg) | |
| if new_vulns < old_vulns: | |
| reward = fix_reward_increment | |
| self.remediated_count += 1 | |
| message = message or f"Fixed rule on {sg.id}." | |
| break | |
| elif at == "remediate_all_in_sg": | |
| if action.sg_id: | |
| for sg in self.sgs: | |
| if sg.id == action.sg_id: | |
| vulns = self._check_sg_vulns(sg) | |
| if vulns > 0: | |
| sg.ingress_rules = [r for r in sg.ingress_rules if not (r.port in [22, 3389] and r.cidr == "0.0.0.0/0")] | |
| reward = fix_reward_increment * vulns | |
| self.remediated_count += vulns | |
| message = f"Batch remediated {vulns} issues in {sg.id}." | |
| break | |
| elif at == "enable_s3_enc": | |
| for b in self.buckets: | |
| if b.name == action.bucket_name and not b.encrypted: | |
| b.encrypted = True | |
| reward = fix_reward_increment | |
| self.remediated_count += 1 | |
| message = f"Encrypted bucket {b.name}." | |
| break | |
| elif at == "enable_rds_enc": | |
| for db in self.rds: | |
| if db.id == action.rds_id and not db.encrypted: | |
| db.encrypted = True | |
| reward = fix_reward_increment | |
| self.remediated_count += 1 | |
| message = f"Encrypted RDS {db.id}." | |
| break | |
| elif at == "enable_ebs_enc": | |
| for vol in self.ebs: | |
| if vol.id == action.ebs_id and not vol.encrypted: | |
| vol.encrypted = True | |
| reward = fix_reward_increment | |
| self.remediated_count += 1 | |
| message = f"Encrypted EBS {vol.id}." | |
| break | |
| elif at == "update_iam": | |
| for p in self.policies: | |
| if p.id == action.policy_id: | |
| # Check for availability penalty (missing required perm) | |
| required = self.required_iam_perms.get(p.id) | |
| if required and required not in (action.new_document or ""): | |
| self.health_score -= 0.3 | |
| message = f"CRITICAL: IAM update broke required service access for {p.id}!" | |
| was_vuln = "*" in p.document | |
| p.document = action.new_document or "" | |
| is_vuln = "*" in p.document | |
| if was_vuln and not is_vuln: | |
| reward = fix_reward_increment | |
| self.remediated_count += 1 | |
| message = message or f"Refactored policy {p.id} to least privilege." | |
| break | |
| elif at == "submit": | |
| self.done = True | |
| if self.remediated_count >= self.initial_vulns and self.health_score > 0.5: | |
| reward = 0.1 # Terminal bonus | |
| message = "Audit report submitted. Perfect remediation achieved!" | |
| else: | |
| message = "Audit report submitted." | |
| # Update Cumulative Tracking | |
| self.cumulative_reward += reward | |
| # Enforce strict Phase 2 boundaries (0.1, 0.9) | |
| # Even if 0 progress is made, we return a small positive base | |
| # If perfect progress is made, we cap at 0.9 | |
| capped_total = 0.1 + (min(max(self.cumulative_reward, 0.0), 1.0) * 0.8) | |
| # We report the "delta" as the step reward so summing works naturally, | |
| # but the FINAL total will be in [0.1, 0.9]. | |
| # Or better: just set the specific step reward to ensure the current total is safe. | |
| step_reward = reward | |
| if self.health_score <= 0: | |
| self.done = True | |
| step_reward = -1.0 # Force total towards the bottom | |
| message = "CRITICAL FAILURE: Production environment is offline. Mission failed." | |
| obs = self._get_observation(message, reward=step_reward, done=self.done) | |
| obs.health_score = self.health_score | |
| return obs | |
| def _check_sg_vulns(self, sg: SecurityGroup) -> int: | |
| return sum(1 for r in sg.ingress_rules if r.port in [22, 3389] and r.cidr == "0.0.0.0/0") | |
| def state(self) -> CloudState: | |
| return CloudState( | |
| task_name="Production Cloud Security Audit", | |
| step_count=self.step_count, | |
| max_steps=self.max_steps, | |
| remediated_count=self.remediated_count, | |
| security_groups=self.sgs, | |
| s3_buckets=self.buckets, | |
| rds_instances=self.rds, | |
| ebs_volumes=self.ebs, | |
| iam_policies=self.policies, | |
| vulnerability_manifest=self.vulnerability_manifest, | |
| required_iam_perms=self.required_iam_perms, | |
| health_score=self.health_score | |
| ) | |
| def _get_observation(self, message: str, reward: float = 0.0, done: bool = False) -> CloudObservation: | |
| # Task-specific resource filtering (Optional based on user recommendation) | |
| show_sgs = self.sgs | |
| show_buckets = self.buckets | |
| show_rds = self.rds | |
| show_ebs = self.ebs | |
| show_policies = self.policies | |
| # If task is focused, we still show other resources but manifest tells the agent where the vulns are | |
| manifest = self.vulnerability_manifest.copy() | |
| return CloudObservation( | |
| security_groups=show_sgs, | |
| s3_buckets=show_buckets, | |
| rds_instances=show_rds, | |
| ebs_volumes=show_ebs, | |
| iam_policies=show_policies, | |
| task_description=self.task_description, | |
| vulnerability_manifest=manifest, | |
| message=message, | |
| reward=reward, | |
| health_score=self.health_score, | |
| done=done, | |
| info={} | |
| ) | |