File size: 15,809 Bytes
408d02c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140d024
408d02c
 
 
140d024
408d02c
 
0b900fd
 
 
 
 
408d02c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abb0dea
 
 
140d024
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408d02c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140d024
408d02c
140d024
408d02c
140d024
408d02c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abb0dea
408d02c
 
 
 
 
fdd0183
140d024
408d02c
140d024
 
 
 
 
 
 
 
 
 
408d02c
 
 
 
 
 
 
 
 
 
 
 
140d024
 
 
 
408d02c
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
"""
Inference Script β€” ReleaseOps-Env

Env vars:
    API_BASE_URL      LLM API endpoint   (default: https://router.huggingface.co/v1)
    MODEL_NAME        Model identifier   (default: Qwen/Qwen2.5-72B-Instruct)
    OPENAI_API_KEY    API key            (or HF_TOKEN)
    HF_TOKEN          API key fallback
    ENV_URL           Environment URL    (default: http://localhost:7860)

Install: pip install openai requests

Structured stdout logs follow [START], [STEP], [END] format per hackathon requirements.
"""

import json
import os
import textwrap
import time
from typing import List, Optional, TypedDict

import requests
from openai import OpenAI
from releaseops_env.scoring import format_score, normalize_score

# ── Config ──────────────────────────────────────────────────────────────────────
# Validator injects these - no fallbacks allowed
API_BASE_URL = os.environ.get("API_BASE_URL") or os.environ.get("OPENAI_BASE_URL", "https://router.huggingface.co/v1")
API_KEY      = os.environ.get("API_KEY") or os.environ.get("OPENAI_API_KEY") or os.environ.get("HF_TOKEN")
MODEL_NAME   = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
ENV_URL      = os.environ.get("ENV_URL", "http://localhost:7860").rstrip("/")


# ── Structured logging helpers ──────────────────────────────────────────────────
def log_start(task_id: str, model_name: str, benchmark: str = "releaseops"):
    """Emit [START] log with task metadata."""
    print(f"[START] task={task_id} env={benchmark} model={model_name or 'unknown'}", flush=True)


def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str] = None):
    """Emit [STEP] log with action and reward."""
    error_val = error if error else "null"
    done_val = str(done).lower()
    print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)


def log_end(success: bool, steps: int, score: float, rewards: List[float]):
    """Emit [END] log with final results."""
    rewards_str = ",".join(f"{r:.2f}" for r in rewards)
    print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)


class TaskResult(TypedDict):
    task_id: str
    final_score: float
    steps_taken: int
    done: bool
    errors: List[str]


def make_task_result(
    task_id: str, final_score: float, steps_taken: int, done: bool, errors: List[str]
) -> TaskResult:
    return {
        "task_id": task_id,
        "final_score": normalize_score(final_score),
        "steps_taken": int(steps_taken),
        "done": bool(done),
        "errors": errors,
    }


def emit_task_result(result: TaskResult) -> None:
    """Emit machine-parseable per-task result JSON."""
    print(json.dumps({"type": "task_result", **result}, sort_keys=True), flush=True)

TASKS       = ["easy_001", "easy_002", "medium_001", "medium_002", "hard_001", "hard_002"]
MAX_STEPS   = 14
TEMPERATURE = 0.0  # reproducible

# ── Prompt ──────────────────────────────────────────────────────────────────────
SYSTEM_PROMPT = textwrap.dedent("""
    You are an SRE agent reviewing a proposed software change for production rollout.
    Investigate thoroughly, then submit a final decision.

    Investigation tools (use as needed β€” not all are required for every change):
      inspect_change       section: "diff"|"tests"|"approvals"|"files_changed"
      inspect_services     service: <service_name_exactly_as_shown>
      inspect_dependencies (no extra params) β€” reveals blast radius
      search_incidents     keywords: ["word1", "word2"] β€” queries real incident DB
      check_policy         (no extra params) β€” checks rollout policy rules
      query_telemetry      metric: "p99"|"error_rate"|"p95"|"cpu"|"rps"|"queue_depth"
                           service: <name>, window: "5m"|"15m"|"1h"
                           β€” check CURRENT live metrics before deciding; risky changes
                             (concurrency, rate limiting, connection pools) often show
                             pre-existing anomalies that static inspection misses
      request_artifact     artifact_type: "load_test"|"rollback_plan"|"approval"|"runbook"
      control_rollout      decision: "start_canary"|"promote"|"pause"|"rollback"
      submit_decision      final_decision: "approve"|"request_changes"|"block"|"escalate"
                           reason_codes: [<signal_id>, ...]  β€” use signal_ids from known_risk_signals

    Investigation strategy:
    - ALWAYS start: inspect_change(diff) β†’ inspect_change(tests) β†’ inspect_change(approvals) β†’ inspect_dependencies
    - inspect_dependencies reveals the EXACT service names you MUST use for inspect_services and query_telemetry
    - NEVER guess service names β€” only use names returned by inspect_dependencies
    - For changes touching high-traffic services, query live telemetry β€” pre-existing
      degradation is a blocker even if tests pass
    - Search incidents with keywords from the change type (e.g. "retry", "rate_limit", "pool")
    - Check policy after gathering evidence
    - reason_codes MUST be string signal_ids from known_risk_signals (e.g. "missing_load_test"), never numbers

    Respond with ONLY a valid JSON object. No explanation. No markdown.
    Example: {"action_type": "query_telemetry", "metric": "error_rate", "service": "api-gateway", "window": "5m"}
""").strip()


