repomind-api / agent /tools.py
SouravNath's picture
Initial commit
dc71cad
"""
agent/tools.py
───────────────
Tool definitions for the reflection agent.
Tools available to the agent:
read_file(path) β€” read a file from the workspace
write_patch(diff) β€” write a unified diff to the workspace
run_tests(test_ids) β€” run pytest and return structured output
git_diff() β€” show current diff vs base commit
list_files(pattern) β€” list files matching a glob
Each tool returns a structured ToolResult with success/error.
The agent's LLM sees ToolResult.to_prompt_str() in its context.
"""
from __future__ import annotations
import logging
import re
import subprocess
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal
logger = logging.getLogger(__name__)
# ── Tool result ───────────────────────────────────────────────────────────────
@dataclass
class ToolResult:
tool_name: str
success: bool
output: str
error: str = ""
metadata: dict = field(default_factory=dict)
def to_prompt_str(self) -> str:
"""Format result for inclusion in LLM prompt."""
status = "SUCCESS" if self.success else "ERROR"
parts = [f"[TOOL: {self.tool_name} | {status}]"]
if self.output:
parts.append(self.output[:3000]) # truncate for token budget
if self.error:
parts.append(f"ERROR: {self.error[:500]}")
return "\n".join(parts)
# ── Individual tools ──────────────────────────────────────────────────────────
class AgentTools:
"""
Collection of tools available to the reflection agent.
All file operations are scoped to workspace_dir (sandbox root).
"""
def __init__(self, workspace_dir: Path, sandbox=None):
self.workspace_dir = Path(workspace_dir)
self.sandbox = sandbox # SandboxExecutor instance (optional)
def read_file(self, path: str, max_lines: int = 200) -> ToolResult:
"""
Read the contents of a file relative to workspace_dir.
Args:
path: relative file path within the workspace
max_lines: truncate to this many lines (token budget control)
"""
full_path = self.workspace_dir / path
# Prevent path traversal
try:
full_path.resolve().relative_to(self.workspace_dir.resolve())
except ValueError:
return ToolResult("read_file", False, "", f"Path traversal rejected: {path}")
if not full_path.exists():
return ToolResult("read_file", False, "", f"File not found: {path}")
try:
content = full_path.read_text(errors="replace")
lines = content.splitlines()
truncated = len(lines) > max_lines
visible = "\n".join(lines[:max_lines])
if truncated:
visible += f"\n... [{len(lines) - max_lines} more lines truncated]"
return ToolResult(
"read_file", True, visible,
metadata={"total_lines": len(lines), "truncated": truncated}
)
except Exception as e:
return ToolResult("read_file", False, "", str(e))
def write_patch(self, diff_text: str) -> ToolResult:
"""
Write a unified diff to a staging file for git apply.
Does NOT apply the patch β€” call the sandbox apply_patch() separately.
Args:
diff_text: unified diff text (git format)
"""
if not diff_text.strip():
return ToolResult("write_patch", False, "", "Empty patch text")
# Basic validation: must start with --- or diff --git
stripped = diff_text.strip()
if not (stripped.startswith("---") or stripped.startswith("diff --git")):
return ToolResult(
"write_patch", False, "",
"Patch must start with '---' or 'diff --git'"
)
patch_file = self.workspace_dir / "_agent_patch.diff"
try:
patch_file.write_text(diff_text)
return ToolResult(
"write_patch", True,
f"Patch written to {patch_file.name} ({len(diff_text)} chars)",
metadata={"patch_path": str(patch_file)}
)
except Exception as e:
return ToolResult("write_patch", False, "", str(e))
def run_tests(self, test_ids: list[str], timeout: int = 60) -> ToolResult:
"""
Run pytest on specific test IDs.
Returns structured output including PASSED/FAILED counts and
the first failing test's traceback (for reflection context).
"""
if not test_ids:
return ToolResult("run_tests", False, "", "No test IDs provided")
if self.sandbox:
test_result = self.sandbox.run_tests(self.workspace_dir, test_ids)
output = test_result.raw_output
success = test_result.all_passed
else:
# Local subprocess fallback
cmd = ["python", "-m", "pytest", "-v", "--tb=short", "--no-header", "-rN"] + test_ids
try:
proc = subprocess.run(
cmd, capture_output=True, text=True,
timeout=timeout, cwd=str(self.workspace_dir)
)
output = proc.stdout + proc.stderr
success = proc.returncode == 0
except subprocess.TimeoutExpired:
return ToolResult("run_tests", False, "", f"Tests timed out after {timeout}s")
except Exception as e:
return ToolResult("run_tests", False, "", str(e))
# Extract key info for the agent
summary = _extract_test_summary(output)
return ToolResult(
"run_tests", success,
summary,
metadata={"full_output": output[:5000]}
)
def git_diff(self) -> ToolResult:
"""Show the current diff vs HEAD (to review what the agent has changed)."""
try:
result = subprocess.run(
["git", "diff"], capture_output=True, text=True,
cwd=str(self.workspace_dir), timeout=10
)
diff = result.stdout or "(no changes)"
return ToolResult("git_diff", True, diff[:3000])
except Exception as e:
return ToolResult("git_diff", False, "", str(e))
def list_files(self, pattern: str = "**/*.py", max_results: int = 50) -> ToolResult:
"""List files in the workspace matching a glob pattern."""
try:
files = sorted(self.workspace_dir.glob(pattern))
rel_files = [
str(f.relative_to(self.workspace_dir))
for f in files
if "__pycache__" not in str(f) and ".git" not in str(f)
][:max_results]
output = "\n".join(rel_files) or "(no files found)"
return ToolResult("list_files", True, output,
metadata={"count": len(rel_files)})
except Exception as e:
return ToolResult("list_files", False, "", str(e))
# ── Helpers ───────────────────────────────────────────────────────────────────
def _extract_test_summary(pytest_output: str) -> str:
"""
Extract a concise test summary from raw pytest output.
Includes: pass/fail counts + first failure traceback.
"""
lines = pytest_output.splitlines()
summary_lines = []
in_failure_section = False
failure_lines: list[str] = []
for line in lines:
# Capture summary line
if re.search(r"\d+ (passed|failed|error)", line):
summary_lines.append(line)
# Capture short failure tracebacks
if line.startswith("FAILED") or "AssertionError" in line or "Error" in line:
failure_lines.append(line)
# Short traceback block
if line.startswith("_ " * 3) or "FAILURES" in line:
in_failure_section = True
if in_failure_section:
failure_lines.append(line)
if len(failure_lines) > 40: # cap failure context
break
parts = summary_lines + ["---"] + failure_lines[:40] if failure_lines else summary_lines
return "\n".join(parts) or pytest_output[:1000]