Spaces:
Sleeping
Sleeping
| """ | |
| The three main methods implementation -> | |
| 1. reset() - start of each episode | |
| 2. step(action) - agent takes an action | |
| 3. state() - called anytime | |
| """ | |
| """ | |
| 1. For incident selection - curriculum learning approach (easy -> medium -> hard) | |
| 2. For Reward factors - 5 factors (correct, wrong, resolve with/without fix, max steps) | |
| 3. For episode end conditions - resolved with fix , resolved without fix , max steps reached | |
| 4. For action space - 8 actions(including diagnostic , fix , terminal) | |
| 5. For max steps - 10 steps per episode | |
| 6. For reward - range is -20 to +20 | |
| """ | |
| """ | |
| extra info -> | |
| 1. stage 1 episodes -> 1-10 | |
| 2. stage 2 epiosdes -> 11-25 | |
| 3. stage 3 epiosdes -> 26+ | |
| """ | |
| """ | |
| Our 3 models-> | |
| 1. observation- what agent sees at each step | |
| 2. action - what agent can do at each step | |
| 3. EnvState - internal tracking of the environment | |
| """ | |
| import random | |
| from typing import Dict, Any, Tuple, Optional, List | |
| from pydantic import BaseModel | |
| from environment.incident_generator import get_random_incident, get_incident_by_type | |
| from environment.action_space import is_valid_action | |
| from environment.reward import calculate_reward | |
| class Observation(BaseModel): | |
| step: int | |
| max_steps: int | |
| incident_summary : str | |
| logs: List[str] | |
| response_code: int | |
| fix_applied: bool | |
| is_resolved: bool | |
| class Action(BaseModel): | |
| action_name: str | |
| class EnvState(BaseModel): | |
| current_incident: Dict[str, Any] | |
| step_counter: int | |
| fix_applied: bool | |
| total_reward: float | |
| is_resolved: bool | |
| class APITriageEnv: | |
| def __init__(self, max_steps = 10): | |
| self.max_steps = max_steps | |
| self.step_counter = 0 | |
| self.done = False | |
| self.incident = None | |
| self.fix_applied = False | |
| self.total_reward = 0.0 | |
| self.total_episodes = 0 | |
| def reset(self): | |
| self.step_counter = 0 | |
| self.done = False | |
| self.fix_applied = False | |
| self.total_reward = 0.0 | |
| self.total_episodes += 1 | |
| # implying the curriculum learning approach here | |
| if self.total_episodes <= 10: | |
| # stage 1 -> easy incidents (auth_error, missing_fields) | |
| incident_type = random.choice(["auth_error", "missing_fields"]) | |
| self.incident = get_incident_by_type(incident_type) | |
| elif self.total_episodes <= 25: | |
| # stage 2 -> medium incidents | |
| incident_type = random.choice(["rate_limit", "timeout", "wrong_endpoint"]) | |
| self.incident = get_incident_by_type(incident_type) | |
| elif self.total_episodes > 25: | |
| # stage 3 -> hard incidents | |
| incident_type = "server_error" | |
| self.incident = get_incident_by_type(incident_type) | |
| return self.state() | |
| def state(self): | |
| """Returns what the agent sees at current step""" | |
| return Observation( | |
| step=self.step_counter, | |
| max_steps=self.max_steps, | |
| incident_summary=self.incident["summary"], | |
| logs=self.incident["logs"], | |
| response_code=self.incident["code"], | |
| fix_applied=self.fix_applied, | |
| is_resolved=self.done | |
| ) | |
| def step(self, action): | |
| """Agent takes an action and environment responds with new state and reward""" | |
| # 1. if episode is done or finished already | |
| if self.done: | |
| state = self.state() | |
| reward = 0.0 | |
| info = {"error": "episode is already finished "} | |
| done = True | |
| return state, reward, done, info | |
| # 2. increment step counter and check is action is valid | |
| self.step_counter += 1 | |
| # 3. validate the action | |
| if not is_valid_action(action): | |
| state = self.state() | |
| reward = -2.0 | |
| info = {"error" : "the action is not valid"} | |
| done = False | |
| return state, reward , done , info | |
| # 4. Reward calculation | |
| reward = calculate_reward(action , self.incident, self.fix_applied, self.step_counter , self.max_steps) | |
| # 5. updating fix applied status if the action is the correct fix action | |
| if action == self.incident["fix_action"]: | |
| self.fix_applied = True | |
| # 6. update toatal reward | |
| self.total_reward += reward | |
| # 7. prepare info (for all cases ) | |
| info = { | |
| "step": self.step_counter, | |
| "incident_type": self.incident["type"], | |
| "fix_applied": self.fix_applied, | |
| "total_reward": self.total_reward | |
| } | |
| # 8. check if the epiosde is resolved | |
| if action == "resolve": | |
| self.done = True | |
| info["resolution"] = "success" if self.fix_applied else "failure - resolved without fix" | |
| # 9. check if epsiode is not resolved that means max steps are reached | |
| if self.step_counter >= self.max_steps: | |
| self.done = True | |
| info["resolution"] = "failure - max steps reached" | |
| # 10. final return (one return at the end) | |
| return self.state(), reward, self.done, info | |