Mist-ic commited on
Commit
641125f
·
1 Parent(s): bc82eac

Improve inference script robustness and fallback strategy

Browse files

- Rolling 3-turn conversation history prevents context explosion on hard task
- Groq llama-3.3-70b-versatile as primary (free, 1000 RPD, no CC)
- Groq llama-3.1-8b-instant as tier-1 fallback (same key, 14400 RPD)
- HF Inference Router as tier-2 fallback
- Coerce replicas param to int (models sometimes return strings)
- Guard against empty/non-JSON server responses
- Fix UnicodeEncodeError on Windows (arrow -> ASCII)
- Fix param format: service_id always single string, not list

Files changed (1) hide show
  1. inference.py +148 -65
inference.py CHANGED
@@ -9,18 +9,41 @@ MANDATORY
9
 
10
  - The inference script must be named `inference.py` and placed in the root directory of the project
11
  - Participants must use OpenAI Client for all LLM calls using above variables
 
 
 
 
 
12
  """
13
 
14
  import json
15
  import os
 
16
  import textwrap
17
- from typing import Any, Dict, List, Optional
18
 
19
  from openai import OpenAI
20
 
21
- API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
22
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
23
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  SYSTEM_PROMPT = textwrap.dedent("""\
26
  You are an expert Site Reliability Engineer (SRE) responding to a production incident.
@@ -31,25 +54,71 @@ SYSTEM_PROMPT = textwrap.dedent("""\
31
  Strategy:
32
  1. First, inspect logs of services showing the highest error rates or critical alerts
33
  2. Diagnose the root cause from log patterns:
34
- - OOMKilled/CrashLoopBackOff restart_service
35
- - NullPointerException/TypeError + recent deploy rollback_service
36
- - "password authentication failed"/"config not found" tune_config with the broken key
37
- - Thread pool exhaustion/timeout from downstream fix the downstream dependency first
38
- - Memory climbing linearly restart_service (resource leak)
39
- - HikariPool exhaustion/slow queries scale_service or restart_service on the DB
40
- - CLUSTERDOWN/cache miss clear_cache
41
- - DNS/network errors → rebalance_traffic (if multi-region)
 
42
  3. Apply the correct remediation action
43
  4. Verify recovery with inspect_logs or inspect_metrics
44
 
45
- Respond with EXACTLY one JSON object:
46
  {"action_type": "...", "params": {...}}
47
 
48
- Available actions: inspect_logs, inspect_metrics, inspect_traces, restart_service,
49
- rollback_service, scale_service, tune_config, clear_cache, rebalance_traffic, pause_job, noop
 
 
 
 
 
 
 
50
  """)
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def build_observation_prompt(obs: Dict[str, Any]) -> str:
54
  """Build a concise prompt from the observation."""
55
  parts = [f"## Incident Status\n{obs.get('observation_summary', 'N/A')}"]
@@ -57,34 +126,37 @@ def build_observation_prompt(obs: Dict[str, Any]) -> str:
57
  # Alerts (most important)
58
  alerts = obs.get("alerts", [])
59
  if alerts:
60
- alert_lines = []
61
- for a in alerts[:10]:
62
- alert_lines.append(f" [{a['severity'].upper()}] {a['message']}")
63
  parts.append("## Active Alerts\n" + "\n".join(alert_lines))
64
 
65
- # Service states (condensed)
66
  services = obs.get("services", [])
67
  degraded = [s for s in services if s.get("status") in ("degraded", "critical", "down")]
68
  if degraded:
69
- svc_lines = []
70
- for s in degraded:
71
- svc_lines.append(
72
- f" {s['id']} [{s['status']}]: error={s['error_rate']:.1%}, "
73
- f"p99={s['latency_p99_ms']:.0f}ms, cpu={s['cpu_pct']:.0f}%, "
74
- f"mem={s['memory_pct']:.0f}%, pool={s['connection_pool_usage_pct']:.0f}%"
75
- )
76
  parts.append("## Degraded Services\n" + "\n".join(svc_lines))
77
 
78
  # Recent deploys
79
  deploys = obs.get("recent_deploys", [])
80
  if deploys:
81
- dep_lines = [f" {d['service']} → {d['version']} ({d['ticks_ago']} ticks ago)" for d in deploys]
 
 
 
82
  parts.append("## Recent Deploys\n" + "\n".join(dep_lines))
83
 
84
  # Actions taken
85
  actions = obs.get("actions_taken", [])
86
  if actions:
87
- act_lines = [f" tick {a['tick']}: {a['action']}({a.get('target', '')}) → {'OK' if a['success'] else 'FAIL'}" for a in actions[-5:]]
 
 
 