# ── Prompt builder ───────────────────────────────────────────────────────────────
def build_prompt(step: int, obs: dict, history: List[str]) -> str:
    risk_lines = "\n".join(
        f"  [{r['severity'].upper()}] {r['signal_id']}: {r['summary']}"
        for r in obs.get("known_risk_signals", [])
    ) or "  (none discovered yet)"

    last = obs.get("last_tool_result")
    if last:
        status = "OK" if last["success"] else "FAIL"
        last_result = f"[{status}] {last['tool_name']}:\n{last['content'][:600]}"
    else:
        last_result = "(none)"

    return textwrap.dedent(f"""
        Step: {step}/{MAX_STEPS} | Phase: {obs['rollout_phase']} | Budget: {obs['time_remaining']}
        Task: {obs['task_id']}
        Change: {obs['change_summary']}

        Risk signals (use signal_id as reason_codes):
        {risk_lines}

        Last result:
        {last_result}

        Actions taken so far β€” DO NOT repeat:
        {chr(10).join(history) or '(none)'}

        Output next action as JSON.
    """).strip()


# ── Action parsing ───────────────────────────────────────────────────────────────
# Fields allowed per action_type β€” prevents hallucinated fields from failing validation
_ACTION_FIELDS: dict[str, set] = {
    "inspect_change":       {"section"},
    "inspect_services":     {"service"},
    "inspect_dependencies": set(),
    "search_incidents":     {"keywords"},
    "check_policy":         set(),
    "query_telemetry":      {"metric", "service", "window"},
    "request_artifact":     {"artifact_type"},
    "control_rollout":      {"decision"},
    "submit_decision":      {"final_decision", "reason_codes"},
}

_VALID_SECTIONS   = {"diff", "tests", "approvals", "files_changed"}
_VALID_METRICS    = {"p50", "p95", "p99", "error_rate", "queue_depth", "cpu", "rps"}
_VALID_WINDOWS    = {"5m", "15m", "1h"}
_VALID_ARTIFACTS  = {"load_test", "rollback_plan", "approval", "runbook", "security_review", "compliance_check"}
_VALID_ROLLOUT    = {"start_canary", "pause", "promote", "rollback"}
_VALID_DECISIONS  = {"approve", "request_changes", "block", "escalate"}


def parse_action(text: str) -> Optional[dict]:
    text = text.strip()
    if text.startswith("```"):
        text = "\n".join(l for l in text.splitlines() if not l.startswith("```")).strip()
    s, e = text.find("{"), text.rfind("}") + 1
    if s >= 0 and e > s:
        text = text[s:e]
    try:
        data = json.loads(text)
    except Exception:
        return None

    action_type = data.get("action_type", "")
    if action_type not in _ACTION_FIELDS:
        return None

    # Keep only fields valid for this action_type
    allowed = _ACTION_FIELDS[action_type]
    result: dict = {"action_type": action_type}
    for k in allowed:
        if k in data:
            result[k] = data[k]

    # Validate enum fields β€” strip if invalid to avoid Pydantic rejection
    if "section" in result and result["section"] not in _VALID_SECTIONS:
        del result["section"]
    if "metric" in result and result["metric"] not in _VALID_METRICS:
        del result["metric"]
    if "window" in result and result["window"] not in _VALID_WINDOWS:
        del result["window"]
    if "artifact_type" in result and result["artifact_type"] not in _VALID_ARTIFACTS:
        del result["artifact_type"]
    if "decision" in result and result["decision"] not in _VALID_ROLLOUT:
        del result["decision"]
    if "final_decision" in result and result["final_decision"] not in _VALID_DECISIONS:
        del result["final_decision"]
    if "keywords" in result and not isinstance(result["keywords"], list):
        result["keywords"] = [str(result["keywords"])]
    if "reason_codes" in result:
        if not isinstance(result["reason_codes"], list):
            result["reason_codes"] = [str(result["reason_codes"])]
        else:
            result["reason_codes"] = [str(rc) for rc in result["reason_codes"] if rc is not None]

    return result


