Mist-ic commited on
Commit
00225fe
·
1 Parent(s): 0e4dd30

Add baseline inference script and Dockerfile

Browse files

- inference.py: LLM-powered baseline agent using OpenAI client with
structured SRE diagnostic prompting, JSON action parsing, and
automatic grading. Reads API_BASE_URL, MODEL_NAME, HF_TOKEN env vars.
- Dockerfile: multi-stage build (builder with uv, slim runtime) on
port 7860 for HF Spaces deployment
- pyproject.toml: added project.scripts entry and httpx dependency

Files changed (3) hide show
  1. Dockerfile +40 -0
  2. inference.py +259 -0
  3. pyproject.toml +4 -0
Dockerfile ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim AS builder
2
+
3
+ WORKDIR /app
4
+
5
+ # Install uv for fast dependency management
6
+ RUN pip install --no-cache-dir uv
7
+
8
+ # Copy dependency files first for cache efficiency
9
+ COPY pyproject.toml uv.lock ./
10
+
11
+ # Install dependencies
12
+ RUN uv sync --frozen --no-dev
13
+
14
+ # --- Runtime stage ---
15
+ FROM python:3.11-slim
16
+
17
+ WORKDIR /app
18
+
19
+ # Copy installed packages from builder
20
+ COPY --from=builder /app/.venv /app/.venv
21
+
22
+ # Copy application code
23
+ COPY models.py .
24
+ COPY server/ server/
25
+ COPY inference.py .
26
+ COPY openenv.yaml .
27
+ COPY pyproject.toml .
28
+ COPY README.md .
29
+
30
+ # Use the venv
31
+ ENV PATH="/app/.venv/bin:$PATH"
32
+ ENV PYTHONPATH="/app"
33
+
34
+ # Non-root user
35
+ RUN useradd -m appuser
36
+ USER appuser
37
+
38
+ EXPOSE 7860
39
+
40
+ CMD ["python", "-m", "uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
inference.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference Script — SevZero Baseline Agent
3
+ ==========================================
4
+ MANDATORY
5
+ - Before submitting, ensure the following variables are defined in your environment configuration:
6
+ API_BASE_URL The API endpoint for the LLM.
7
+ MODEL_NAME The model identifier to use for inference.
8
+ HF_TOKEN Your Hugging Face / API key.
9
+
10
+ - The inference script must be named `inference.py` and placed in the root directory of the project
11
+ - Participants must use OpenAI Client for all LLM calls using above variables
12
+ """
13
+
14
+ import json
15
+ import os
16
+ import textwrap
17
+ from typing import Any, Dict, List, Optional
18
+
19
+ from openai import OpenAI
20
+
21
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
22
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
23
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
24
+
25
+ SYSTEM_PROMPT = textwrap.dedent("""\
26
+ You are an expert Site Reliability Engineer (SRE) responding to a production incident.
27
+ You are managing a microservice cluster experiencing failures.
28
+
29
+ Your goal: restore all services to healthy SLO compliance as efficiently as possible.
30
+
31
+ Strategy:
32
+ 1. First, inspect logs of services showing the highest error rates or critical alerts
33
+ 2. Diagnose the root cause from log patterns:
34
+ - OOMKilled/CrashLoopBackOff → restart_service
35
+ - NullPointerException/TypeError + recent deploy → rollback_service
36
+ - "password authentication failed"/"config not found" → tune_config with the broken key
37
+ - Thread pool exhaustion/timeout from downstream → fix the downstream dependency first
38
+ - Memory climbing linearly → restart_service (resource leak)
39
+ - HikariPool exhaustion/slow queries → scale_service or restart_service on the DB
40
+ - CLUSTERDOWN/cache miss → clear_cache
41
+ - DNS/network errors → rebalance_traffic (if multi-region)
42
+ 3. Apply the correct remediation action
43
+ 4. Verify recovery with inspect_logs or inspect_metrics
44
+
45
+ Respond with EXACTLY one JSON object:
46
+ {"action_type": "...", "params": {...}}
47
+
48
+ Available actions: inspect_logs, inspect_metrics, inspect_traces, restart_service,
49
+ rollback_service, scale_service, tune_config, clear_cache, rebalance_traffic, pause_job, noop
50
+ """)
51
+
52
+
53
+ def build_observation_prompt(obs: Dict[str, Any]) -> str:
54
+ """Build a concise prompt from the observation."""
55
+ parts = [f"## Incident Status\n{obs.get('observation_summary', 'N/A')}"]
56
+
57
+ # Alerts (most important)
58
+ alerts = obs.get("alerts", [])
59
+ if alerts:
60
+ alert_lines = []
61
+ for a in alerts[:10]:
62
+ alert_lines.append(f" [{a['severity'].upper()}] {a['message']}")
63
+ parts.append("## Active Alerts\n" + "\n".join(alert_lines))
64
+
65
+ # Service states (condensed)
66
+ services = obs.get("services", [])
67
+ degraded = [s for s in services if s.get("status") in ("degraded", "critical", "down")]
68
+ if degraded:
69
+ svc_lines = []
70
+ for s in degraded:
71
+ svc_lines.append(
72
+ f" {s['id']} [{s['status']}]: error={s['error_rate']:.1%}, "
73
+ f"p99={s['latency_p99_ms']:.0f}ms, cpu={s['cpu_pct']:.0f}%, "
74
+ f"mem={s['memory_pct']:.0f}%, pool={s['connection_pool_usage_pct']:.0f}%"
75
+ )
76
+ parts.append("## Degraded Services\n" + "\n".join(svc_lines))
77
+
78
+ # Recent deploys
79
+ deploys = obs.get("recent_deploys", [])
80
+ if deploys:
81
+ dep_lines = [f" {d['service']} → {d['version']} ({d['ticks_ago']} ticks ago)" for d in deploys]
82
+ parts.append("## Recent Deploys\n" + "\n".join(dep_lines))
83
+
84
+ # Actions taken
85
+ actions = obs.get("actions_taken", [])
86
+ if actions:
87
+ act_lines = [f" tick {a['tick']}: {a['action']}({a.get('target', '')}) → {'OK' if a['success'] else 'FAIL'}" for a in actions[-5:]]
88
+ parts.append("## Recent Actions\n" + "\n".join(act_lines))
89
+
90
+ # Logs (if available from inspect)
91
+ logs = obs.get("logs")
92
+ if logs:
93
+ parts.append(f"## Logs\n{logs}")
94
+
95
+ # Traces (if available)
96
+ traces = obs.get("traces")
97
+ if traces:
98
+ error_spans = [s for s in traces.get("spans", []) if s.get("status") == "ERROR"]
99
+ if error_spans:
100
+ trace_lines = [f" {s['service']}: {s.get('tags', {}).get('error.message', 'ERROR')} ({s['duration_ms']}ms)" for s in error_spans[:5]]
101
+ parts.append("## Trace Errors\n" + "\n".join(trace_lines))
102
+
103
+ # Legal actions
104
+ legal = obs.get("legal_actions", [])
105
+ if legal:
106
+ legal_strs = [f" {la['action_type']}: targets={la['valid_targets'][:5]}" for la in legal]
107
+ parts.append("## Available Actions\n" + "\n".join(legal_strs))
108
+
109
+ return "\n\n".join(parts)
110
+
111
+
112
+ def parse_action(response_text: str) -> Dict[str, Any]:
113
+ """Parse the model's JSON response into an action dict."""
114
+ # Try to extract JSON from the response
115
+ text = response_text.strip()
116
+
117
+ # Handle markdown code blocks
118
+ if "```json" in text:
119
+ text = text.split("```json")[1].split("```")[0].strip()
120
+ elif "```" in text:
121
+ text = text.split("```")[1].split("```")[0].strip()
122
+
123
+ # Find JSON object
124
+ start = text.find("{")
125
+ end = text.rfind("}") + 1
126
+ if start >= 0 and end > start:
127
+ try:
128
+ return json.loads(text[start:end])
129
+ except json.JSONDecodeError:
130
+ pass
131
+
132
+ return {"action_type": "noop", "params": {}}
133
+
134
+
135
+ def run_episode(
136
+ client: OpenAI,
137
+ env_url: str,
138
+ task_id: str,
139
+ seed: int = 42,
140
+ ) -> Dict[str, Any]:
141
+ """Run one episode using the OpenEnv HTTP API."""
142
+ import httpx
143
+
144
+ base = env_url.rstrip("/")
145
+
146
+ # Reset
147
+ reset_resp = httpx.post(
148
+ f"{base}/reset",
149
+ json={"seed": seed, "task_id": task_id},
150
+ timeout=30.0,
151
+ )
152
+ obs = reset_resp.json()
153
+
154
+ messages: List[Dict[str, Any]] = [
155
+ {"role": "system", "content": SYSTEM_PROMPT},
156
+ ]
157
+
158
+ max_steps = obs.get("max_steps", 10)
159
+ total_reward = 0.0
160
+
161
+ for step_num in range(max_steps):
162
+ if obs.get("done", False):
163
+ break
164
+
165
+ user_msg = build_observation_prompt(obs)
166
+ messages.append({"role": "user", "content": user_msg})
167
+
168
+ # Call the LLM
169
+ try:
170
+ completion = client.chat.completions.create(
171
+ model=MODEL_NAME,
172
+ messages=messages,
173
+ temperature=0.2,
174
+ max_tokens=200,
175
+ )
176
+ response_text = completion.choices[0].message.content or ""
177
+ except Exception as e:
178
+ print(f" LLM error at step {step_num}: {e}")
179
+ response_text = '{"action_type": "noop", "params": {}}'
180
+
181
+ action = parse_action(response_text)
182
+ messages.append({"role": "assistant", "content": response_text})
183
+
184
+ print(f" Step {step_num}: {action.get('action_type', 'noop')}({action.get('params', {})})")
185
+
186
+ # Step the environment
187
+ step_resp = httpx.post(
188
+ f"{base}/step",
189
+ json={"action_type": action.get("action_type", "noop"), "params": action.get("params", {})},
190
+ timeout=30.0,
191
+ )
192
+ obs = step_resp.json()
193
+ reward = obs.get("reward", 0.0)
194
+ total_reward += reward if reward else 0.0
195
+
196
+ # Get final state
197
+ state_resp = httpx.get(f"{base}/state", timeout=10.0)
198
+ final_state = state_resp.json()
199
+
200
+ # Grade
201
+ grade_resp = httpx.post(
202
+ f"{base}/grader",
203
+ json={
204
+ "final_slo_score": final_state.get("global_slo_score", 0.0),
205
+ "steps_taken": final_state.get("step_count", 0),
206
+ "max_steps": max_steps,
207
+ "actions_taken": obs.get("actions_taken", []),
208
+ "terminated": final_state.get("terminated", True),
209
+ "termination_reason": final_state.get("termination_reason"),
210
+ },
211
+ timeout=10.0,
212
+ )
213
+ grade = grade_resp.json()
214
+
215
+ return {
216
+ "task_id": task_id,
217
+ "seed": seed,
218
+ "total_reward": total_reward,
219
+ "score": grade.get("score", 0.0),
220
+ "slo_recovery": grade.get("slo_recovery", 0.0),
221
+ "steps_taken": final_state.get("step_count", 0),
222
+ "termination_reason": final_state.get("termination_reason"),
223
+ }
224
+
225
+
226
+ def main() -> None:
227
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
228
+ env_url = os.getenv("ENV_URL", "http://localhost:7860")
229
+
230
+ tasks = ["easy", "medium", "hard"]
231
+ seeds = [42, 123, 7]
232
+
233
+ print("=" * 60)
234
+ print("SevZero Baseline Inference")
235
+ print("=" * 60)
236
+ print(f"Model: {MODEL_NAME}")
237
+ print(f"Environment: {env_url}")
238
+ print()
239
+
240
+ results = []
241
+ for task_id, seed in zip(tasks, seeds):
242
+ print(f"--- Task: {task_id} (seed={seed}) ---")
243
+ result = run_episode(client, env_url, task_id, seed)
244
+ results.append(result)
245
+ print(f" Score: {result['score']:.4f} | SLO Recovery: {result['slo_recovery']:.4f} | "
246
+ f"Steps: {result['steps_taken']} | Outcome: {result['termination_reason']}")
247
+ print()
248
+
249
+ print("=" * 60)
250
+ print("Summary")
251
+ print("=" * 60)
252
+ for r in results:
253
+ print(f" {r['task_id']:8s} → score={r['score']:.4f} slo={r['slo_recovery']:.4f} steps={r['steps_taken']}")
254
+ avg_score = sum(r["score"] for r in results) / len(results) if results else 0.0
255
+ print(f"\n Average score: {avg_score:.4f}")
256
+
257
+
258
+ if __name__ == "__main__":
259
+ main()
pyproject.toml CHANGED
@@ -10,6 +10,7 @@ dependencies = [
10
  "uvicorn>=0.24.0",
11
  "pydantic>=2.0.0",
12
  "openai>=1.0.0",
 
13
  ]
14
 
15
  [project.optional-dependencies]
@@ -18,6 +19,9 @@ dev = [
18
  "httpx>=0.24.0",
19
  ]
20
 
 
 
 
21
  [build-system]
22
  requires = ["hatchling"]
23
  build-backend = "hatchling.build"
 
10
  "uvicorn>=0.24.0",
11
  "pydantic>=2.0.0",
12
  "openai>=1.0.0",
13
+ "httpx>=0.24.0",
14
  ]
15
 
16
  [project.optional-dependencies]
 
19
  "httpx>=0.24.0",
20
  ]
21
 
22
+ [project.scripts]
23
+ server = "server.app:main"
24
+
25
  [build-system]
26
  requires = ["hatchling"]
27
  build-backend = "hatchling.build"