88
  parts.append("## Recent Actions\n" + "\n".join(act_lines))
89
 
90
  # Logs (if available from inspect)
@@ -97,7 +169,10 @@ def build_observation_prompt(obs: Dict[str, Any]) -> str:
97
  if traces:
98
  error_spans = [s for s in traces.get("spans", []) if s.get("status") == "ERROR"]
99
  if error_spans:
100
- trace_lines = [f" {s['service']}: {s.get('tags', {}).get('error.message', 'ERROR')} ({s['duration_ms']}ms)" for s in error_spans[:5]]
 
 
 
101
  parts.append("## Trace Errors\n" + "\n".join(trace_lines))
102
 
103
  # Legal actions
@@ -111,16 +186,15 @@ def build_observation_prompt(obs: Dict[str, Any]) -> str:
111
 
112
  def parse_action(response_text: str) -> Dict[str, Any]:
113
  """Parse the model's JSON response into an action dict."""
114
- # Try to extract JSON from the response
115
  text = response_text.strip()
116
 
117
- # Handle markdown code blocks
118
  if "```json" in text:
119
  text = text.split("```json")[1].split("```")[0].strip()
120
  elif "```" in text:
121
  text = text.split("```")[1].split("```")[0].strip()
122
 
123
- # Find JSON object
124
  start = text.find("{")
125
  end = text.rfind("}") + 1
126
  if start >= 0 and end > start:
