Arijit-07 commited on
Commit
ead0ec3
·
verified ·
1 Parent(s): 9a2f12c

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +71 -3
train_model.py CHANGED
@@ -100,6 +100,42 @@ def env_state():
100
  r.raise_for_status()
101
  return r.json()
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  health = requests.get(f'{BASE_URL}/health', timeout=15).json()
104
  print(f'✅ Environment: {health}')
105
  test_obs = env_reset('easy', seed=0)
@@ -111,9 +147,40 @@ print('Config loaded:')
111
  for k, v in CONFIG.items():
112
  print(f' {k}: {v}')
113
 
114
- SYSTEM_PROMPT = """You are an autonomous DevOps incident response agent.
115
- Return exactly one JSON object describing the next action to take.
116
- The JSON object must include an "action_type" field."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  def observation_to_prompt(obs, task_id):
119
  return (
@@ -232,6 +299,7 @@ def run_episode(task_id, seed=None, verbose=False):
232
  if done: break
233
  action, _ = generate_action(obs, task_id)
234
  if verbose: print(f' Step {step+1}: {action}')
 
235
  result = env_step(action)
236
  total_reward += result.get('reward', 0.0)
237
  obs = result.get('observation', obs)
 
100
  r.raise_for_status()
101
  return r.json()
102
 
103
+ VALID_ACTIONS = {
104
+ "diagnose", "read_logs", "read_metrics", "read_runbook",
105
+ "search_logs", "restart_service", "rollback", "scale_up",
106
+ "alert_oncall", "acknowledge", "noop", "block_ip_range",
107
+ "create_index", "failover"
108
+ }
109
+
110
+ def sanitize_action(action):
111
+ if not isinstance(action, dict):
112
+ return {"action_type": "noop"}
113
+
114
+ action_type = action.get("action_type", "").lower()
115
+
116
+ # Fix common mistakes
117
+ if action_type == "read_service_logs":
118
+ action_type = "read_logs"
119
+
120
+ if action_type not in VALID_ACTIONS:
121
+ return {"action_type": "noop"}
122
+
123
+ # Fix parameter names
124
+ service = action.get("service") or action.get("service_name")
125
+
126
+ clean = {"action_type": action_type}
127
+
128
+ if service:
129
+ clean["service"] = service
130
+
131
+ # add other fields safely
132
+ for key in ["root_cause", "runbook", "version", "reason",
133
+ "query", "ip_range", "table", "column", "target_region"]:
134
+ if key in action:
135
+ clean[key] = action[key]
136
+
137
+ return clean
138
+
139
  health = requests.get(f'{BASE_URL}/health', timeout=15).json()
140
  print(f'✅ Environment: {health}')
141
  test_obs = env_reset('easy', seed=0)
 
147
  for k, v in CONFIG.items():
148
  print(f' {k}: {v}')
149
 
150
+ SYSTEM_PROMPT = """
151
+ You are an autonomous DevOps agent.
152
+
153
+ You MUST return ONLY valid JSON.
154
+
155
+ STRICT RULES:
156
+ - action_type MUST be one of:
157
+ diagnose, read_logs, read_metrics, read_runbook, search_logs,
158
+ restart_service, rollback, scale_up, alert_oncall, acknowledge,
159
+ noop, block_ip_range, create_index, failover
160
+
161
+ - Use EXACT parameter names:
162
+ service (NOT service_name)
163
+ root_cause
164
+ runbook
165
+ version
166
+ reason
167
+ query
168
+ ip_range
169
+ table
170
+ column
171
+ target_region
172
+
173
+ - DO NOT invent new fields
174
+ - DO NOT change names
175
+ - DO NOT use service_name
176
+ - Always output valid JSON only
177
+
178
+ Example:
179
+ {
180
+ "action_type": "read_logs",
181
+ "service": "order-service"
182
+ }
183
+ """
184
 
185
  def observation_to_prompt(obs, task_id):
186
  return (
 
299
  if done: break
300
  action, _ = generate_action(obs, task_id)
301
  if verbose: print(f' Step {step+1}: {action}')
302
+ action = sanitize_action(action)
303
  result = env_step(action)
304
  total_reward += result.get('reward', 0.0)
305
  obs = result.get('observation', obs)