updated-policy / inference_agent.py
srinjoyd's picture
add blog.md
8c26ecf
"""
inference_agent.py β€” real-world SRE incident investigation.
Takes a GitHub issue (URL or raw text) + a local git repo, then drives the
trained GRPO model through the same Phase-2 action loop it was trained on
(list_dir β†’ read_file β†’ search_code β†’ get_git_log β†’ get_file_diff β†’
propose_patch / declare_root_cause).
The model sees an observation dict in exactly the format it was trained on;
the only change is that tool calls hit a real repository instead of a snapshot.
Usage:
# From a GitHub issue URL (requires `gh` CLI authenticated)
python inference_agent.py \\
--model srinjoyd/qwen2.5-7b-sre-grpo \\
--repo /path/to/cloned/repo \\
--issue https://github.com/owner/repo/issues/42
# From raw issue text
python inference_agent.py \\
--model srinjoyd/qwen2.5-7b-sre-grpo \\
--repo /path/to/cloned/repo \\
--issue-text "OrderService crashes with OOM after deploy abc1234"
# Clone the repo automatically
python inference_agent.py \\
--model srinjoyd/qwen2.5-7b-sre-grpo \\
--repo-url https://github.com/owner/repo \\
--issue https://github.com/owner/repo/issues/42
"""
from __future__ import annotations
import argparse
import json
import os
import subprocess
import sys
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from server.real_code_workspace import RealCodeWorkspace, RealCodeWorkspaceError
# ──────────────────────────────────────────────────────────────────────
# Issue fetching
# ──────────────────────────────────────────────────────────────────────
def fetch_github_issue(issue_url: str) -> str:
"""
Use `gh issue view` to fetch the issue title + body.
Requires `gh` CLI authenticated (`gh auth login`).
"""
# Parse owner/repo/number from URL
# https://github.com/owner/repo/issues/42
parts = issue_url.rstrip("/").split("/")
if len(parts) < 7 or parts[-2] != "issues":
raise ValueError(f"Cannot parse issue URL: {issue_url}")
owner_repo = f"{parts[-4]}/{parts[-3]}"
number = parts[-1]
result = subprocess.run(
["gh", "issue", "view", number, "--repo", owner_repo,
"--json", "title,body,labels,state,createdAt,author"],
capture_output=True, text=True, timeout=30,
)
if result.returncode != 0:
raise RuntimeError(
f"gh issue view failed:\n{result.stderr}\n"
"Make sure `gh` is installed and authenticated (`gh auth login`)."
)
data = json.loads(result.stdout)
title = data.get("title", "")
body = (data.get("body") or "").strip()
labels = ", ".join(l.get("name", "") for l in data.get("labels", []))
author = data.get("author", {}).get("login", "unknown")
summary = f"Issue: {title}\nAuthor: {author}"
if labels:
summary += f"\nLabels: {labels}"
if body:
summary += f"\n\n{body[:3000]}" # cap body at 3 KB
return summary
def clone_repo(repo_url: str, target_dir: str) -> str:
"""git clone repo_url into target_dir. Returns target_dir."""
print(f"Cloning {repo_url} β†’ {target_dir}")
subprocess.run(
["git", "clone", "--depth", "50", repo_url, target_dir],
check=True, timeout=120,
)
return target_dir
# ──────────────────────────────────────────────────────────────────────
# Minimal model inference
# ──────────────────────────────────────────────────────────────────────
_SYSTEM_PROMPT = """\
You are an SRE root-cause analyst. You will receive observations from a code \
investigation environment. At each step respond with ONE JSON action only β€” \
no prose, no markdown fences.
Valid actions:
{"action_type": "list_dir", "parameters": {"path": "<rel_path>"}}
{"action_type": "read_file", "parameters": {"path": "<rel_path>"}}
{"action_type": "search_code", "parameters": {"query": "<text>", "file_pattern": "*.py"}}
{"action_type": "get_git_log", "parameters": {"n_commits": 10, "path": ""}}
{"action_type": "get_file_diff", "parameters": {"commit_sha": "<sha>", "path": ""}}
{"action_type": "propose_patch", "parameters": {"diff": "<unified_diff>"}}
{"action_type": "declare_root_cause", "parameters": {"root_cause": "<explanation>"}}
{"action_type": "declare_no_change", "parameters": {}}
Investigation strategy:
1. list_dir(".") to orient yourself.
2. get_git_log to find recent commits.
3. get_file_diff on suspicious commits.
4. read_file / search_code to understand the bug.
5. propose_patch with a minimal unified diff once you have the fix.
declare_no_change only if there is genuinely no code bug.\
"""
def _parse_action(text: str) -> Dict[str, Any]:
text = text.strip().lstrip("`")
if text.startswith("json"):
text = text[4:].strip()
a, b = text.find("{"), text.rfind("}")
if a == -1 or b <= a:
return {"action_type": "declare_no_change", "parameters": {}}
try:
obj = json.loads(text[a : b + 1])
obj.setdefault("parameters", {})
return obj if isinstance(obj, dict) else {"action_type": "declare_no_change", "parameters": {}}
except Exception:
return {"action_type": "declare_no_change", "parameters": {}}
class LocalModelAgent:
"""Loads a local / HF model and drives the investigation loop."""
def __init__(
self,
model_id: str,
load_in_4bit: bool = False,
max_new_tokens: int = 512,
temperature: float = 0.2,
max_history: int = 10,
) -> None:
print(f"Loading model: {model_id}")
tok_kwargs: Dict[str, Any] = {"use_fast": True}
self._tok = AutoTokenizer.from_pretrained(model_id, **tok_kwargs)
if self._tok.pad_token is None:
self._tok.pad_token = self._tok.eos_token
model_kwargs: Dict[str, Any] = {"device_map": "auto"}
if load_in_4bit:
from transformers import BitsAndBytesConfig
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
)
else:
model_kwargs["torch_dtype"] = torch.bfloat16
self._model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
self._model.eval()
self._max_new = max_new_tokens
self._temperature = temperature
self._max_history = max_history
self._msgs: List[Dict[str, str]] = []
def reset(self) -> None:
self._msgs = [{"role": "system", "content": _SYSTEM_PROMPT}]
def act(self, observation: Dict[str, Any]) -> Dict[str, Any]:
payload = json.dumps(observation, default=str)[:5000]
self._msgs.append({"role": "user", "content": payload})
self._trim()
prompt = self._format_chat()
try:
device = self._model.get_input_embeddings().weight.device
except Exception:
device = next(self._model.parameters()).device
inputs = self._tok(prompt, return_tensors="pt").to(device)
with torch.no_grad():
out = self._model.generate(
**inputs,
max_new_tokens=self._max_new,
do_sample=(self._temperature > 0),
temperature=self._temperature,
eos_token_id=self._tok.eos_token_id,
)
text = self._tok.decode(
out[0][inputs["input_ids"].shape[-1]:],
skip_special_tokens=True,
).strip()
self._msgs.append({"role": "assistant", "content": text})
return _parse_action(text)
def _format_chat(self) -> str:
# Try the model's own chat template first
try:
return self._tok.apply_chat_template(
self._msgs, tokenize=False, add_generation_prompt=True
)
except Exception:
pass
# Fallback plain-text format
parts: List[str] = []
for m in self._msgs:
role = m["role"]
content = m["content"]
if role == "system":
parts.append(f"<|system|>\n{content}\n")
elif role == "user":
parts.append(f"<|user|>\n{content}\n")
else:
parts.append(f"<|assistant|>\n{content}\n")
parts.append("<|assistant|>\n")
return "".join(parts)
def _trim(self) -> None:
if len(self._msgs) > 1 + self._max_history * 2:
self._msgs = self._msgs[:1] + self._msgs[-(self._max_history * 2):]
# ──────────────────────────────────────────────────────────────────────
# Dispatch: action β†’ RealCodeWorkspace call
# ──────────────────────────────────────────────────────────────────────
def _dispatch(
action: Dict[str, Any],
workspace: RealCodeWorkspace,
) -> Dict[str, Any]:
atype = action.get("action_type", "")
params = action.get("parameters", {}) or {}
try:
if atype == "list_dir":
return workspace.list_dir(params.get("path", "."))
if atype == "read_file":
return workspace.read_file(params.get("path", ""))
if atype == "search_code":
return workspace.search_code(
query = params.get("query", ""),
file_pattern = params.get("file_pattern", "*.py"),
max_hits = params.get("max_hits"),
)
if atype == "get_git_log":
return workspace.get_git_log(
path = params.get("path", ""),
n_commits = int(params.get("n_commits", 10)),
)
if atype == "get_file_diff":
return workspace.get_file_diff(
commit_sha = params.get("commit_sha", "HEAD"),
path = params.get("path", ""),
)
# Terminal actions β€” handled in the main loop
return {}
except RealCodeWorkspaceError as e:
return {"error": str(e)}
# ──────────────────────────────────────────────────────────────────────
# Main investigation loop
# ──────────────────────────────────────────────────────────────────────
TERMINAL_ACTIONS = {"propose_patch", "declare_root_cause", "declare_no_change"}
VALID_ACTIONS = [
"list_dir", "read_file", "search_code",
"get_git_log", "get_file_diff",
"propose_patch", "declare_root_cause", "declare_no_change",
]
def investigate(
agent: LocalModelAgent,
workspace: RealCodeWorkspace,
incident_summary: str,
max_steps: int = 20,
verbose: bool = True,
) -> Dict[str, Any]:
"""
Drive the agent through the investigation loop.
Returns a result dict with:
root_cause, proposed_patch, steps_taken, action_log
"""
agent.reset()
# Seed observation β€” mirrors the format from Phase-2 training
obs: Dict[str, Any] = {
"current_phase": 2,
"incident_summary": incident_summary,
"bad_commit_sha": workspace.bad_commit_sha or "",
"valid_actions": VALID_ACTIONS,
"action_result": {},
"step": 0,
"repo_tree": workspace.file_tree(max_depth=2),
}
action_log: List[Dict[str, Any]] = []
root_cause: Optional[str] = None
proposed_patch: Optional[str] = None
for step in range(1, max_steps + 1):
obs["step"] = step
action = agent.act(obs)
atype = action.get("action_type", "")
if verbose:
params_str = json.dumps(action.get("parameters", {}))[:120]
print(f" step {step:2d} {atype:<22s} {params_str}")
action_log.append({"step": step, "action": action})
if atype == "declare_root_cause":
root_cause = action.get("parameters", {}).get("root_cause", "")
if verbose:
print(f"\nRoot cause declared:\n {root_cause}")
# Don't break yet β€” model may follow up with propose_patch
obs = {**obs, "action_result": {"acknowledged": True},
"valid_actions": ["propose_patch", "declare_no_change"]}
continue
if atype == "propose_patch":
proposed_patch = action.get("parameters", {}).get("diff", "")
if verbose:
print(f"\nPatch proposed ({len(proposed_patch)} chars)")
break
if atype == "declare_no_change":
if verbose:
print("\nAgent declared: no code change needed.")
break
# Execute the tool call
result = _dispatch(action, workspace)
obs = {
**obs,
"action_result": result,
}
return {
"root_cause": root_cause,
"proposed_patch": proposed_patch,
"steps_taken": step,
"action_log": action_log,
}
# ──────────────────────────────────────────────────────────────────────
# CLI
# ──────────────────────────────────────────────────────────────────────
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Run the trained SRE agent on a real GitHub issue."
)
p.add_argument("--model", required=True,
help="HF model ID or local path of the GRPO checkpoint")
# Repo β€” one of these required
g = p.add_mutually_exclusive_group(required=True)
g.add_argument("--repo", help="Path to an already-cloned local repo")
g.add_argument("--repo-url", help="Git URL to clone (cloned to a temp dir)")
# Issue β€” one of these required
ig = p.add_mutually_exclusive_group(required=True)
ig.add_argument("--issue", help="GitHub issue URL (requires gh CLI)")
ig.add_argument("--issue-text", help="Raw issue description text")
p.add_argument("--bad-commit", default="",
help="SHA of the suspected bad commit (optional hint)")
p.add_argument("--max-steps", type=int, default=20)
p.add_argument("--max-new-tokens",type=int, default=512)
p.add_argument("--load-in-4bit", action="store_true", default=False)
p.add_argument("--output", default=None,
help="Write JSON result to this file")
p.add_argument("--quiet", action="store_true")
return p.parse_args()
def main() -> None:
args = _parse_args()
verbose = not args.quiet
# ── 1. Get incident summary ──────────────────────────────────────
if args.issue:
if verbose:
print(f"Fetching issue: {args.issue}")
incident_summary = fetch_github_issue(args.issue)
else:
incident_summary = args.issue_text
if verbose:
print("\n── Incident summary ──")
print(incident_summary[:600])
print()
# ── 2. Set up repo ───────────────────────────────────────────────
_tmpdir = None
if args.repo_url:
_tmpdir = tempfile.mkdtemp(prefix="sre_agent_")
repo_path = clone_repo(args.repo_url, _tmpdir)
else:
repo_path = args.repo
workspace = RealCodeWorkspace(repo_path, bad_commit_sha=args.bad_commit or "")
# ── 3. Load model ────────────────────────────────────────────────
agent = LocalModelAgent(
model_id = args.model,
load_in_4bit = args.load_in_4bit,
max_new_tokens = args.max_new_tokens,
)
# ── 4. Investigate ───────────────────────────────────────────────
if verbose:
print("── Investigation ──")
result = investigate(
agent=agent,
workspace=workspace,
incident_summary=incident_summary,
max_steps=args.max_steps,
verbose=verbose,
)
# ── 5. Output ────────────────────────────────────────────────────
print("\n" + "═" * 60)
print("ROOT CAUSE:")
print(result["root_cause"] or "(not declared)")
print("\nPROPOSED PATCH:")
print(result["proposed_patch"] or "(none)")
print(f"\n({result['steps_taken']} steps taken)")
if args.output:
with open(args.output, "w") as f:
json.dump(result, f, indent=2, default=str)
print(f"\nFull result written to {args.output}")
# Cleanup temp clone
if _tmpdir:
import shutil
shutil.rmtree(_tmpdir, ignore_errors=True)
if __name__ == "__main__":
main()