# ── Simple HTTP Environment Client ──────────────────────────────────────────────
class SimpleEnvClient:
    """Simple HTTP client for ReleaseOps-Env."""
    
    def __init__(self, base_url: str):
        self.base_url = base_url.rstrip("/")
    
    def reset(self, task_id: str) -> dict:
        resp = requests.post(f"{self.base_url}/reset", json={"task_id": task_id}, timeout=30)
        resp.raise_for_status()
        return resp.json()
    
    def step(self, action: dict) -> dict:
        resp = requests.post(f"{self.base_url}/step", json={"action": action}, timeout=30)
        resp.raise_for_status()
        return resp.json()


# ── Task runner ──────────────────────────────────────────────────────────────────
def run_task(llm: OpenAI, task_id: str) -> dict:
    log_start(task_id, MODEL_NAME)
    
    rewards: List[float] = []
    errors: List[str] = []
    step = 0
    done = False
    success = False
    score = 0.001
    env = SimpleEnvClient(base_url=ENV_URL)

    try:
        result = env.reset(task_id=task_id)
        obs_dict = result.get("observation", result)
        done = result.get("done", False)

        history: List[str] = []

        for step in range(1, MAX_STEPS + 1):
            if done:
                break

            response_text = ""
            for attempt in range(4):
                try:
                    completion = llm.chat.completions.create(
                        model=MODEL_NAME,
                        messages=[
                            {"role": "system", "content": SYSTEM_PROMPT},
                            {"role": "user", "content": build_prompt(step, obs_dict, history)},
                        ],
                        temperature=TEMPERATURE,
                        max_tokens=200,
                    )
                    response_text = completion.choices[0].message.content or ""
                    break
                except Exception as exc:
                    msg = str(exc)
                    if "429" in msg or "rate" in msg.lower():
                        wait = 15 * (attempt + 1)
                        time.sleep(wait)
                    else:
                        break

            action = parse_action(response_text)
            if action is None:
                action = {"action_type": "check_policy"}

            # Force submit on last step
            if step == MAX_STEPS and action.get("action_type") != "submit_decision":
                risks = obs_dict.get("known_risk_signals", [])
                codes = [r["signal_id"] for r in risks] or ["INSUFFICIENT_EVIDENCE"]
                has_high = any(r["severity"] in ("high", "critical") for r in risks)
                action = {
                    "action_type": "submit_decision",
                    "final_decision": "request_changes" if has_high else "approve",
                    "reason_codes": codes,
                }

            result = env.step(action)
            obs_dict = result.get("observation", result)
            
            last_reward = result.get("reward", 0) or 0
            done = result.get("done", False)
            
            rewards.append(last_reward)

            # Format action as string for logging
            action_str = f"{action.get('action_type', 'unknown')}"
            if action.get('section'):
                action_str += f"(section={action['section']})"
            elif action.get('metric'):
                action_str += f"(metric={action['metric']})"
            elif action.get('decision'):
                action_str += f"({action['decision']})"
            elif action.get('final_decision'):
                action_str += f"({action['final_decision']})"

            # Emit structured [STEP] log
            log_step(step, action_str, last_reward, done, error=None)

            history.append(
                f"Step {step}: {action.get('action_type')}"
                f"({action.get('section') or action.get('metric') or action.get('decision') or ''})"
                f" -> reward {last_reward:+.2f}"
            )

            if done:
                break

        score = normalize_score(obs_dict.get("final_score") or 0.0)
        success = score >= 0.5

    except Exception as e:
        print(f"[DEBUG] Task {task_id} failed with error: {e}", flush=True)
        success = False
        score = 0.001
        errors.append(str(e))

    log_end(success, step, score, rewards)
    task_result = make_task_result(
        task_id=task_id,
        final_score=score,
        steps_taken=step,
        done=done,
        errors=errors,
    )
    emit_task_result(task_result)
    return task_result


# ── Entry point ──────────────────────────────────────────────────────────────────
def main():
    try:
        llm = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
        results = [run_task(llm, t) for t in TASKS]

        print(f"\n{'='*60}\nResults\n{'='*60}")
        total = 0.0
        for r in results:
            total += r["final_score"]
            print(
                f"  {r['task_id']:15s}  score={format_score(r['final_score'])}  steps={r['steps_taken']}"
            )
        print(f"  {'AVERAGE':15s}  score={format_score(total / len(results))}")
        return results
    except Exception as e:
        print(f"[ERROR] Fatal error in main: {e}", flush=True)
        import traceback
        traceback.print_exc()
        raise


if __name__ == "__main__":
    main()