"""CloudSense RL environment — manages state, steps, and episode lifecycle.""" import copy import json from pathlib import Path from env.models import ActionType, CloudObservation, StepResult from env.reward import compute_reward from env.tasks import TASKS from env.graders.grader import grade_task DATA_DIR = Path(__file__).parent / "data" PRICING_FILE = DATA_DIR / "aws_pricing.json" # Schedule uptime savings constants # Weekday-only: run Mon-Fri (5 of 7 days), pay for 5/7 of the month WEEKDAY_ONLY_FACTOR = 5 / 7 # 0.714 — cost multiplier # Business hours: run 10hrs/day instead of 24 BUSINESS_HOURS_FACTOR = 10 / 24 # 0.417 — cost multiplier # Combined: Mon-Fri, 10hrs/day WEEKDAY_BUSINESS_HOURS = (5 / 7) * (10 / 24) # 0.298 — cost multiplier # Action cost multipliers for non-rightsize actions LIFECYCLE_POLICY_FACTOR = 0.30 # S3: transition to IA saves ~70% AUTOSCALING_FACTOR = 0.80 # ~20% savings from dynamic scaling RESERVATION_FACTOR = 0.70 # RI typically saves ~30-40% on on-demand with open(PRICING_FILE) as f: PRICING = json.load(f) class CloudSenseEnv: def __init__(self): self.task = None self.task_id = None self.current_resources: list[dict] = [] self.original_resources: list[dict] = [] self.actions_history: list[dict] = [] self.original_cost: float = 0.0 self.current_cost: float = 0.0 self.step_count: int = 0 self.done: bool = False self.last_blast_radius: dict | None = None self.last_reward: float = 0.0 self.last_action_error: str | None = None def reset(self, task_id: str) -> CloudObservation: """Reset environment for a new episode.""" task_cls = TASKS.get(task_id) if task_cls is None: raise ValueError(f"Unknown task: {task_id}. Available: {list(TASKS.keys())}") self.task = task_cls() self.task_id = task_id raw = self.task.load_account_raw() self.current_resources = copy.deepcopy(raw) self.original_resources = copy.deepcopy(raw) self.actions_history = [] self.original_cost = sum(r["monthly_cost"] for r in self.current_resources) self.current_cost = self.original_cost self.step_count = 0 self.done = False self.last_blast_radius = None self.last_reward = 0.0 self.last_action_error = None return self._make_observation() def step(self, action: dict) -> StepResult: """Execute one action and return result.""" if self.task is None: raise RuntimeError("Environment not initialized. Call reset() first.") if self.done: raise RuntimeError("Episode is done. Call reset() to start a new one.") self.step_count += 1 self.last_action_error = None # Validate action action_type = action.get("action_type", "") resource_id = action.get("resource_id", "") # Validate action_type valid_types = {a.value for a in ActionType} if action_type not in valid_types: self.last_action_error = f"Invalid action_type: {action_type}" return self._make_step_result(0.0) # Find the resource resource = None for r in self.current_resources: if r["resource_id"] == resource_id: resource = r break if resource is None: self.last_action_error = f"Resource not found: {resource_id}" return self._make_step_result(0.0) # Validate action_type is applicable to resource type error = self._validate_action_for_resource(action_type, resource, action) if error: self.last_action_error = error return self._make_step_result(0.0) # Compute blast radius BEFORE applying action blast_radius = self._compute_blast_radius(action, resource) # Apply action and compute new cost cost_before = self.current_cost self._apply_action(action, resource) cost_after = sum(r["monthly_cost"] for r in self.current_resources) self.current_cost = cost_after # Compute reward — raw for grader, api for response # Pass current action's blast radius so reward reflects THIS action's impact raw_reward, api_reward = compute_reward( action=action, resource=resource, task=self.task, account_state=self.current_resources, actions_history=self.actions_history, original_cost=self.original_cost, current_cost=cost_before, cost_after_action=cost_after, blast_radius=blast_radius, ) self.actions_history.append({**action, "_raw_reward": raw_reward}) self.last_blast_radius = blast_radius self.last_reward = api_reward # Check done condition self.done = self._is_done() info = {"blast_radius": blast_radius} # If done, compute final grader score if self.done: task_score = grade_task( self.task_id, self.actions_history, self.state(), {}, ) info["task_score"] = task_score return self._make_step_result(api_reward, info) def state(self) -> dict: """Return full current environment state.""" return { "task_id": self.task_id, "resources": copy.deepcopy(self.current_resources), "original_cost": self.original_cost, "current_cost": self.current_cost, "step_count": self.step_count, "max_steps": self.task.max_steps if self.task else 0, "done": self.done, "actions_history": list(self.actions_history), } def close(self): """Clean up environment state.""" self.task = None self.task_id = None self.current_resources = [] self.original_resources = [] self.actions_history = [] self.original_cost = 0.0 self.current_cost = 0.0 self.step_count = 0 self.done = False self.last_blast_radius = None self.last_reward = 0.0 self.last_action_error = None def _make_observation(self) -> CloudObservation: """Build an observation from current state.""" action_savings = self.task.get_action_savings() if self.task else {} total_possible_savings = sum(action_savings.values()) return CloudObservation( task_id=self.task_id or "", goal=self.task.description if self.task else "", account_id=f"account-{self.task_id}", resources=copy.deepcopy(self.current_resources), monthly_cost_current=self.current_cost, monthly_cost_optimized=self.current_cost - total_possible_savings, total_possible_savings=total_possible_savings, actions_taken=[dict(a) for a in self.actions_history], warnings=[], step_number=self.step_count, max_steps=self.task.max_steps if self.task else 0, last_reward=self.last_reward, last_action_error=self.last_action_error, info={"blast_radius": self.last_blast_radius} if self.last_blast_radius else {}, ) def _make_step_result(self, reward: float, info: dict | None = None) -> StepResult: """Build a StepResult.""" obs = self._make_observation() if info is None: info = {} if self.last_blast_radius and "blast_radius" not in info: info["blast_radius"] = self.last_blast_radius if "blast_radius" not in info: info["blast_radius"] = {"affected_resources": [], "risk_level": "none", "explanation": ""} return StepResult( observation=obs, reward=reward, done=self.done, info=info, ) def _validate_action_for_resource(self, action_type: str, resource: dict, action: dict = None) -> str | None: """Return error message if action is invalid for this resource type, else None.""" rtype = resource.get("resource_type", "") if action_type == ActionType.add_lifecycle_policy.value and rtype != "s3": return f"add_lifecycle_policy only applies to S3, not {rtype}" if action_type == ActionType.change_storage_class.value and rtype != "s3": return f"change_storage_class only applies to S3, not {rtype}" if action_type == ActionType.enable_autoscaling.value and rtype not in ("ec2", "kubernetes"): return f"enable_autoscaling only applies to EC2/K8s, not {rtype}" if action_type == ActionType.rightsize_resource.value: if rtype in ("s3", "ebs", "eip", "nat_gateway", "load_balancer"): return f"rightsize_resource does not apply to {rtype}" if action and not action.get("new_config"): return "rightsize_resource requires new_config with target configuration" if action_type == ActionType.purchase_reservation.value: eligible = ("ec2", "rds", "elasticsearch", "kubernetes") if rtype not in eligible: return f"purchase_reservation only applies to {', '.join(eligible)}, not {rtype}" return None def _apply_action(self, action: dict, resource: dict): """Simulate the cost impact of an action on a resource.""" action_type = action.get("action_type", "") new_config = action.get("new_config", {}) if action_type == ActionType.terminate_resource.value: resource["monthly_cost"] = 0.0 resource["current_config"]["terminated"] = True elif action_type == ActionType.rightsize_resource.value: if new_config: new_cost = self._compute_new_cost(resource, new_config) if new_cost is not None: resource["monthly_cost"] = new_cost resource["current_config"].update(new_config) elif action_type == ActionType.add_lifecycle_policy.value: resource["monthly_cost"] = round(resource["monthly_cost"] * LIFECYCLE_POLICY_FACTOR, 2) resource["current_config"]["lifecycle_policy"] = "transition_to_ia_30d" elif action_type == ActionType.change_storage_class.value: target = (new_config or {}).get("storage_class", "GLACIER_DEEP_ARCHIVE") if target == "GLACIER_DEEP_ARCHIVE": resource["monthly_cost"] = round(resource["monthly_cost"] * 0.04, 2) elif target == "GLACIER_INSTANT": resource["monthly_cost"] = round(resource["monthly_cost"] * 0.17, 2) elif target == "INFREQUENT_ACCESS": resource["monthly_cost"] = round(resource["monthly_cost"] * 0.54, 2) elif target == "STANDARD": pass # No cost change — already at standard pricing resource["current_config"]["storage_class"] = target elif action_type == ActionType.schedule_uptime.value: pattern = (new_config or {}).get("schedule", "weekday_only") if pattern == "business_hours": # Mon-Fri 10hrs/day: cost = original * WEEKDAY_BUSINESS_HOURS resource["monthly_cost"] = round(resource["monthly_cost"] * WEEKDAY_BUSINESS_HOURS, 2) else: # Weekday-only: cost = original * WEEKDAY_ONLY_FACTOR (pay for 5/7 of month) resource["monthly_cost"] = round(resource["monthly_cost"] * WEEKDAY_ONLY_FACTOR, 2) elif action_type == ActionType.enable_autoscaling.value: resource["monthly_cost"] = round(resource["monthly_cost"] * AUTOSCALING_FACTOR, 2) resource["current_config"]["autoscaling"] = True elif action_type == ActionType.purchase_reservation.value: resource["monthly_cost"] = round(resource["monthly_cost"] * RESERVATION_FACTOR, 2) resource["current_config"]["reservation_status"] = "reserved" # skip_resource and request_more_info don't change cost def _compute_new_cost(self, resource: dict, new_config: dict) -> float | None: """Compute the new monthly cost after rightsizing.""" rtype = resource.get("resource_type", "") new_instance = new_config.get("instance_type", "") if rtype == "ec2" and new_instance: cost = PRICING.get("ec2", {}).get(new_instance) if cost is not None: return round(cost, 2) elif rtype == "rds" and new_instance: instance_cost = PRICING.get("rds", {}).get(new_instance) if instance_cost is not None: storage_gb = new_config.get("storage_gb", resource["current_config"].get("storage_gb", 0)) storage_cost = storage_gb * PRICING["ebs"]["gp2_per_gb"] return round(instance_cost + storage_cost, 2) elif rtype == "elasticsearch" and new_instance: node_cost = PRICING.get("elasticsearch", {}).get(new_instance) if node_cost is not None: node_count = new_config.get("node_count", 1) return round(node_cost * node_count, 2) elif rtype == "kubernetes": node_type = new_config.get("node_type", "") node_count = new_config.get("node_count", 1) key = f"per_node_{node_type.replace('.', '_')}" per_node = PRICING.get("kubernetes", {}).get(key) if per_node is not None: base = PRICING["kubernetes"]["cluster_base"] return round(base + per_node * node_count, 2) return None def _compute_blast_radius(self, action: dict, resource: dict) -> dict: """Compute cascading impact of an action using transitive BFS. BFS through dependency graph: if ELB → EC2 → RDS, terminating ELB shows BOTH EC2 AND RDS as affected (not just direct dependents). """ action_type = action.get("action_type", "") resource_id = action.get("resource_id", "") # Skip/info actions have no blast radius if action_type in (ActionType.skip_resource.value, ActionType.request_more_info.value): return {"affected_resources": [], "risk_level": "none", "explanation": ""} affected = [] risk_level = "none" explanation = "" if action_type == ActionType.terminate_resource.value: # BFS through dependency graph — find ALL transitively affected resources visited = set() queue = [resource_id] while queue: current_id = queue.pop(0) for r in self.current_resources: if current_id in r.get("dependencies", []) and r["resource_id"] not in visited: visited.add(r["resource_id"]) affected.append(r["resource_id"]) queue.append(r["resource_id"]) # follow the chain # Check subnet-based impact (NAT Gateway) if resource.get("resource_type") == "nat_gateway": subnet = resource.get("subnet") if subnet: for r in self.current_resources: if r.get("subnet") == subnet and r["resource_id"] != resource_id: if r["resource_id"] not in visited: visited.add(r["resource_id"]) affected.append(r["resource_id"]) # Determine risk level based on affected count and criticality has_critical = any( r.get("is_critical") for r in self.current_resources if r["resource_id"] in visited ) has_prod = any( r.get("environment") == "prod" for r in self.current_resources if r["resource_id"] in visited ) if len(affected) == 0: risk_level = "none" explanation = "No dependent resources affected." elif resource.get("resource_type") == "nat_gateway" and has_critical: risk_level = "critical" explanation = ( f"Terminating this NAT Gateway will cut internet access for " f"{len(affected)} resources in subnet {resource.get('subnet')}." ) elif has_prod: risk_level = "critical" explanation = f"Terminating this resource affects {len(affected)} resources including production." elif has_critical: risk_level = "high" explanation = f"Terminating this resource affects {len(affected)} resources including critical ones." elif resource.get("resource_type") == "load_balancer": risk_level = "high" explanation = ( f"Terminating this ELB will disconnect {len(affected)} " f"instances that route through it." ) elif len(affected) > 2: risk_level = "high" explanation = f"Terminating this resource affects {len(affected)} dependent resources." elif len(affected) > 0: risk_level = "medium" explanation = f"Terminating this resource affects {len(affected)} dependent resource(s)." elif action_type == ActionType.rightsize_resource.value: # BFS for rightsize too — affects dependents but one tier lower risk visited = set() queue = [resource_id] while queue: current_id = queue.pop(0) for r in self.current_resources: if current_id in r.get("dependencies", []) and r["resource_id"] not in visited: visited.add(r["resource_id"]) affected.append(r["resource_id"]) queue.append(r["resource_id"]) if affected: has_critical = any( r.get("is_critical") for r in self.current_resources if r["resource_id"] in visited ) # Rightsizing keeps service running — reduce risk one tier if has_critical: risk_level = "high" explanation = f"Rightsizing affects {len(affected)} dependent resource(s) including critical ones." elif len(affected) > 2: risk_level = "medium" explanation = f"Rightsizing affects {len(affected)} dependent resource(s)." else: risk_level = "medium" if resource.get("resource_type") == "rds" else "low" explanation = f"Rightsizing affects {len(affected)} dependent resource(s)." elif action_type in (ActionType.add_lifecycle_policy.value, ActionType.change_storage_class.value): # Minimal blast radius for storage changes risk_level = "low" if resource.get("is_critical") else "none" if risk_level == "low": explanation = "Storage class change on a critical resource — verify access patterns." return { "affected_resources": affected, "risk_level": risk_level, "explanation": explanation, } def _is_done(self) -> bool: """Check if episode should end.""" if self.task is None: return True if self.step_count >= self.task.max_steps: return True # All resources have been actioned actioned_ids = {a.get("resource_id") for a in self.actions_history} all_ids = {r["resource_id"] for r in self.current_resources} if all_ids and actioned_ids >= all_ids: return True # >95% of possible savings achieved action_savings = self.task.get_action_savings() total_possible = sum(action_savings.values()) actual_savings = self.original_cost - self.current_cost if total_possible > 0 and actual_savings / total_possible > 0.95: return True return False