Spaces:
Paused
Paused
Update train_model.py
Browse files- train_model.py +71 -3
train_model.py
CHANGED
|
@@ -100,6 +100,42 @@ def env_state():
|
|
| 100 |
r.raise_for_status()
|
| 101 |
return r.json()
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
health = requests.get(f'{BASE_URL}/health', timeout=15).json()
|
| 104 |
print(f'✅ Environment: {health}')
|
| 105 |
test_obs = env_reset('easy', seed=0)
|
|
@@ -111,9 +147,40 @@ print('Config loaded:')
|
|
| 111 |
for k, v in CONFIG.items():
|
| 112 |
print(f' {k}: {v}')
|
| 113 |
|
| 114 |
-
SYSTEM_PROMPT = """
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
def observation_to_prompt(obs, task_id):
|
| 119 |
return (
|
|
@@ -232,6 +299,7 @@ def run_episode(task_id, seed=None, verbose=False):
|
|
| 232 |
if done: break
|
| 233 |
action, _ = generate_action(obs, task_id)
|
| 234 |
if verbose: print(f' Step {step+1}: {action}')
|
|
|
|
| 235 |
result = env_step(action)
|
| 236 |
total_reward += result.get('reward', 0.0)
|
| 237 |
obs = result.get('observation', obs)
|
|
|
|
| 100 |
r.raise_for_status()
|
| 101 |
return r.json()
|
| 102 |
|
| 103 |
+
VALID_ACTIONS = {
|
| 104 |
+
"diagnose", "read_logs", "read_metrics", "read_runbook",
|
| 105 |
+
"search_logs", "restart_service", "rollback", "scale_up",
|
| 106 |
+
"alert_oncall", "acknowledge", "noop", "block_ip_range",
|
| 107 |
+
"create_index", "failover"
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
def sanitize_action(action):
|
| 111 |
+
if not isinstance(action, dict):
|
| 112 |
+
return {"action_type": "noop"}
|
| 113 |
+
|
| 114 |
+
action_type = action.get("action_type", "").lower()
|
| 115 |
+
|
| 116 |
+
# Fix common mistakes
|
| 117 |
+
if action_type == "read_service_logs":
|
| 118 |
+
action_type = "read_logs"
|
| 119 |
+
|
| 120 |
+
if action_type not in VALID_ACTIONS:
|
| 121 |
+
return {"action_type": "noop"}
|
| 122 |
+
|
| 123 |
+
# Fix parameter names
|
| 124 |
+
service = action.get("service") or action.get("service_name")
|
| 125 |
+
|
| 126 |
+
clean = {"action_type": action_type}
|
| 127 |
+
|
| 128 |
+
if service:
|
| 129 |
+
clean["service"] = service
|
| 130 |
+
|
| 131 |
+
# add other fields safely
|
| 132 |
+
for key in ["root_cause", "runbook", "version", "reason",
|
| 133 |
+
"query", "ip_range", "table", "column", "target_region"]:
|
| 134 |
+
if key in action:
|
| 135 |
+
clean[key] = action[key]
|
| 136 |
+
|
| 137 |
+
return clean
|
| 138 |
+
|
| 139 |
health = requests.get(f'{BASE_URL}/health', timeout=15).json()
|
| 140 |
print(f'✅ Environment: {health}')
|
| 141 |
test_obs = env_reset('easy', seed=0)
|
|
|
|
| 147 |
for k, v in CONFIG.items():
|
| 148 |
print(f' {k}: {v}')
|
| 149 |
|
| 150 |
+
SYSTEM_PROMPT = """
|
| 151 |
+
You are an autonomous DevOps agent.
|
| 152 |
+
|
| 153 |
+
You MUST return ONLY valid JSON.
|
| 154 |
+
|
| 155 |
+
STRICT RULES:
|
| 156 |
+
- action_type MUST be one of:
|
| 157 |
+
diagnose, read_logs, read_metrics, read_runbook, search_logs,
|
| 158 |
+
restart_service, rollback, scale_up, alert_oncall, acknowledge,
|
| 159 |
+
noop, block_ip_range, create_index, failover
|
| 160 |
+
|
| 161 |
+
- Use EXACT parameter names:
|
| 162 |
+
service (NOT service_name)
|
| 163 |
+
root_cause
|
| 164 |
+
runbook
|
| 165 |
+
version
|
| 166 |
+
reason
|
| 167 |
+
query
|
| 168 |
+
ip_range
|
| 169 |
+
table
|
| 170 |
+
column
|
| 171 |
+
target_region
|
| 172 |
+
|
| 173 |
+
- DO NOT invent new fields
|
| 174 |
+
- DO NOT change names
|
| 175 |
+
- DO NOT use service_name
|
| 176 |
+
- Always output valid JSON only
|
| 177 |
+
|
| 178 |
+
Example:
|
| 179 |
+
{
|
| 180 |
+
"action_type": "read_logs",
|
| 181 |
+
"service": "order-service"
|
| 182 |
+
}
|
| 183 |
+
"""
|
| 184 |
|
| 185 |
def observation_to_prompt(obs, task_id):
|
| 186 |
return (
|
|
|
|
| 299 |
if done: break
|
| 300 |
action, _ = generate_action(obs, task_id)
|
| 301 |
if verbose: print(f' Step {step+1}: {action}')
|
| 302 |
+
action = sanitize_action(action)
|
| 303 |
result = env_step(action)
|
| 304 |
total_reward += result.get('reward', 0.0)
|
| 305 |
obs = result.get('observation', obs)
|