File size: 4,693 Bytes
d416acc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
""" 
The three main methods implementation ->
1. reset() - start of each episode
2. step(action) - agent takes an action
3. state() - called anytime 
"""

"""
1. For incident selection - curriculum learning approach (easy -> medium -> hard)
2. For Reward factors - 5 factors (correct, wrong, resolve with/without fix, max steps)
3. For episode end conditions - resolved with fix , resolved without fix , max steps reached 
4. For action space -  8 actions(including diagnostic , fix , terminal)
5. For max steps - 10 steps per episode
6. For reward - range is -20 to +20
"""

"""
extra info ->
1. stage 1 episodes -> 1-10 
2. stage 2 epiosdes -> 11-25
3. stage 3 epiosdes -> 26+
"""

"""
Our 3 models->
1. observation-  what agent sees at each step 
2. action - what agent can do at each step
3. EnvState - internal tracking of the environment 
"""

import random
from typing import Dict, Any, Tuple, Optional, List
from pydantic import BaseModel 

from environment.incident_generator import get_random_incident, get_incident_by_type
from environment.action_space import is_valid_action
from environment.reward import calculate_reward

class Observation(BaseModel): 
  step: int
  max_steps: int
  incident_summary : str
  logs: List[str]
  response_code: int
  fix_applied: bool
  is_resolved: bool
    
class Action(BaseModel):
  action_name: str

class EnvState(BaseModel):
  current_incident: Dict[str, Any]
  step_counter: int
  fix_applied: bool
  total_reward: float
  is_resolved: bool  

class APITriageEnv:
  def __init__(self, max_steps = 10):
    self.max_steps = max_steps
    self.step_counter = 0
    self.done = False
    self.incident = None
    self.fix_applied = False
    self.total_reward = 0.0
    self.total_episodes = 0

  def reset(self):
    self.step_counter = 0
    self.done = False
    self.fix_applied = False     
    self.total_reward = 0.0
    self.total_episodes += 1

    # implying the curriculum learning approach here
    if self.total_episodes <= 10:
      # stage 1 -> easy incidents (auth_error, missing_fields)
      incident_type  = random.choice(["auth_error", "missing_fields"])
      self.incident = get_incident_by_type(incident_type)
    elif self.total_episodes <= 25:
      # stage 2 -> medium incidents
      incident_type = random.choice(["rate_limit", "timeout", "wrong_endpoint"])
      self.incident = get_incident_by_type(incident_type)
    elif self.total_episodes > 25:
      # stage 3 -> hard incidents
      incident_type = "server_error"
      self.incident = get_incident_by_type(incident_type)  

    return self.state()  
  
  def state(self):
    """Returns what the agent sees at current step"""
    return Observation(
        step=self.step_counter,
        max_steps=self.max_steps,
        incident_summary=self.incident["summary"],
        logs=self.incident["logs"],
        response_code=self.incident["code"],
        fix_applied=self.fix_applied,
        is_resolved=self.done
    )
  
  def step(self, action):
    """Agent takes an action and environment responds with new state and reward"""
    # 1. if episode is done or finished already 
    if self.done:
      state = self.state()
      reward = 0.0
      info  = {"error": "episode is already finished "}
      done = True
      return state, reward, done, info
    
    # 2. increment step counter and check is action is valid
    self.step_counter += 1

    # 3. validate the action 
    if not is_valid_action(action):
      state = self.state()
      reward = -2.0
      info = {"error" : "the action is not valid"}
      done = False
      return state, reward , done , info

    # 4. Reward calculation  
    reward = calculate_reward(action , self.incident, self.fix_applied, self.step_counter , self.max_steps)
 
    # 5. updating fix applied status if the action is the correct fix action
    if action == self.incident["fix_action"]:
      self.fix_applied = True

    # 6. update toatal reward 
    self.total_reward += reward

    # 7. prepare info (for all cases )
    info = {
      "step": self.step_counter,
      "incident_type": self.incident["type"],
      "fix_applied": self.fix_applied,
      "total_reward": self.total_reward
    }

    # 8. check if the epiosde is resolved
    if action == "resolve":
      self.done = True
      info["resolution"] = "success" if self.fix_applied else "failure - resolved without fix"
    

    # 9. check if epsiode is not resolved that means max steps are reached    
    if self.step_counter >= self.max_steps:
      self.done = True
      info["resolution"] = "failure - max steps reached"  
    
    # 10. final return (one return at the end)
    return self.state(), reward, self.done, info