Spaces:
Sleeping
Sleeping
| """ | |
| inference_agent.py β real-world SRE incident investigation. | |
| Takes a GitHub issue (URL or raw text) + a local git repo, then drives the | |
| trained GRPO model through the same Phase-2 action loop it was trained on | |
| (list_dir β read_file β search_code β get_git_log β get_file_diff β | |
| propose_patch / declare_root_cause). | |
| The model sees an observation dict in exactly the format it was trained on; | |
| the only change is that tool calls hit a real repository instead of a snapshot. | |
| Usage: | |
| # From a GitHub issue URL (requires `gh` CLI authenticated) | |
| python inference_agent.py \\ | |
| --model srinjoyd/qwen2.5-7b-sre-grpo \\ | |
| --repo /path/to/cloned/repo \\ | |
| --issue https://github.com/owner/repo/issues/42 | |
| # From raw issue text | |
| python inference_agent.py \\ | |
| --model srinjoyd/qwen2.5-7b-sre-grpo \\ | |
| --repo /path/to/cloned/repo \\ | |
| --issue-text "OrderService crashes with OOM after deploy abc1234" | |
| # Clone the repo automatically | |
| python inference_agent.py \\ | |
| --model srinjoyd/qwen2.5-7b-sre-grpo \\ | |
| --repo-url https://github.com/owner/repo \\ | |
| --issue https://github.com/owner/repo/issues/42 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import subprocess | |
| import sys | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from server.real_code_workspace import RealCodeWorkspace, RealCodeWorkspaceError | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Issue fetching | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def fetch_github_issue(issue_url: str) -> str: | |
| """ | |
| Use `gh issue view` to fetch the issue title + body. | |
| Requires `gh` CLI authenticated (`gh auth login`). | |
| """ | |
| # Parse owner/repo/number from URL | |
| # https://github.com/owner/repo/issues/42 | |
| parts = issue_url.rstrip("/").split("/") | |
| if len(parts) < 7 or parts[-2] != "issues": | |
| raise ValueError(f"Cannot parse issue URL: {issue_url}") | |
| owner_repo = f"{parts[-4]}/{parts[-3]}" | |
| number = parts[-1] | |
| result = subprocess.run( | |
| ["gh", "issue", "view", number, "--repo", owner_repo, | |
| "--json", "title,body,labels,state,createdAt,author"], | |
| capture_output=True, text=True, timeout=30, | |
| ) | |
| if result.returncode != 0: | |
| raise RuntimeError( | |
| f"gh issue view failed:\n{result.stderr}\n" | |
| "Make sure `gh` is installed and authenticated (`gh auth login`)." | |
| ) | |
| data = json.loads(result.stdout) | |
| title = data.get("title", "") | |
| body = (data.get("body") or "").strip() | |
| labels = ", ".join(l.get("name", "") for l in data.get("labels", [])) | |
| author = data.get("author", {}).get("login", "unknown") | |
| summary = f"Issue: {title}\nAuthor: {author}" | |
| if labels: | |
| summary += f"\nLabels: {labels}" | |
| if body: | |
| summary += f"\n\n{body[:3000]}" # cap body at 3 KB | |
| return summary | |
| def clone_repo(repo_url: str, target_dir: str) -> str: | |
| """git clone repo_url into target_dir. Returns target_dir.""" | |
| print(f"Cloning {repo_url} β {target_dir}") | |
| subprocess.run( | |
| ["git", "clone", "--depth", "50", repo_url, target_dir], | |
| check=True, timeout=120, | |
| ) | |
| return target_dir | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Minimal model inference | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _SYSTEM_PROMPT = """\ | |
| You are an SRE root-cause analyst. You will receive observations from a code \ | |
| investigation environment. At each step respond with ONE JSON action only β \ | |
| no prose, no markdown fences. | |
| Valid actions: | |
| {"action_type": "list_dir", "parameters": {"path": "<rel_path>"}} | |
| {"action_type": "read_file", "parameters": {"path": "<rel_path>"}} | |
| {"action_type": "search_code", "parameters": {"query": "<text>", "file_pattern": "*.py"}} | |
| {"action_type": "get_git_log", "parameters": {"n_commits": 10, "path": ""}} | |
| {"action_type": "get_file_diff", "parameters": {"commit_sha": "<sha>", "path": ""}} | |
| {"action_type": "propose_patch", "parameters": {"diff": "<unified_diff>"}} | |
| {"action_type": "declare_root_cause", "parameters": {"root_cause": "<explanation>"}} | |
| {"action_type": "declare_no_change", "parameters": {}} | |
| Investigation strategy: | |
| 1. list_dir(".") to orient yourself. | |
| 2. get_git_log to find recent commits. | |
| 3. get_file_diff on suspicious commits. | |
| 4. read_file / search_code to understand the bug. | |
| 5. propose_patch with a minimal unified diff once you have the fix. | |
| declare_no_change only if there is genuinely no code bug.\ | |
| """ | |
| def _parse_action(text: str) -> Dict[str, Any]: | |
| text = text.strip().lstrip("`") | |
| if text.startswith("json"): | |
| text = text[4:].strip() | |
| a, b = text.find("{"), text.rfind("}") | |
| if a == -1 or b <= a: | |
| return {"action_type": "declare_no_change", "parameters": {}} | |
| try: | |
| obj = json.loads(text[a : b + 1]) | |
| obj.setdefault("parameters", {}) | |
| return obj if isinstance(obj, dict) else {"action_type": "declare_no_change", "parameters": {}} | |
| except Exception: | |
| return {"action_type": "declare_no_change", "parameters": {}} | |
| class LocalModelAgent: | |
| """Loads a local / HF model and drives the investigation loop.""" | |
| def __init__( | |
| self, | |
| model_id: str, | |
| load_in_4bit: bool = False, | |
| max_new_tokens: int = 512, | |
| temperature: float = 0.2, | |
| max_history: int = 10, | |
| ) -> None: | |
| print(f"Loading model: {model_id}") | |
| tok_kwargs: Dict[str, Any] = {"use_fast": True} | |
| self._tok = AutoTokenizer.from_pretrained(model_id, **tok_kwargs) | |
| if self._tok.pad_token is None: | |
| self._tok.pad_token = self._tok.eos_token | |
| model_kwargs: Dict[str, Any] = {"device_map": "auto"} | |
| if load_in_4bit: | |
| from transformers import BitsAndBytesConfig | |
| model_kwargs["quantization_config"] = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_quant_type="nf4", | |
| ) | |
| else: | |
| model_kwargs["torch_dtype"] = torch.bfloat16 | |
| self._model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) | |
| self._model.eval() | |
| self._max_new = max_new_tokens | |
| self._temperature = temperature | |
| self._max_history = max_history | |
| self._msgs: List[Dict[str, str]] = [] | |
| def reset(self) -> None: | |
| self._msgs = [{"role": "system", "content": _SYSTEM_PROMPT}] | |
| def act(self, observation: Dict[str, Any]) -> Dict[str, Any]: | |
| payload = json.dumps(observation, default=str)[:5000] | |
| self._msgs.append({"role": "user", "content": payload}) | |
| self._trim() | |
| prompt = self._format_chat() | |
| try: | |
| device = self._model.get_input_embeddings().weight.device | |
| except Exception: | |
| device = next(self._model.parameters()).device | |
| inputs = self._tok(prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| out = self._model.generate( | |
| **inputs, | |
| max_new_tokens=self._max_new, | |
| do_sample=(self._temperature > 0), | |
| temperature=self._temperature, | |
| eos_token_id=self._tok.eos_token_id, | |
| ) | |
| text = self._tok.decode( | |
| out[0][inputs["input_ids"].shape[-1]:], | |
| skip_special_tokens=True, | |
| ).strip() | |
| self._msgs.append({"role": "assistant", "content": text}) | |
| return _parse_action(text) | |
| def _format_chat(self) -> str: | |
| # Try the model's own chat template first | |
| try: | |
| return self._tok.apply_chat_template( | |
| self._msgs, tokenize=False, add_generation_prompt=True | |
| ) | |
| except Exception: | |
| pass | |
| # Fallback plain-text format | |
| parts: List[str] = [] | |
| for m in self._msgs: | |
| role = m["role"] | |
| content = m["content"] | |
| if role == "system": | |
| parts.append(f"<|system|>\n{content}\n") | |
| elif role == "user": | |
| parts.append(f"<|user|>\n{content}\n") | |
| else: | |
| parts.append(f"<|assistant|>\n{content}\n") | |
| parts.append("<|assistant|>\n") | |
| return "".join(parts) | |
| def _trim(self) -> None: | |
| if len(self._msgs) > 1 + self._max_history * 2: | |
| self._msgs = self._msgs[:1] + self._msgs[-(self._max_history * 2):] | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Dispatch: action β RealCodeWorkspace call | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _dispatch( | |
| action: Dict[str, Any], | |
| workspace: RealCodeWorkspace, | |
| ) -> Dict[str, Any]: | |
| atype = action.get("action_type", "") | |
| params = action.get("parameters", {}) or {} | |
| try: | |
| if atype == "list_dir": | |
| return workspace.list_dir(params.get("path", ".")) | |
| if atype == "read_file": | |
| return workspace.read_file(params.get("path", "")) | |
| if atype == "search_code": | |
| return workspace.search_code( | |
| query = params.get("query", ""), | |
| file_pattern = params.get("file_pattern", "*.py"), | |
| max_hits = params.get("max_hits"), | |
| ) | |
| if atype == "get_git_log": | |
| return workspace.get_git_log( | |
| path = params.get("path", ""), | |
| n_commits = int(params.get("n_commits", 10)), | |
| ) | |
| if atype == "get_file_diff": | |
| return workspace.get_file_diff( | |
| commit_sha = params.get("commit_sha", "HEAD"), | |
| path = params.get("path", ""), | |
| ) | |
| # Terminal actions β handled in the main loop | |
| return {} | |
| except RealCodeWorkspaceError as e: | |
| return {"error": str(e)} | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main investigation loop | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| TERMINAL_ACTIONS = {"propose_patch", "declare_root_cause", "declare_no_change"} | |
| VALID_ACTIONS = [ | |
| "list_dir", "read_file", "search_code", | |
| "get_git_log", "get_file_diff", | |
| "propose_patch", "declare_root_cause", "declare_no_change", | |
| ] | |
| def investigate( | |
| agent: LocalModelAgent, | |
| workspace: RealCodeWorkspace, | |
| incident_summary: str, | |
| max_steps: int = 20, | |
| verbose: bool = True, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Drive the agent through the investigation loop. | |
| Returns a result dict with: | |
| root_cause, proposed_patch, steps_taken, action_log | |
| """ | |
| agent.reset() | |
| # Seed observation β mirrors the format from Phase-2 training | |
| obs: Dict[str, Any] = { | |
| "current_phase": 2, | |
| "incident_summary": incident_summary, | |
| "bad_commit_sha": workspace.bad_commit_sha or "", | |
| "valid_actions": VALID_ACTIONS, | |
| "action_result": {}, | |
| "step": 0, | |
| "repo_tree": workspace.file_tree(max_depth=2), | |
| } | |
| action_log: List[Dict[str, Any]] = [] | |
| root_cause: Optional[str] = None | |
| proposed_patch: Optional[str] = None | |
| for step in range(1, max_steps + 1): | |
| obs["step"] = step | |
| action = agent.act(obs) | |
| atype = action.get("action_type", "") | |
| if verbose: | |
| params_str = json.dumps(action.get("parameters", {}))[:120] | |
| print(f" step {step:2d} {atype:<22s} {params_str}") | |
| action_log.append({"step": step, "action": action}) | |
| if atype == "declare_root_cause": | |
| root_cause = action.get("parameters", {}).get("root_cause", "") | |
| if verbose: | |
| print(f"\nRoot cause declared:\n {root_cause}") | |
| # Don't break yet β model may follow up with propose_patch | |
| obs = {**obs, "action_result": {"acknowledged": True}, | |
| "valid_actions": ["propose_patch", "declare_no_change"]} | |
| continue | |
| if atype == "propose_patch": | |
| proposed_patch = action.get("parameters", {}).get("diff", "") | |
| if verbose: | |
| print(f"\nPatch proposed ({len(proposed_patch)} chars)") | |
| break | |
| if atype == "declare_no_change": | |
| if verbose: | |
| print("\nAgent declared: no code change needed.") | |
| break | |
| # Execute the tool call | |
| result = _dispatch(action, workspace) | |
| obs = { | |
| **obs, | |
| "action_result": result, | |
| } | |
| return { | |
| "root_cause": root_cause, | |
| "proposed_patch": proposed_patch, | |
| "steps_taken": step, | |
| "action_log": action_log, | |
| } | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CLI | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser( | |
| description="Run the trained SRE agent on a real GitHub issue." | |
| ) | |
| p.add_argument("--model", required=True, | |
| help="HF model ID or local path of the GRPO checkpoint") | |
| # Repo β one of these required | |
| g = p.add_mutually_exclusive_group(required=True) | |
| g.add_argument("--repo", help="Path to an already-cloned local repo") | |
| g.add_argument("--repo-url", help="Git URL to clone (cloned to a temp dir)") | |
| # Issue β one of these required | |
| ig = p.add_mutually_exclusive_group(required=True) | |
| ig.add_argument("--issue", help="GitHub issue URL (requires gh CLI)") | |
| ig.add_argument("--issue-text", help="Raw issue description text") | |
| p.add_argument("--bad-commit", default="", | |
| help="SHA of the suspected bad commit (optional hint)") | |
| p.add_argument("--max-steps", type=int, default=20) | |
| p.add_argument("--max-new-tokens",type=int, default=512) | |
| p.add_argument("--load-in-4bit", action="store_true", default=False) | |
| p.add_argument("--output", default=None, | |
| help="Write JSON result to this file") | |
| p.add_argument("--quiet", action="store_true") | |
| return p.parse_args() | |
| def main() -> None: | |
| args = _parse_args() | |
| verbose = not args.quiet | |
| # ββ 1. Get incident summary ββββββββββββββββββββββββββββββββββββββ | |
| if args.issue: | |
| if verbose: | |
| print(f"Fetching issue: {args.issue}") | |
| incident_summary = fetch_github_issue(args.issue) | |
| else: | |
| incident_summary = args.issue_text | |
| if verbose: | |
| print("\nββ Incident summary ββ") | |
| print(incident_summary[:600]) | |
| print() | |
| # ββ 2. Set up repo βββββββββββββββββββββββββββββββββββββββββββββββ | |
| _tmpdir = None | |
| if args.repo_url: | |
| _tmpdir = tempfile.mkdtemp(prefix="sre_agent_") | |
| repo_path = clone_repo(args.repo_url, _tmpdir) | |
| else: | |
| repo_path = args.repo | |
| workspace = RealCodeWorkspace(repo_path, bad_commit_sha=args.bad_commit or "") | |
| # ββ 3. Load model ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| agent = LocalModelAgent( | |
| model_id = args.model, | |
| load_in_4bit = args.load_in_4bit, | |
| max_new_tokens = args.max_new_tokens, | |
| ) | |
| # ββ 4. Investigate βββββββββββββββββββββββββββββββββββββββββββββββ | |
| if verbose: | |
| print("ββ Investigation ββ") | |
| result = investigate( | |
| agent=agent, | |
| workspace=workspace, | |
| incident_summary=incident_summary, | |
| max_steps=args.max_steps, | |
| verbose=verbose, | |
| ) | |
| # ββ 5. Output ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\n" + "β" * 60) | |
| print("ROOT CAUSE:") | |
| print(result["root_cause"] or "(not declared)") | |
| print("\nPROPOSED PATCH:") | |
| print(result["proposed_patch"] or "(none)") | |
| print(f"\n({result['steps_taken']} steps taken)") | |
| if args.output: | |
| with open(args.output, "w") as f: | |
| json.dump(result, f, indent=2, default=str) | |
| print(f"\nFull result written to {args.output}") | |
| # Cleanup temp clone | |
| if _tmpdir: | |
| import shutil | |
| shutil.rmtree(_tmpdir, ignore_errors=True) | |
| if __name__ == "__main__": | |
| main() | |