""" inference_agent.py — real-world SRE incident investigation. Takes a GitHub issue (URL or raw text) + a local git repo, then drives the trained GRPO model through the same Phase-2 action loop it was trained on (list_dir → read_file → search_code → get_git_log → get_file_diff → propose_patch / declare_root_cause). The model sees an observation dict in exactly the format it was trained on; the only change is that tool calls hit a real repository instead of a snapshot. Usage: # From a GitHub issue URL (requires `gh` CLI authenticated) python inference_agent.py \\ --model srinjoyd/qwen2.5-7b-sre-grpo \\ --repo /path/to/cloned/repo \\ --issue https://github.com/owner/repo/issues/42 # From raw issue text python inference_agent.py \\ --model srinjoyd/qwen2.5-7b-sre-grpo \\ --repo /path/to/cloned/repo \\ --issue-text "OrderService crashes with OOM after deploy abc1234" # Clone the repo automatically python inference_agent.py \\ --model srinjoyd/qwen2.5-7b-sre-grpo \\ --repo-url https://github.com/owner/repo \\ --issue https://github.com/owner/repo/issues/42 """ from __future__ import annotations import argparse import json import os import subprocess import sys import tempfile from pathlib import Path from typing import Any, Dict, List, Optional import torch from transformers import AutoModelForCausalLM, AutoTokenizer from server.real_code_workspace import RealCodeWorkspace, RealCodeWorkspaceError # ────────────────────────────────────────────────────────────────────── # Issue fetching # ────────────────────────────────────────────────────────────────────── def fetch_github_issue(issue_url: str) -> str: """ Use `gh issue view` to fetch the issue title + body. Requires `gh` CLI authenticated (`gh auth login`). """ # Parse owner/repo/number from URL # https://github.com/owner/repo/issues/42 parts = issue_url.rstrip("/").split("/") if len(parts) < 7 or parts[-2] != "issues": raise ValueError(f"Cannot parse issue URL: {issue_url}") owner_repo = f"{parts[-4]}/{parts[-3]}" number = parts[-1] result = subprocess.run( ["gh", "issue", "view", number, "--repo", owner_repo, "--json", "title,body,labels,state,createdAt,author"], capture_output=True, text=True, timeout=30, ) if result.returncode != 0: raise RuntimeError( f"gh issue view failed:\n{result.stderr}\n" "Make sure `gh` is installed and authenticated (`gh auth login`)." ) data = json.loads(result.stdout) title = data.get("title", "") body = (data.get("body") or "").strip() labels = ", ".join(l.get("name", "") for l in data.get("labels", [])) author = data.get("author", {}).get("login", "unknown") summary = f"Issue: {title}\nAuthor: {author}" if labels: summary += f"\nLabels: {labels}" if body: summary += f"\n\n{body[:3000]}" # cap body at 3 KB return summary def clone_repo(repo_url: str, target_dir: str) -> str: """git clone repo_url into target_dir. Returns target_dir.""" print(f"Cloning {repo_url} → {target_dir}") subprocess.run( ["git", "clone", "--depth", "50", repo_url, target_dir], check=True, timeout=120, ) return target_dir # ────────────────────────────────────────────────────────────────────── # Minimal model inference # ────────────────────────────────────────────────────────────────────── _SYSTEM_PROMPT = """\ You are an SRE root-cause analyst. You will receive observations from a code \ investigation environment. At each step respond with ONE JSON action only — \ no prose, no markdown fences. Valid actions: {"action_type": "list_dir", "parameters": {"path": ""}} {"action_type": "read_file", "parameters": {"path": ""}} {"action_type": "search_code", "parameters": {"query": "", "file_pattern": "*.py"}} {"action_type": "get_git_log", "parameters": {"n_commits": 10, "path": ""}} {"action_type": "get_file_diff", "parameters": {"commit_sha": "", "path": ""}} {"action_type": "propose_patch", "parameters": {"diff": ""}} {"action_type": "declare_root_cause", "parameters": {"root_cause": ""}} {"action_type": "declare_no_change", "parameters": {}} Investigation strategy: 1. list_dir(".") to orient yourself. 2. get_git_log to find recent commits. 3. get_file_diff on suspicious commits. 4. read_file / search_code to understand the bug. 5. propose_patch with a minimal unified diff once you have the fix. declare_no_change only if there is genuinely no code bug.\ """ def _parse_action(text: str) -> Dict[str, Any]: text = text.strip().lstrip("`") if text.startswith("json"): text = text[4:].strip() a, b = text.find("{"), text.rfind("}") if a == -1 or b <= a: return {"action_type": "declare_no_change", "parameters": {}} try: obj = json.loads(text[a : b + 1]) obj.setdefault("parameters", {}) return obj if isinstance(obj, dict) else {"action_type": "declare_no_change", "parameters": {}} except Exception: return {"action_type": "declare_no_change", "parameters": {}} class LocalModelAgent: """Loads a local / HF model and drives the investigation loop.""" def __init__( self, model_id: str, load_in_4bit: bool = False, max_new_tokens: int = 512, temperature: float = 0.2, max_history: int = 10, ) -> None: print(f"Loading model: {model_id}") tok_kwargs: Dict[str, Any] = {"use_fast": True} self._tok = AutoTokenizer.from_pretrained(model_id, **tok_kwargs) if self._tok.pad_token is None: self._tok.pad_token = self._tok.eos_token model_kwargs: Dict[str, Any] = {"device_map": "auto"} if load_in_4bit: from transformers import BitsAndBytesConfig model_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", ) else: model_kwargs["torch_dtype"] = torch.bfloat16 self._model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) self._model.eval() self._max_new = max_new_tokens self._temperature = temperature self._max_history = max_history self._msgs: List[Dict[str, str]] = [] def reset(self) -> None: self._msgs = [{"role": "system", "content": _SYSTEM_PROMPT}] def act(self, observation: Dict[str, Any]) -> Dict[str, Any]: payload = json.dumps(observation, default=str)[:5000] self._msgs.append({"role": "user", "content": payload}) self._trim() prompt = self._format_chat() try: device = self._model.get_input_embeddings().weight.device except Exception: device = next(self._model.parameters()).device inputs = self._tok(prompt, return_tensors="pt").to(device) with torch.no_grad(): out = self._model.generate( **inputs, max_new_tokens=self._max_new, do_sample=(self._temperature > 0), temperature=self._temperature, eos_token_id=self._tok.eos_token_id, ) text = self._tok.decode( out[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True, ).strip() self._msgs.append({"role": "assistant", "content": text}) return _parse_action(text) def _format_chat(self) -> str: # Try the model's own chat template first try: return self._tok.apply_chat_template( self._msgs, tokenize=False, add_generation_prompt=True ) except Exception: pass # Fallback plain-text format parts: List[str] = [] for m in self._msgs: role = m["role"] content = m["content"] if role == "system": parts.append(f"<|system|>\n{content}\n") elif role == "user": parts.append(f"<|user|>\n{content}\n") else: parts.append(f"<|assistant|>\n{content}\n") parts.append("<|assistant|>\n") return "".join(parts) def _trim(self) -> None: if len(self._msgs) > 1 + self._max_history * 2: self._msgs = self._msgs[:1] + self._msgs[-(self._max_history * 2):] # ────────────────────────────────────────────────────────────────────── # Dispatch: action → RealCodeWorkspace call # ────────────────────────────────────────────────────────────────────── def _dispatch( action: Dict[str, Any], workspace: RealCodeWorkspace, ) -> Dict[str, Any]: atype = action.get("action_type", "") params = action.get("parameters", {}) or {} try: if atype == "list_dir": return workspace.list_dir(params.get("path", ".")) if atype == "read_file": return workspace.read_file(params.get("path", "")) if atype == "search_code": return workspace.search_code( query = params.get("query", ""), file_pattern = params.get("file_pattern", "*.py"), max_hits = params.get("max_hits"), ) if atype == "get_git_log": return workspace.get_git_log( path = params.get("path", ""), n_commits = int(params.get("n_commits", 10)), ) if atype == "get_file_diff": return workspace.get_file_diff( commit_sha = params.get("commit_sha", "HEAD"), path = params.get("path", ""), ) # Terminal actions — handled in the main loop return {} except RealCodeWorkspaceError as e: return {"error": str(e)} # ────────────────────────────────────────────────────────────────────── # Main investigation loop # ────────────────────────────────────────────────────────────────────── TERMINAL_ACTIONS = {"propose_patch", "declare_root_cause", "declare_no_change"} VALID_ACTIONS = [ "list_dir", "read_file", "search_code", "get_git_log", "get_file_diff", "propose_patch", "declare_root_cause", "declare_no_change", ] def investigate( agent: LocalModelAgent, workspace: RealCodeWorkspace, incident_summary: str, max_steps: int = 20, verbose: bool = True, ) -> Dict[str, Any]: """ Drive the agent through the investigation loop. Returns a result dict with: root_cause, proposed_patch, steps_taken, action_log """ agent.reset() # Seed observation — mirrors the format from Phase-2 training obs: Dict[str, Any] = { "current_phase": 2, "incident_summary": incident_summary, "bad_commit_sha": workspace.bad_commit_sha or "", "valid_actions": VALID_ACTIONS, "action_result": {}, "step": 0, "repo_tree": workspace.file_tree(max_depth=2), } action_log: List[Dict[str, Any]] = [] root_cause: Optional[str] = None proposed_patch: Optional[str] = None for step in range(1, max_steps + 1): obs["step"] = step action = agent.act(obs) atype = action.get("action_type", "") if verbose: params_str = json.dumps(action.get("parameters", {}))[:120] print(f" step {step:2d} {atype:<22s} {params_str}") action_log.append({"step": step, "action": action}) if atype == "declare_root_cause": root_cause = action.get("parameters", {}).get("root_cause", "") if verbose: print(f"\nRoot cause declared:\n {root_cause}") # Don't break yet — model may follow up with propose_patch obs = {**obs, "action_result": {"acknowledged": True}, "valid_actions": ["propose_patch", "declare_no_change"]} continue if atype == "propose_patch": proposed_patch = action.get("parameters", {}).get("diff", "") if verbose: print(f"\nPatch proposed ({len(proposed_patch)} chars)") break if atype == "declare_no_change": if verbose: print("\nAgent declared: no code change needed.") break # Execute the tool call result = _dispatch(action, workspace) obs = { **obs, "action_result": result, } return { "root_cause": root_cause, "proposed_patch": proposed_patch, "steps_taken": step, "action_log": action_log, } # ────────────────────────────────────────────────────────────────────── # CLI # ────────────────────────────────────────────────────────────────────── def _parse_args() -> argparse.Namespace: p = argparse.ArgumentParser( description="Run the trained SRE agent on a real GitHub issue." ) p.add_argument("--model", required=True, help="HF model ID or local path of the GRPO checkpoint") # Repo — one of these required g = p.add_mutually_exclusive_group(required=True) g.add_argument("--repo", help="Path to an already-cloned local repo") g.add_argument("--repo-url", help="Git URL to clone (cloned to a temp dir)") # Issue — one of these required ig = p.add_mutually_exclusive_group(required=True) ig.add_argument("--issue", help="GitHub issue URL (requires gh CLI)") ig.add_argument("--issue-text", help="Raw issue description text") p.add_argument("--bad-commit", default="", help="SHA of the suspected bad commit (optional hint)") p.add_argument("--max-steps", type=int, default=20) p.add_argument("--max-new-tokens",type=int, default=512) p.add_argument("--load-in-4bit", action="store_true", default=False) p.add_argument("--output", default=None, help="Write JSON result to this file") p.add_argument("--quiet", action="store_true") return p.parse_args() def main() -> None: args = _parse_args() verbose = not args.quiet # ── 1. Get incident summary ────────────────────────────────────── if args.issue: if verbose: print(f"Fetching issue: {args.issue}") incident_summary = fetch_github_issue(args.issue) else: incident_summary = args.issue_text if verbose: print("\n── Incident summary ──") print(incident_summary[:600]) print() # ── 2. Set up repo ─────────────────────────────────────────────── _tmpdir = None if args.repo_url: _tmpdir = tempfile.mkdtemp(prefix="sre_agent_") repo_path = clone_repo(args.repo_url, _tmpdir) else: repo_path = args.repo workspace = RealCodeWorkspace(repo_path, bad_commit_sha=args.bad_commit or "") # ── 3. Load model ──────────────────────────────────────────────── agent = LocalModelAgent( model_id = args.model, load_in_4bit = args.load_in_4bit, max_new_tokens = args.max_new_tokens, ) # ── 4. Investigate ─────────────────────────────────────────────── if verbose: print("── Investigation ──") result = investigate( agent=agent, workspace=workspace, incident_summary=incident_summary, max_steps=args.max_steps, verbose=verbose, ) # ── 5. Output ──────────────────────────────────────────────────── print("\n" + "═" * 60) print("ROOT CAUSE:") print(result["root_cause"] or "(not declared)") print("\nPROPOSED PATCH:") print(result["proposed_patch"] or "(none)") print(f"\n({result['steps_taken']} steps taken)") if args.output: with open(args.output, "w") as f: json.dump(result, f, indent=2, default=str) print(f"\nFull result written to {args.output}") # Cleanup temp clone if _tmpdir: import shutil shutil.rmtree(_tmpdir, ignore_errors=True) if __name__ == "__main__": main()