API_DEBUG_SOLVER / environment /api_triage_env.py
Kavya988's picture
Upload 29 files
d416acc verified
"""
The three main methods implementation ->
1. reset() - start of each episode
2. step(action) - agent takes an action
3. state() - called anytime
"""
"""
1. For incident selection - curriculum learning approach (easy -> medium -> hard)
2. For Reward factors - 5 factors (correct, wrong, resolve with/without fix, max steps)
3. For episode end conditions - resolved with fix , resolved without fix , max steps reached
4. For action space - 8 actions(including diagnostic , fix , terminal)
5. For max steps - 10 steps per episode
6. For reward - range is -20 to +20
"""
"""
extra info ->
1. stage 1 episodes -> 1-10
2. stage 2 epiosdes -> 11-25
3. stage 3 epiosdes -> 26+
"""
"""
Our 3 models->
1. observation- what agent sees at each step
2. action - what agent can do at each step
3. EnvState - internal tracking of the environment
"""
import random
from typing import Dict, Any, Tuple, Optional, List
from pydantic import BaseModel
from environment.incident_generator import get_random_incident, get_incident_by_type
from environment.action_space import is_valid_action
from environment.reward import calculate_reward
class Observation(BaseModel):
step: int
max_steps: int
incident_summary : str
logs: List[str]
response_code: int
fix_applied: bool
is_resolved: bool
class Action(BaseModel):
action_name: str
class EnvState(BaseModel):
current_incident: Dict[str, Any]
step_counter: int
fix_applied: bool
total_reward: float
is_resolved: bool
class APITriageEnv:
def __init__(self, max_steps = 10):
self.max_steps = max_steps
self.step_counter = 0
self.done = False
self.incident = None
self.fix_applied = False
self.total_reward = 0.0
self.total_episodes = 0
def reset(self):
self.step_counter = 0
self.done = False
self.fix_applied = False
self.total_reward = 0.0
self.total_episodes += 1
# implying the curriculum learning approach here
if self.total_episodes <= 10:
# stage 1 -> easy incidents (auth_error, missing_fields)
incident_type = random.choice(["auth_error", "missing_fields"])
self.incident = get_incident_by_type(incident_type)
elif self.total_episodes <= 25:
# stage 2 -> medium incidents
incident_type = random.choice(["rate_limit", "timeout", "wrong_endpoint"])
self.incident = get_incident_by_type(incident_type)
elif self.total_episodes > 25:
# stage 3 -> hard incidents
incident_type = "server_error"
self.incident = get_incident_by_type(incident_type)
return self.state()
def state(self):
"""Returns what the agent sees at current step"""
return Observation(
step=self.step_counter,
max_steps=self.max_steps,
incident_summary=self.incident["summary"],
logs=self.incident["logs"],
response_code=self.incident["code"],
fix_applied=self.fix_applied,
is_resolved=self.done
)
def step(self, action):
"""Agent takes an action and environment responds with new state and reward"""
# 1. if episode is done or finished already
if self.done:
state = self.state()
reward = 0.0
info = {"error": "episode is already finished "}
done = True
return state, reward, done, info
# 2. increment step counter and check is action is valid
self.step_counter += 1
# 3. validate the action
if not is_valid_action(action):
state = self.state()
reward = -2.0
info = {"error" : "the action is not valid"}
done = False
return state, reward , done , info
# 4. Reward calculation
reward = calculate_reward(action , self.incident, self.fix_applied, self.step_counter , self.max_steps)
# 5. updating fix applied status if the action is the correct fix action
if action == self.incident["fix_action"]:
self.fix_applied = True
# 6. update toatal reward
self.total_reward += reward
# 7. prepare info (for all cases )
info = {
"step": self.step_counter,
"incident_type": self.incident["type"],
"fix_applied": self.fix_applied,
"total_reward": self.total_reward
}
# 8. check if the epiosde is resolved
if action == "resolve":
self.done = True
info["resolution"] = "success" if self.fix_applied else "failure - resolved without fix"
# 9. check if epsiode is not resolved that means max steps are reached
if self.step_counter >= self.max_steps:
self.done = True
info["resolution"] = "failure - max steps reached"
# 10. final return (one return at the end)
return self.state(), reward, self.done, info