Improve inference script robustness and fallback strategy
Browse files- Rolling 3-turn conversation history prevents context explosion on hard task
- Groq llama-3.3-70b-versatile as primary (free, 1000 RPD, no CC)
- Groq llama-3.1-8b-instant as tier-1 fallback (same key, 14400 RPD)
- HF Inference Router as tier-2 fallback
- Coerce replicas param to int (models sometimes return strings)
- Guard against empty/non-JSON server responses
- Fix UnicodeEncodeError on Windows (arrow -> ASCII)
- Fix param format: service_id always single string, not list
- inference.py +148 -65
inference.py
CHANGED
|
@@ -9,18 +9,41 @@ MANDATORY
|
|
| 9 |
|
| 10 |
- The inference script must be named `inference.py` and placed in the root directory of the project
|
| 11 |
- Participants must use OpenAI Client for all LLM calls using above variables
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
import json
|
| 15 |
import os
|
|
|
|
| 16 |
import textwrap
|
| 17 |
-
from typing import Any, Dict, List
|
| 18 |
|
| 19 |
from openai import OpenAI
|
| 20 |
|
| 21 |
-
API_BASE_URL = os.getenv("API_BASE_URL", "https://
|
| 22 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 23 |
-
MODEL_NAME = os.getenv("MODEL_NAME", "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
SYSTEM_PROMPT = textwrap.dedent("""\
|
| 26 |
You are an expert Site Reliability Engineer (SRE) responding to a production incident.
|
|
@@ -31,25 +54,71 @@ SYSTEM_PROMPT = textwrap.dedent("""\
|
|
| 31 |
Strategy:
|
| 32 |
1. First, inspect logs of services showing the highest error rates or critical alerts
|
| 33 |
2. Diagnose the root cause from log patterns:
|
| 34 |
-
- OOMKilled/CrashLoopBackOff
|
| 35 |
-
- NullPointerException/TypeError + recent deploy
|
| 36 |
-
- "password authentication failed"/"config not found"
|
| 37 |
-
|
| 38 |
-
-
|
| 39 |
-
-
|
| 40 |
-
-
|
| 41 |
-
-
|
|
|
|
| 42 |
3. Apply the correct remediation action
|
| 43 |
4. Verify recovery with inspect_logs or inspect_metrics
|
| 44 |
|
| 45 |
-
Respond with EXACTLY one JSON object:
|
| 46 |
{"action_type": "...", "params": {...}}
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
""")
|
| 51 |
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
def build_observation_prompt(obs: Dict[str, Any]) -> str:
|
| 54 |
"""Build a concise prompt from the observation."""
|
| 55 |
parts = [f"## Incident Status\n{obs.get('observation_summary', 'N/A')}"]
|
|
@@ -57,34 +126,37 @@ def build_observation_prompt(obs: Dict[str, Any]) -> str:
|
|
| 57 |
# Alerts (most important)
|
| 58 |
alerts = obs.get("alerts", [])
|
| 59 |
if alerts:
|
| 60 |
-
alert_lines = []
|
| 61 |
-
for a in alerts[:10]:
|
| 62 |
-
alert_lines.append(f" [{a['severity'].upper()}] {a['message']}")
|
| 63 |
parts.append("## Active Alerts\n" + "\n".join(alert_lines))
|
| 64 |
|
| 65 |
-
# Service states (condensed)
|
| 66 |
services = obs.get("services", [])
|
| 67 |
degraded = [s for s in services if s.get("status") in ("degraded", "critical", "down")]
|
| 68 |
if degraded:
|
| 69 |
-
svc_lines = [
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
)
|
| 76 |
parts.append("## Degraded Services\n" + "\n".join(svc_lines))
|
| 77 |
|
| 78 |
# Recent deploys
|
| 79 |
deploys = obs.get("recent_deploys", [])
|
| 80 |
if deploys:
|
| 81 |
-
dep_lines = [
|
|
|
|
|
|
|
|
|
|
| 82 |
parts.append("## Recent Deploys\n" + "\n".join(dep_lines))
|
| 83 |
|
| 84 |
# Actions taken
|
| 85 |
actions = obs.get("actions_taken", [])
|
| 86 |
if actions:
|
| 87 |
-
act_lines = [
|
|
|
|
|
|
|
|
|
|
| 88 |
parts.append("## Recent Actions\n" + "\n".join(act_lines))
|
| 89 |
|
| 90 |
# Logs (if available from inspect)
|
|
@@ -97,7 +169,10 @@ def build_observation_prompt(obs: Dict[str, Any]) -> str:
|
|
| 97 |
if traces:
|
| 98 |
error_spans = [s for s in traces.get("spans", []) if s.get("status") == "ERROR"]
|
| 99 |
if error_spans:
|
| 100 |
-
trace_lines = [
|
|
|
|
|
|
|
|
|
|
| 101 |
parts.append("## Trace Errors\n" + "\n".join(trace_lines))
|
| 102 |
|
| 103 |
# Legal actions
|
|
@@ -111,16 +186,15 @@ def build_observation_prompt(obs: Dict[str, Any]) -> str:
|
|
| 111 |
|
| 112 |
def parse_action(response_text: str) -> Dict[str, Any]:
|
| 113 |
"""Parse the model's JSON response into an action dict."""
|
| 114 |
-
# Try to extract JSON from the response
|
| 115 |
text = response_text.strip()
|
| 116 |
|
| 117 |
-
#
|
| 118 |
if "```json" in text:
|
| 119 |
text = text.split("```json")[1].split("```")[0].strip()
|
| 120 |
elif "```" in text:
|
| 121 |
text = text.split("```")[1].split("```")[0].strip()
|
| 122 |
|
| 123 |
-
#
|
| 124 |
start = text.find("{")
|
| 125 |
end = text.rfind("}") + 1
|
| 126 |
if start >= 0 and end > start:
|
|
@@ -152,57 +226,61 @@ def run_episode(
|
|
| 152 |
resp_data = reset_resp.json()
|
| 153 |
obs = resp_data.get("observation", resp_data)
|
| 154 |
|
| 155 |
-
messages: List[Dict[str, Any]] = [
|
| 156 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 157 |
-
]
|
| 158 |
-
|
| 159 |
max_steps = obs.get("max_steps", 10)
|
| 160 |
total_reward = 0.0
|
| 161 |
-
|
| 162 |
done = resp_data.get("done", False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
for step_num in range(max_steps):
|
| 164 |
if done:
|
| 165 |
break
|
| 166 |
|
| 167 |
user_msg = build_observation_prompt(obs)
|
| 168 |
-
|
| 169 |
|
| 170 |
-
#
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
model=MODEL_NAME,
|
| 174 |
-
messages=messages,
|
| 175 |
-
temperature=0.2,
|
| 176 |
-
max_tokens=200,
|
| 177 |
-
)
|
| 178 |
-
response_text = completion.choices[0].message.content or ""
|
| 179 |
-
except Exception as e:
|
| 180 |
-
print(f" LLM error at step {step_num}: {e}")
|
| 181 |
-
response_text = '{"action_type": "noop", "params": {}}'
|
| 182 |
|
|
|
|
| 183 |
action = parse_action(response_text)
|
| 184 |
-
|
| 185 |
|
| 186 |
-
print(f" Step {step_num}: {action.get('action_type', 'noop')}({action.get('params', {})})")
|
| 187 |
|
| 188 |
# Step the environment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
step_resp = httpx.post(
|
| 190 |
f"{base}/step",
|
| 191 |
-
json={"action": {
|
|
|
|
|
|
|
|
|
|
| 192 |
timeout=30.0,
|
| 193 |
)
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
obs = resp_data.get("observation", resp_data)
|
| 196 |
done = resp_data.get("done", False)
|
| 197 |
reward = obs.get("reward") or resp_data.get("reward") or 0.0
|
| 198 |
total_reward += reward if reward else 0.0
|
| 199 |
|
| 200 |
-
#
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
# Grade
|
| 205 |
-
grade_resp = httpx.post(
|
| 206 |
f"{base}/grader",
|
| 207 |
json={
|
| 208 |
"final_slo_score": final_state.get("global_slo_score", 0.0),
|
|
@@ -213,8 +291,7 @@ def run_episode(
|
|
| 213 |
"termination_reason": final_state.get("termination_reason"),
|
| 214 |
},
|
| 215 |
timeout=10.0,
|
| 216 |
-
)
|
| 217 |
-
grade = grade_resp.json()
|
| 218 |
|
| 219 |
return {
|
| 220 |
"task_id": task_id,
|
|
@@ -222,6 +299,8 @@ def run_episode(
|
|
| 222 |
"total_reward": total_reward,
|
| 223 |
"score": grade.get("score", 0.0),
|
| 224 |
"slo_recovery": grade.get("slo_recovery", 0.0),
|
|
|
|
|
|
|
| 225 |
"steps_taken": final_state.get("step_count", 0),
|
| 226 |
"termination_reason": final_state.get("termination_reason"),
|
| 227 |
}
|
|
@@ -237,7 +316,8 @@ def main() -> None:
|
|
| 237 |
print("=" * 60)
|
| 238 |
print("SevZero Baseline Inference")
|
| 239 |
print("=" * 60)
|
| 240 |
-
print(f"Model:
|
|
|
|
| 241 |
print(f"Environment: {env_url}")
|
| 242 |
print()
|
| 243 |
|
|
@@ -246,15 +326,18 @@ def main() -> None:
|
|
| 246 |
print(f"--- Task: {task_id} (seed={seed}) ---")
|
| 247 |
result = run_episode(client, env_url, task_id, seed)
|
| 248 |
results.append(result)
|
| 249 |
-
print(
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
| 251 |
print()
|
| 252 |
|
| 253 |
print("=" * 60)
|
| 254 |
print("Summary")
|
| 255 |
print("=" * 60)
|
| 256 |
for r in results:
|
| 257 |
-
print(f" {r['task_id']:8s}
|
| 258 |
avg_score = sum(r["score"] for r in results) / len(results) if results else 0.0
|
| 259 |
print(f"\n Average score: {avg_score:.4f}")
|
| 260 |
|
|
|
|
| 9 |
|
| 10 |
- The inference script must be named `inference.py` and placed in the root directory of the project
|
| 11 |
- Participants must use OpenAI Client for all LLM calls using above variables
|
| 12 |
+
|
| 13 |
+
Recommended setup (free, no credit card):
|
| 14 |
+
API_BASE_URL=https://api.groq.com/openai/v1
|
| 15 |
+
MODEL_NAME=llama-3.3-70b-versatile
|
| 16 |
+
HF_TOKEN=<your_groq_api_key> # Free at console.groq.com
|
| 17 |
"""
|
| 18 |
|
| 19 |
import json
|
| 20 |
import os
|
| 21 |
+
import time
|
| 22 |
import textwrap
|
| 23 |
+
from typing import Any, Dict, List
|
| 24 |
|
| 25 |
from openai import OpenAI
|
| 26 |
|
| 27 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
|
| 28 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 29 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.3-70b-versatile")
|
| 30 |
+
|
| 31 |
+
# Fallback providers tried in order if the primary hits rate limits or errors.
|
| 32 |
+
# Each uses the same HF_TOKEN env var as the API key — all are OpenAI-compatible.
|
| 33 |
+
_FALLBACK_PROVIDERS = [
|
| 34 |
+
# Tier 1 fallback: same Groq key, lighter model (14,400 RPD free)
|
| 35 |
+
{
|
| 36 |
+
"base_url": "https://api.groq.com/openai/v1",
|
| 37 |
+
"model": "llama-3.1-8b-instant",
|
| 38 |
+
"api_key": API_KEY,
|
| 39 |
+
},
|
| 40 |
+
# Tier 2 fallback: HuggingFace Inference Router
|
| 41 |
+
{
|
| 42 |
+
"base_url": "https://router.huggingface.co/v1",
|
| 43 |
+
"model": "Qwen/Qwen2.5-72B-Instruct",
|
| 44 |
+
"api_key": os.getenv("HF_INFERENCE_TOKEN") or API_KEY,
|
| 45 |
+
},
|
| 46 |
+
]
|
| 47 |
|
| 48 |
SYSTEM_PROMPT = textwrap.dedent("""\
|
| 49 |
You are an expert Site Reliability Engineer (SRE) responding to a production incident.
|
|
|
|
| 54 |
Strategy:
|
| 55 |
1. First, inspect logs of services showing the highest error rates or critical alerts
|
| 56 |
2. Diagnose the root cause from log patterns:
|
| 57 |
+
- OOMKilled/CrashLoopBackOff -> restart_service
|
| 58 |
+
- NullPointerException/TypeError + recent deploy -> rollback_service
|
| 59 |
+
- "password authentication failed"/"config not found" -> tune_config with the broken key
|
| 60 |
+
(the logs will show: "Configuration diagnostic: key '<KEY>' has invalid value")
|
| 61 |
+
- Thread pool exhaustion/timeout from downstream -> fix the downstream dependency first
|
| 62 |
+
- Memory climbing linearly -> restart_service (resource leak)
|
| 63 |
+
- HikariPool exhaustion/slow queries -> scale_service or restart_service on the DB
|
| 64 |
+
- CLUSTERDOWN/cache miss -> clear_cache
|
| 65 |
+
- DNS/network errors -> rebalance_traffic (if multi-region)
|
| 66 |
3. Apply the correct remediation action
|
| 67 |
4. Verify recovery with inspect_logs or inspect_metrics
|
| 68 |
|
| 69 |
+
Respond with EXACTLY one JSON object — no explanation, no markdown, just raw JSON:
|
| 70 |
{"action_type": "...", "params": {...}}
|
| 71 |
|
| 72 |
+
Param rules (STRICT — single service only, never a list):
|
| 73 |
+
- inspect_logs / inspect_metrics / inspect_traces / restart_service / rollback_service / scale_service:
|
| 74 |
+
{"action_type": "X", "params": {"service_id": "order-service"}}
|
| 75 |
+
- tune_config:
|
| 76 |
+
{"action_type": "tune_config", "params": {"service_id": "order-service", "key": "api_endpoint", "value": "correct"}}
|
| 77 |
+
- clear_cache:
|
| 78 |
+
{"action_type": "clear_cache", "params": {"cache_name": "redis-cache"}}
|
| 79 |
+
- noop:
|
| 80 |
+
{"action_type": "noop", "params": {}}
|
| 81 |
""")
|
| 82 |
|
| 83 |
|
| 84 |
+
def _call_llm(
|
| 85 |
+
messages: List[Dict[str, Any]],
|
| 86 |
+
primary_client: OpenAI,
|
| 87 |
+
primary_model: str,
|
| 88 |
+
) -> str:
|
| 89 |
+
"""Call LLM with automatic fallback on rate limit or error."""
|
| 90 |
+
providers = [{"client": primary_client, "model": primary_model}] + [
|
| 91 |
+
{
|
| 92 |
+
"client": OpenAI(base_url=p["base_url"], api_key=p["api_key"]),
|
| 93 |
+
"model": p["model"],
|
| 94 |
+
}
|
| 95 |
+
for p in _FALLBACK_PROVIDERS
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
last_err = None
|
| 99 |
+
for i, provider in enumerate(providers):
|
| 100 |
+
try:
|
| 101 |
+
completion = provider["client"].chat.completions.create(
|
| 102 |
+
model=provider["model"],
|
| 103 |
+
messages=messages,
|
| 104 |
+
temperature=0.2,
|
| 105 |
+
max_tokens=200,
|
| 106 |
+
)
|
| 107 |
+
return completion.choices[0].message.content or ""
|
| 108 |
+
except Exception as e:
|
| 109 |
+
last_err = e
|
| 110 |
+
is_rate_limit = any(x in str(e).lower() for x in ("429", "rate_limit", "quota", "credits", "402"))
|
| 111 |
+
label = "fallback" if i > 0 else "primary"
|
| 112 |
+
print(f" [{label} {provider['model']}] error: {e}")
|
| 113 |
+
if is_rate_limit and i < len(providers) - 1:
|
| 114 |
+
time.sleep(3)
|
| 115 |
+
continue
|
| 116 |
+
if i < len(providers) - 1:
|
| 117 |
+
continue
|
| 118 |
+
print(f" All providers failed. Last error: {last_err}")
|
| 119 |
+
return '{"action_type": "noop", "params": {}}'
|
| 120 |
+
|
| 121 |
+
|
| 122 |
def build_observation_prompt(obs: Dict[str, Any]) -> str:
|
| 123 |
"""Build a concise prompt from the observation."""
|
| 124 |
parts = [f"## Incident Status\n{obs.get('observation_summary', 'N/A')}"]
|
|
|
|
| 126 |
# Alerts (most important)
|
| 127 |
alerts = obs.get("alerts", [])
|
| 128 |
if alerts:
|
| 129 |
+
alert_lines = [f" [{a['severity'].upper()}] {a['message']}" for a in alerts[:10]]
|
|
|
|
|
|
|
| 130 |
parts.append("## Active Alerts\n" + "\n".join(alert_lines))
|
| 131 |
|
| 132 |
+
# Service states (condensed — degraded only)
|
| 133 |
services = obs.get("services", [])
|
| 134 |
degraded = [s for s in services if s.get("status") in ("degraded", "critical", "down")]
|
| 135 |
if degraded:
|
| 136 |
+
svc_lines = [
|
| 137 |
+
f" {s['id']} [{s['status']}]: error={s['error_rate']:.1%}, "
|
| 138 |
+
f"p99={s['latency_p99_ms']:.0f}ms, cpu={s['cpu_pct']:.0f}%, "
|
| 139 |
+
f"mem={s['memory_pct']:.0f}%, pool={s['connection_pool_usage_pct']:.0f}%"
|
| 140 |
+
for s in degraded
|
| 141 |
+
]
|
|
|
|
| 142 |
parts.append("## Degraded Services\n" + "\n".join(svc_lines))
|
| 143 |
|
| 144 |
# Recent deploys
|
| 145 |
deploys = obs.get("recent_deploys", [])
|
| 146 |
if deploys:
|
| 147 |
+
dep_lines = [
|
| 148 |
+
f" {d['service']} -> {d['version']} ({d['ticks_ago']} ticks ago)"
|
| 149 |
+
for d in deploys
|
| 150 |
+
]
|
| 151 |
parts.append("## Recent Deploys\n" + "\n".join(dep_lines))
|
| 152 |
|
| 153 |
# Actions taken
|
| 154 |
actions = obs.get("actions_taken", [])
|
| 155 |
if actions:
|
| 156 |
+
act_lines = [
|
| 157 |
+
f" tick {a['tick']}: {a['action']}({a.get('target', '')}) -> {'OK' if a['success'] else 'FAIL'}"
|
| 158 |
+
for a in actions[-5:]
|
| 159 |
+
]
|
| 160 |
parts.append("## Recent Actions\n" + "\n".join(act_lines))
|
| 161 |
|
| 162 |
# Logs (if available from inspect)
|
|
|
|
| 169 |
if traces:
|
| 170 |
error_spans = [s for s in traces.get("spans", []) if s.get("status") == "ERROR"]
|
| 171 |
if error_spans:
|
| 172 |
+
trace_lines = [
|
| 173 |
+
f" {s['service']}: {s.get('tags', {}).get('error.message', 'ERROR')} ({s['duration_ms']}ms)"
|
| 174 |
+
for s in error_spans[:5]
|
| 175 |
+
]
|
| 176 |
parts.append("## Trace Errors\n" + "\n".join(trace_lines))
|
| 177 |
|
| 178 |
# Legal actions
|
|
|
|
| 186 |
|
| 187 |
def parse_action(response_text: str) -> Dict[str, Any]:
|
| 188 |
"""Parse the model's JSON response into an action dict."""
|
|
|
|
| 189 |
text = response_text.strip()
|
| 190 |
|
| 191 |
+
# Strip markdown code blocks
|
| 192 |
if "```json" in text:
|
| 193 |
text = text.split("```json")[1].split("```")[0].strip()
|
| 194 |
elif "```" in text:
|
| 195 |
text = text.split("```")[1].split("```")[0].strip()
|
| 196 |
|
| 197 |
+
# Extract JSON object
|
| 198 |
start = text.find("{")
|
| 199 |
end = text.rfind("}") + 1
|
| 200 |
if start >= 0 and end > start:
|
|
|
|
| 226 |
resp_data = reset_resp.json()
|
| 227 |
obs = resp_data.get("observation", resp_data)
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
max_steps = obs.get("max_steps", 10)
|
| 230 |
total_reward = 0.0
|
|
|
|
| 231 |
done = resp_data.get("done", False)
|
| 232 |
+
|
| 233 |
+
# Rolling conversation: system prompt + last 6 messages (3 turns).
|
| 234 |
+
# Prevents context explosion on hard tasks (50 steps x ~800 tokens/step).
|
| 235 |
+
conversation_history: List[Dict[str, Any]] = []
|
| 236 |
+
|
| 237 |
for step_num in range(max_steps):
|
| 238 |
if done:
|
| 239 |
break
|
| 240 |
|
| 241 |
user_msg = build_observation_prompt(obs)
|
| 242 |
+
conversation_history.append({"role": "user", "content": user_msg})
|
| 243 |
|
| 244 |
+
# Keep only last 6 messages (3 user+assistant turns) to bound context size
|
| 245 |
+
trimmed = conversation_history[-6:]
|
| 246 |
+
messages_to_send = [{"role": "system", "content": SYSTEM_PROMPT}] + trimmed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
+
response_text = _call_llm(messages_to_send, client, MODEL_NAME)
|
| 249 |
action = parse_action(response_text)
|
| 250 |
+
conversation_history.append({"role": "assistant", "content": response_text})
|
| 251 |
|
| 252 |
+
print(f" Step {step_num + 1}: {action.get('action_type', 'noop')}({action.get('params', {})})")
|
| 253 |
|
| 254 |
# Step the environment
|
| 255 |
+
params = action.get("params", {})
|
| 256 |
+
# Coerce replicas to int if model sends a string
|
| 257 |
+
if "replicas" in params:
|
| 258 |
+
try:
|
| 259 |
+
params["replicas"] = int(params["replicas"])
|
| 260 |
+
except (ValueError, TypeError):
|
| 261 |
+
params["replicas"] = 2
|
| 262 |
+
|
| 263 |
step_resp = httpx.post(
|
| 264 |
f"{base}/step",
|
| 265 |
+
json={"action": {
|
| 266 |
+
"action_type": action.get("action_type", "noop"),
|
| 267 |
+
"params": params,
|
| 268 |
+
}},
|
| 269 |
timeout=30.0,
|
| 270 |
)
|
| 271 |
+
try:
|
| 272 |
+
resp_data = step_resp.json()
|
| 273 |
+
except Exception:
|
| 274 |
+
# Empty or non-JSON response (server error) — treat as noop
|
| 275 |
+
resp_data = {}
|
| 276 |
obs = resp_data.get("observation", resp_data)
|
| 277 |
done = resp_data.get("done", False)
|
| 278 |
reward = obs.get("reward") or resp_data.get("reward") or 0.0
|
| 279 |
total_reward += reward if reward else 0.0
|
| 280 |
|
| 281 |
+
# Final state + grade
|
| 282 |
+
final_state = httpx.get(f"{base}/state", timeout=10.0).json()
|
| 283 |
+
grade = httpx.post(
|
|
|
|
|
|
|
|
|
|
| 284 |
f"{base}/grader",
|
| 285 |
json={
|
| 286 |
"final_slo_score": final_state.get("global_slo_score", 0.0),
|
|
|
|
| 291 |
"termination_reason": final_state.get("termination_reason"),
|
| 292 |
},
|
| 293 |
timeout=10.0,
|
| 294 |
+
).json()
|
|
|
|
| 295 |
|
| 296 |
return {
|
| 297 |
"task_id": task_id,
|
|
|
|
| 299 |
"total_reward": total_reward,
|
| 300 |
"score": grade.get("score", 0.0),
|
| 301 |
"slo_recovery": grade.get("slo_recovery", 0.0),
|
| 302 |
+
"action_efficiency": grade.get("action_efficiency", 0.0),
|
| 303 |
+
"time_efficiency": grade.get("time_efficiency", 0.0),
|
| 304 |
"steps_taken": final_state.get("step_count", 0),
|
| 305 |
"termination_reason": final_state.get("termination_reason"),
|
| 306 |
}
|
|
|
|
| 316 |
print("=" * 60)
|
| 317 |
print("SevZero Baseline Inference")
|
| 318 |
print("=" * 60)
|
| 319 |
+
print(f"Model: {MODEL_NAME}")
|
| 320 |
+
print(f"API: {API_BASE_URL}")
|
| 321 |
print(f"Environment: {env_url}")
|
| 322 |
print()
|
| 323 |
|
|
|
|
| 326 |
print(f"--- Task: {task_id} (seed={seed}) ---")
|
| 327 |
result = run_episode(client, env_url, task_id, seed)
|
| 328 |
results.append(result)
|
| 329 |
+
print(
|
| 330 |
+
f" Score: {result['score']:.4f} | SLO: {result['slo_recovery']:.4f} | "
|
| 331 |
+
f"AE: {result['action_efficiency']:.4f} | TE: {result['time_efficiency']:.4f} | "
|
| 332 |
+
f"Steps: {result['steps_taken']} | Outcome: {result['termination_reason']}"
|
| 333 |
+
)
|
| 334 |
print()
|
| 335 |
|
| 336 |
print("=" * 60)
|
| 337 |
print("Summary")
|
| 338 |
print("=" * 60)
|
| 339 |
for r in results:
|
| 340 |
+
print(f" {r['task_id']:8s} score={r['score']:.4f} slo={r['slo_recovery']:.4f} steps={r['steps_taken']}")
|
| 341 |
avg_score = sum(r["score"] for r in results) / len(results) if results else 0.0
|
| 342 |
print(f"\n Average score: {avg_score:.4f}")
|
| 343 |
|