import json import random import uuid from typing import List, Tuple, Dict from openenv.core.env_server import Environment from models import ( CloudAction, CloudObservation, CloudState, SecurityGroup, SecurityGroupRule, S3Bucket, IAMPolicy, RDSInstance, EBSVolume ) class CloudAuditEnv(Environment): def __init__(self): super().__init__() self.max_steps = 30 self.reset() def reset(self, task_name: str = "easy_audit") -> CloudObservation: self.task_name = task_name self.step_count = 0 self.remediated_count = 0 self.health_score = 1.0 self.cumulative_reward = 0.0 self.done = False # Task Descriptions descriptions = { "easy_audit": "Find and fix all Security Group rules that allow SSH/RDP access from the public internet (0.0.0.0/0).", "medium_remediation": "Ensure all S3 buckets, RDS instances, and EBS volumes are encrypted at rest.", "hard_iam_refactor": "Refactor IAM policies to remove wildcards ('*') while preserving required service permissions." } self.task_description = descriptions.get(task_name, "Perform a comprehensive cloud security audit.") # Procedural Generation Configuration self.sgs: List[SecurityGroup] = [] self.buckets: List[S3Bucket] = [] self.rds: List[RDSInstance] = [] self.ebs: List[EBSVolume] = [] self.policies: List[IAMPolicy] = [] self.vulnerability_manifest = {"sg_vulns": 0, "s3_vulns": 0, "rds_vulns": 0, "ebs_vulns": 0, "iam_vulns": 0} self.essential_rules: Dict[str, List[Tuple[int, str]]] = {} # sg_id -> list of (port, cidr) self.required_iam_perms: Dict[str, str] = {} # policy_id -> required substring (e.g. s3:GetObject) self._generate_procedural_assets() self.initial_vulns = sum(self.vulnerability_manifest.values()) return self._get_observation(f"Environment reset. Procedural audit pending with {len(self.sgs) + len(self.buckets) + len(self.rds) + len(self.ebs) + len(self.policies)} resources.", reward=0.0, done=False) def _generate_procedural_assets(self): # 1. Security Groups (5-7) num_sgs = random.randint(5, 7) for i in range(num_sgs): sg_id = f"sg-{uuid.uuid4().hex[:6]}" rules = [] # Legitimate rule (Essential) ess_port = random.choice([443, 8080, 5432]) rules.append(SecurityGroupRule(port=ess_port, cidr="10.0.0.0/8")) self.essential_rules[sg_id] = [(ess_port, "10.0.0.0/8")] # Vulnerable rule (sometimes) if random.random() > 0.3: vuln_port = random.choice([22, 3389]) rules.append(SecurityGroupRule(port=vuln_port, cidr="0.0.0.0/0")) self.vulnerability_manifest["sg_vulns"] += 1 self.sgs.append(SecurityGroup(id=sg_id, name=f"group-{i}", ingress_rules=rules)) # 2. S3 Buckets (4-6) num_buckets = random.randint(4, 6) for i in range(num_buckets): name = f"bucket-{uuid.uuid4().hex[:8]}" encrypted = random.choice([True, False]) if not encrypted: self.vulnerability_manifest["s3_vulns"] += 1 self.buckets.append(S3Bucket(name=name, encrypted=encrypted)) # 3. RDS Instances (2-3) num_rds = random.randint(2, 3) for i in range(num_rds): rid = f"db-{uuid.uuid4().hex[:4]}" enc = random.choice([True, False]) if not enc: self.vulnerability_manifest["rds_vulns"] += 1 self.rds.append(RDSInstance(id=rid, engine="postgres", encrypted=enc)) # 4. EBS Volumes (2-3) num_ebs = random.randint(2, 3) for i in range(num_ebs): eid = f"vol-{uuid.uuid4().hex[:4]}" enc = random.choice([True, False]) if not enc: self.vulnerability_manifest["ebs_vulns"] += 1 self.ebs.append(EBSVolume(id=eid, encrypted=enc)) # 5. IAM Policies (3) for i in range(3): pid = f"p-{uuid.uuid4().hex[:4]}" required = random.choice(["s3:GetObject", "ec2:DescribeInstances", "iam:GetUser"]) self.required_iam_perms[pid] = required is_vuln = random.choice([True, False]) if is_vuln: doc = {"Version": "2012-10-17", "Statement": [{"Effect": "Allow", "Action": "*", "Resource": "*"}]} self.vulnerability_manifest["iam_vulns"] += 1 else: doc = {"Version": "2012-10-17", "Statement": [{"Effect": "Allow", "Action": required, "Resource": "*"}]} self.policies.append(IAMPolicy(id=pid, name=f"Policy-{i}", document=json.dumps(doc))) def step(self, action: CloudAction) -> CloudObservation: self.step_count += 1 reward = 0.0 message = "" if self.step_count >= self.max_steps: self.done = True message = "Max steps reached." at = action.action_type fix_reward_increment = 0.8 / max(1, self.initial_vulns) if at == "audit": message = "Audit log generated." reward = 0.01 elif at == "fix_sg": if action.sg_id and action.port is not None and action.cidr_to_remove: for sg in self.sgs: if sg.id == action.sg_id: # Check for availability penalty (removing essential rule) if (action.port, action.cidr_to_remove) in self.essential_rules.get(sg.id, []): self.health_score -= 0.2 message = f"CRITICAL: Removed essential rule on {sg.id}! Availability decreased." original_len = len(sg.ingress_rules) old_vulns = self._check_sg_vulns(sg) sg.ingress_rules = [r for r in sg.ingress_rules if not (r.port == action.port and r.cidr == action.cidr_to_remove)] new_vulns = self._check_sg_vulns(sg) if new_vulns < old_vulns: reward = fix_reward_increment self.remediated_count += 1 message = message or f"Fixed rule on {sg.id}." break elif at == "remediate_all_in_sg": if action.sg_id: for sg in self.sgs: if sg.id == action.sg_id: vulns = self._check_sg_vulns(sg) if vulns > 0: sg.ingress_rules = [r for r in sg.ingress_rules if not (r.port in [22, 3389] and r.cidr == "0.0.0.0/0")] reward = fix_reward_increment * vulns self.remediated_count += vulns message = f"Batch remediated {vulns} issues in {sg.id}." break elif at == "enable_s3_enc": for b in self.buckets: if b.name == action.bucket_name and not b.encrypted: b.encrypted = True reward = fix_reward_increment self.remediated_count += 1 message = f"Encrypted bucket {b.name}." break elif at == "enable_rds_enc": for db in self.rds: if db.id == action.rds_id and not db.encrypted: db.encrypted = True reward = fix_reward_increment self.remediated_count += 1 message = f"Encrypted RDS {db.id}." break elif at == "enable_ebs_enc": for vol in self.ebs: if vol.id == action.ebs_id and not vol.encrypted: vol.encrypted = True reward = fix_reward_increment self.remediated_count += 1 message = f"Encrypted EBS {vol.id}." break elif at == "update_iam": for p in self.policies: if p.id == action.policy_id: # Check for availability penalty (missing required perm) required = self.required_iam_perms.get(p.id) if required and required not in (action.new_document or ""): self.health_score -= 0.3 message = f"CRITICAL: IAM update broke required service access for {p.id}!" was_vuln = "*" in p.document p.document = action.new_document or "" is_vuln = "*" in p.document if was_vuln and not is_vuln: reward = fix_reward_increment self.remediated_count += 1 message = message or f"Refactored policy {p.id} to least privilege." break elif at == "submit": self.done = True if self.remediated_count >= self.initial_vulns and self.health_score > 0.5: reward = 0.1 # Terminal bonus message = "Audit report submitted. Perfect remediation achieved!" else: message = "Audit report submitted." # Update Cumulative Tracking self.cumulative_reward += reward # Enforce strict Phase 2 boundaries (0.1, 0.9) # Even if 0 progress is made, we return a small positive base # If perfect progress is made, we cap at 0.9 capped_total = 0.1 + (min(max(self.cumulative_reward, 0.0), 1.0) * 0.8) # We report the "delta" as the step reward so summing works naturally, # but the FINAL total will be in [0.1, 0.9]. # Or better: just set the specific step reward to ensure the current total is safe. step_reward = reward if self.health_score <= 0: self.done = True step_reward = -1.0 # Force total towards the bottom message = "CRITICAL FAILURE: Production environment is offline. Mission failed." obs = self._get_observation(message, reward=step_reward, done=self.done) obs.health_score = self.health_score return obs def _check_sg_vulns(self, sg: SecurityGroup) -> int: return sum(1 for r in sg.ingress_rules if r.port in [22, 3389] and r.cidr == "0.0.0.0/0") @property def state(self) -> CloudState: return CloudState( task_name="Production Cloud Security Audit", step_count=self.step_count, max_steps=self.max_steps, remediated_count=self.remediated_count, security_groups=self.sgs, s3_buckets=self.buckets, rds_instances=self.rds, ebs_volumes=self.ebs, iam_policies=self.policies, vulnerability_manifest=self.vulnerability_manifest, required_iam_perms=self.required_iam_perms, health_score=self.health_score ) def _get_observation(self, message: str, reward: float = 0.0, done: bool = False) -> CloudObservation: # Task-specific resource filtering (Optional based on user recommendation) show_sgs = self.sgs show_buckets = self.buckets show_rds = self.rds show_ebs = self.ebs show_policies = self.policies # If task is focused, we still show other resources but manifest tells the agent where the vulns are manifest = self.vulnerability_manifest.copy() return CloudObservation( security_groups=show_sgs, s3_buckets=show_buckets, rds_instances=show_rds, ebs_volumes=show_ebs, iam_policies=show_policies, task_description=self.task_description, vulnerability_manifest=manifest, message=message, reward=reward, health_score=self.health_score, done=done, info={} )