@@ -152,57 +226,61 @@ def run_episode(
152
  resp_data = reset_resp.json()
153
  obs = resp_data.get("observation", resp_data)
154
 
155
- messages: List[Dict[str, Any]] = [
156
- {"role": "system", "content": SYSTEM_PROMPT},
157
- ]
158
-
159
  max_steps = obs.get("max_steps", 10)
160
  total_reward = 0.0
161
-
162
  done = resp_data.get("done", False)
 
 
 
 
 
163
  for step_num in range(max_steps):
164
  if done:
165
  break
166
 
167
  user_msg = build_observation_prompt(obs)
168
- messages.append({"role": "user", "content": user_msg})
169
 
170
- # Call the LLM
171
- try:
172
- completion = client.chat.completions.create(
173
- model=MODEL_NAME,
174
- messages=messages,
175
- temperature=0.2,
176
- max_tokens=200,
177
- )
178
- response_text = completion.choices[0].message.content or ""
179
- except Exception as e:
180
- print(f" LLM error at step {step_num}: {e}")
181
- response_text = '{"action_type": "noop", "params": {}}'
182
 
 
183
  action = parse_action(response_text)
184
- messages.append({"role": "assistant", "content": response_text})
185
 
186
- print(f" Step {step_num}: {action.get('action_type', 'noop')}({action.get('params', {})})")
187
 
188
  # Step the environment
 
 
 
 
 
 
 
 
189
  step_resp = httpx.post(
190
  f"{base}/step",
191
- json={"action": {"action_type": action.get("action_type", "noop"), "params": action.get("params", {})}},
 
 
 
192
  timeout=30.0,
193
  )
194
- resp_data = step_resp.json()
 
 
 
 
195
  obs = resp_data.get("observation", resp_data)
196
  done = resp_data.get("done", False)
197
  reward = obs.get("reward") or resp_data.get("reward") or 0.0
198
  total_reward += reward if reward else 0.0
199
 
200
- # Get final state
201
- state_resp = httpx.get(f"{base}/state", timeout=10.0)
202
- final_state = state_resp.json()
203
-
204
- # Grade
205
- grade_resp = httpx.post(
206
  f"{base}/grader",
207
  json={
208
  "final_slo_score": final_state.get("global_slo_score", 0.0),
@@ -213,8 +291,7 @@ def run_episode(
213
  "termination_reason": final_state.get("termination_reason"),
214
  },
215
  timeout=10.0,
216
- )
217
- grade = grade_resp.json()
218
 
219
  return {
220
  "task_id": task_id,
@@ -222,6 +299,8 @@ def run_episode(
222
  "total_reward": total_reward,
223
  "score": grade.get("score", 0.0),
224
  "slo_recovery": grade.get("slo_recovery", 0.0),
 
 
225
  "steps_taken": final_state.get("step_count", 0),
226
  "termination_reason": final_state.get("termination_reason"),
227
  }
@@ -237,7 +316,8 @@ def main() -> None:
237
  print("=" * 60)
238
  print("SevZero Baseline Inference")
239
  print("=" * 60)
240
- print(f"Model: {MODEL_NAME}")
 
241
  print(f"Environment: {env_url}")
242
  print()
243
 
@@ -246,15 +326,18 @@ def main() -> None:
246
  print(f"--- Task: {task_id} (seed={seed}) ---")
247
  result = run_episode(client, env_url, task_id, seed)
248
  results.append(result)
249
- print(f" Score: {result['score']:.4f} | SLO Recovery: {result['slo_recovery']:.4f} | "
250
- f"Steps: {result['steps_taken']} | Outcome: {result['termination_reason']}")
 
 
 
251
  print()
252
 
253
  print("=" * 60)
254
  print("Summary")
255
  print("=" * 60)
256
  for r in results:
257
- print(f" {r['task_id']:8s} score={r['score']:.4f} slo={r['slo_recovery']:.4f} steps={r['steps_taken']}")
258
  avg_score = sum(r["score"] for r in results) / len(results) if results else 0.0
259
  print(f"\n Average score: {avg_score:.4f}")
260
 
 
9
 
10
  - The inference script must be named `inference.py` and placed in the root directory of the project
11
  - Participants must use OpenAI Client for all LLM calls using above variables
12
+
13
+ Recommended setup (free, no credit card):
14
+ API_BASE_URL=https://api.groq.com/openai/v1
15
+ MODEL_NAME=llama-3.3-70b-versatile
16
+ HF_TOKEN=<your_groq_api_key> # Free at console.groq.com
17
  """
18
 
19
  import json
20
  import os
21
+ import time
22
  import textwrap
23
+ from typing import Any, Dict, List
24
 
25
  from openai import OpenAI
26
 
27
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
28
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
29
+ MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.3-70b-versatile")
30
+
31
+ # Fallback providers tried in order if the primary hits rate limits or errors.
32
+ # Each uses the same HF_TOKEN env var as the API key — all are OpenAI-compatible.
33
+ _FALLBACK_PROVIDERS = [
34
+ # Tier 1 fallback: same Groq key, lighter model (14,400 RPD free)
35
+ {
36
+ "base_url": "https://api.groq.com/openai/v1",
37
+ "model": "llama-3.1-8b-instant",
38
+ "api_key": API_KEY,
39
+ },
40
+ # Tier 2 fallback: HuggingFace Inference Router
41
+ {
42
+ "base_url": "https://router.huggingface.co/v1",
43
+ "model": "Qwen/Qwen2.5-72B-Instruct",
44
+ "api_key": os.getenv("HF_INFERENCE_TOKEN") or API_KEY,
45
+ },
46
+ ]
47
 
48
  SYSTEM_PROMPT = textwrap.dedent("""\
49
  You are an expert Site Reliability Engineer (SRE) responding to a production incident.
 
54
  Strategy:
55
  1. First, inspect logs of services showing the highest error rates or critical alerts
56
  2. Diagnose the root cause from log patterns:
57
+ - OOMKilled/CrashLoopBackOff -> restart_service
58
+ - NullPointerException/TypeError + recent deploy -> rollback_service
59
+ - "password authentication failed"/"config not found" -> tune_config with the broken key
60
+ (the logs will show: "Configuration diagnostic: key '<KEY>' has invalid value")
61
+ - Thread pool exhaustion/timeout from downstream -> fix the downstream dependency first
62
+ - Memory climbing linearly -> restart_service (resource leak)
63
+ - HikariPool exhaustion/slow queries -> scale_service or restart_service on the DB
64
+ - CLUSTERDOWN/cache miss -> clear_cache
65
+ - DNS/network errors -> rebalance_traffic (if multi-region)
66
  3. Apply the correct remediation action
67
  4. Verify recovery with inspect_logs or inspect_metrics
68
 
69
+ Respond with EXACTLY one JSON object — no explanation, no markdown, just raw JSON:
70
  {"action_type": "...", "params": {...}}
71
 
72
+ Param rules (STRICT single service only, never a list):
73
+ - inspect_logs / inspect_metrics / inspect_traces / restart_service / rollback_service / scale_service:
74
+ {"action_type": "X", "params": {"service_id": "order-service"}}
75
+ - tune_config:
76
+ {"action_type": "tune_config", "params": {"service_id": "order-service", "key": "api_endpoint", "value": "correct"}}
77
+ - clear_cache:
78
+ {"action_type": "clear_cache", "params": {"cache_name": "redis-cache"}}
79
+ - noop:
80
+ {"action_type": "noop", "params": {}}
81
  """)
82
 
83
 
84
+ def _call_llm(
85
+ messages: List[Dict[str, Any]],
86
+ primary_client: OpenAI,
87
+ primary_model: str,
88
+ ) -> str:
89
+ """Call LLM with automatic fallback on rate limit or error."""
90
+ providers = [{"client": primary_client, "model": primary_model}] + [
91
+ {
92
+ "client": OpenAI(base_url=p["base_url"], api_key=p["api_key"]),
93
+ "model": p["model"],
94
+ }
95
+ for p in _FALLBACK_PROVIDERS
96
+ ]
97
+
98
+ last_err = None
99
+ for i, provider in enumerate(providers):
100
+ try:
101
+ completion = provider["client"].chat.completions.create(
102
+ model=provider["model"],
103
+ messages=messages,
104
+ temperature=0.2,
105
+ max_tokens=200,
106
+ )
107
+ return completion.choices[0].message.content or ""
108
+ except Exception as e:
109
+ last_err = e
110
+ is_rate_limit = any(x in str(e).lower() for x in ("429", "rate_limit", "quota", "credits", "402"))
111
+ label = "fallback" if i > 0 else "primary"
112
+ print(f" [{label} {provider['model']}] error: {e}")
113
+ if is_rate_limit and i < len(providers) - 1:
114
+ time.sleep(3)
115
+ continue
116
+ if i < len(providers) - 1:
117
+ continue
118
+ print(f" All providers failed. Last error: {last_err}")
119
+ return '{"action_type": "noop", "params": {}}'
120
+
121
+
122
  def build_observation_prompt(obs: Dict[str, Any]) -> str:
123
  """Build a concise prompt from the observation."""
124
  parts = [f"## Incident Status\n{obs.get('observation_summary', 'N/A')}"]
 
126
  # Alerts (most important)
127
  alerts = obs.get("alerts", [])
128
  if alerts:
129
+ alert_lines = [f" [{a['severity'].upper()}] {a['message']}" for a in alerts[:10]]
 
 
130
  parts.append("## Active Alerts\n" + "\n".join(alert_lines))
131
 
132
+ # Service states (condensed — degraded only)
133
  services = obs.get("services", [])
134
  degraded = [s for s in services if s.get("status") in ("degraded", "critical", "down")]
135
  if degraded:
136
+ svc_lines = [
137
+ f" {s['id']} [{s['status']}]: error={s['error_rate']:.1%}, "
138
+ f"p99={s['latency_p99_ms']:.0f}ms, cpu={s['cpu_pct']:.0f}%, "
139
+ f"mem={s['memory_pct']:.0f}%, pool={s['connection_pool_usage_pct']:.0f}%"
140
+ for s in degraded
141
+ ]
 
142
  parts.append("## Degraded Services\n" + "\n".join(svc_lines))
143
 
144
  # Recent deploys
145
  deploys = obs.get("recent_deploys", [])
146
  if deploys:
147
+ dep_lines = [
148
+ f" {d['service']} -> {d['version']} ({d['ticks_ago']} ticks ago)"
149
+ for d in deploys
150
+ ]
151
  parts.append("## Recent Deploys\n" + "\n".join(dep_lines))
152
 
153
  # Actions taken
154
  actions = obs.get("actions_taken", [])
155
  if actions:
156
+ act_lines = [
157
+ f" tick {a['tick']}: {a['action']}({a.get('target', '')}) -> {'OK' if a['success'] else 'FAIL'}"
158
+ for a in actions[-5:]
159
+ ]
160
  parts.append("## Recent Actions\n" + "\n".join(act_lines))
161
 
162
  # Logs (if available from inspect)
 
169
  if traces:
170
  error_spans = [s for s in traces.get("spans", []) if s.get("status") == "ERROR"]
171
  if error_spans:
172
+ trace_lines = [
173
+ f" {s['service']}: {s.get('tags', {}).get('error.message', 'ERROR')} ({s['duration_ms']}ms)"
174
+ for s in error_spans[:5]
175
+ ]
176
  parts.append("## Trace Errors\n" + "\n".join(trace_lines))
177
 
178
  # Legal actions
 
186
 
187
  def parse_action(response_text: str) -> Dict[str, Any]:
188
  """Parse the model's JSON response into an action dict."""
 
189
  text = response_text.strip()
190
 
191
+ # Strip markdown code blocks
192
  if "```json" in text:
193
  text = text.split("```json")[1].split("```")[0].strip()
194
  elif "```" in text:
195
  text = text.split("```")[1].split("```")[0].strip()
196
 
197
+ # Extract JSON object
198
  start = text.find("{")
199
  end = text.rfind("}") + 1
200
  if start >= 0 and end > start:
 
226
  resp_data = reset_resp.json()
227
  obs = resp_data.get("observation", resp_data)
228
 
 
 
 
 
229
  max_steps = obs.get("max_steps", 10)
230
  total_reward = 0.0
 
231
  done = resp_data.get("done", False)
232
+
233
+ # Rolling conversation: system prompt + last 6 messages (3 turns).
234
+ # Prevents context explosion on hard tasks (50 steps x ~800 tokens/step).
235
+ conversation_history: List[Dict[str, Any]] = []
236
+
237
  for step_num in range(max_steps):
238
  if done:
239
  break
240
 
241
  user_msg = build_observation_prompt(obs)
242
+ conversation_history.append({"role": "user", "content": user_msg})
243
 
244
+ # Keep only last 6 messages (3 user+assistant turns) to bound context size
245
+ trimmed = conversation_history[-6:]
246
+ messages_to_send = [{"role": "system", "content": SYSTEM_PROMPT}] + trimmed
 
 
 
 
 
 
 
 
 
247
 
248
+ response_text = _call_llm(messages_to_send, client, MODEL_NAME)
249
  action = parse_action(response_text)
250
+ conversation_history.append({"role": "assistant", "content": response_text})
251
 
252
+ print(f" Step {step_num + 1}: {action.get('action_type', 'noop')}({action.get('params', {})})")
253
 
254
  # Step the environment
255
+ params = action.get("params", {})
256
+ # Coerce replicas to int if model sends a string
257
+ if "replicas" in params:
258
+ try:
259
+ params["replicas"] = int(params["replicas"])
260
+ except (ValueError, TypeError):
261
+ params["replicas"] = 2
262
+
263
  step_resp = httpx.post(
264
  f"{base}/step",
265
+ json={"action": {
266
+ "action_type": action.get("action_type", "noop"),
267
+ "params": params,
268
+ }},
269
  timeout=30.0,
270
  )
271
+ try:
272
+ resp_data = step_resp.json()
273
+ except Exception:
274
+ # Empty or non-JSON response (server error) — treat as noop
275
+ resp_data = {}
276
  obs = resp_data.get("observation", resp_data)
277
  done = resp_data.get("done", False)
278
  reward = obs.get("reward") or resp_data.get("reward") or 0.0
279
  total_reward += reward if reward else 0.0
280
 
281
+ # Final state + grade
282
+ final_state = httpx.get(f"{base}/state", timeout=10.0).json()
283
+ grade = httpx.post(
 
 
 
284
  f"{base}/grader",
285
  json={
286
  "final_slo_score": final_state.get("global_slo_score", 0.0),
 
291
  "termination_reason": final_state.get("termination_reason"),
292
  },
293
  timeout=10.0,
294
+ ).json()
 
295
 
296
  return {
297
  "task_id": task_id,
 
299
  "total_reward": total_reward,
300
  "score": grade.get("score", 0.0),
301
  "slo_recovery": grade.get("slo_recovery", 0.0),
302
+ "action_efficiency": grade.get("action_efficiency", 0.0),
303
+ "time_efficiency": grade.get("time_efficiency", 0.0),
304
  "steps_taken": final_state.get("step_count", 0),
305
  "termination_reason": final_state.get("termination_reason"),
306
  }
 
316
  print("=" * 60)
317
  print("SevZero Baseline Inference")
318
  print("=" * 60)
319
+ print(f"Model: {MODEL_NAME}")
320
+ print(f"API: {API_BASE_URL}")
321
  print(f"Environment: {env_url}")
322
  print()
323
 
 
326
  print(f"--- Task: {task_id} (seed={seed}) ---")
327
  result = run_episode(client, env_url, task_id, seed)
328
  results.append(result)
329
+ print(
330
+ f" Score: {result['score']:.4f} | SLO: {result['slo_recovery']:.4f} | "
331
+ f"AE: {result['action_efficiency']:.4f} | TE: {result['time_efficiency']:.4f} | "
332
+ f"Steps: {result['steps_taken']} | Outcome: {result['termination_reason']}"
333
+ )
334
  print()
335
 
336
  print("=" * 60)
337
  print("Summary")
338
  print("=" * 60)
339
  for r in results:
340
+ print(f" {r['task_id']:8s} score={r['score']:.4f} slo={r['slo_recovery']:.4f} steps={r['steps_taken']}")
341
  avg_score = sum(r["score"] for r in results) / len(results) if results else 0.0
342
  print(f"\n Average score: {avg_score:.4f}")
343