amine-yagoub's picture
refactor: clean up core modules by removing comment headers and unused code
6a2abaa
"""Custom CrewAI tools for CodeTribunal agents."""
import subprocess
from pathlib import Path
from typing import Type
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
_target_dir: str = ""
_code_graph = None
def configure_tools(target_dir: str, code_graph) -> None:
"""Set the shared runtime context for all tools."""
global _target_dir, _code_graph
_target_dir = target_dir
_code_graph = code_graph
class FileReaderInput(BaseModel):
filepath: str = Field(description="Path to the source file to read (relative to project root)")
start_line: int = Field(default=0, description="Start line (0-indexed). Default: 0 (beginning)")
end_line: int = Field(default=-1, description="End line (-1 = entire file)")
class FileReaderTool(BaseTool):
"""Read source code from a specific file with line numbers."""
name: str = "file_reader"
description: str = (
"Read source code from a specific file. Returns the file content with line numbers. "
"Use this to examine specific files, functions, or code sections that you need to analyze. "
"You can specify a line range to focus on specific sections."
)
args_schema: Type[BaseModel] = FileReaderInput
def _run(self, filepath: str, start_line: int = 0, end_line: int = -1) -> str:
full_path = Path(_target_dir) / filepath
if not full_path.exists():
return f"Error: File not found: {filepath}"
try:
lines = full_path.read_text(errors="replace").splitlines()
except OSError as e:
return f"Error reading file: {e}"
total = len(lines)
end = total if end_line == -1 else min(end_line, total)
start = max(0, start_line)
if start >= total:
return f"Error: start_line {start_line} exceeds file length ({total} lines)"
result_lines = []
for i in range(start, end):
result_lines.append(f"{i + 1:4d} | {lines[i]}")
return "\n".join(result_lines)
class PatternSearchInput(BaseModel):
pattern: str = Field(description="GritQL pattern to search for (e.g., 'eval($_)' or 'TODO: $_')")
language: str | None = Field(default=None, description="Language filter: python, javascript, etc.")
class PatternSearchTool(BaseTool):
"""Search for code patterns using GritQL syntax."""
name: str = "pattern_search"
description: str = (
"Search for code patterns using GritQL syntax. Use this to find specific code constructs "
"like function calls, variable assignments, security patterns, or code smells that you "
"want to investigate further. Examples: 'eval($_)' to find eval usage, "
"'$PASS = $_' to find password assignments."
)
args_schema: Type[BaseModel] = PatternSearchInput
def _run(self, pattern: str, language: str | None = None) -> str:
cmd = ["grit", "apply", "--dry-run", pattern, _target_dir]
if language:
cmd += ["--language", language]
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
except FileNotFoundError:
return "Error: grit CLI not found."
except subprocess.TimeoutExpired:
return "Error: Pattern search timed out after 30 seconds."
output = result.stdout.strip()
if not output or "found 0 matches" in output:
return "No matches found for this pattern."
return output
class CodeGraphQueryInput(BaseModel):
query_type: str = Field(
description=(
"Type of query to run. Options: "
"'trace' — trace function call chain, "
"'callers' — find who calls a function, "
"'imports' — list imports for a file, "
"'summary' — get file summary, "
"'source' — get function source code"
)
)
target: str = Field(description="Function name, file path, or node ID to query")
class CodeGraphQueryTool(BaseTool):
"""Query the code dependency graph to understand code structure."""
name: str = "code_graph_query"
description: str = (
"Query the code dependency graph to understand code structure and relationships. "
"Can trace function call chains, find callers, list imports, get file summaries, "
"or retrieve function source code. Use this to understand how vulnerable code connects "
"to the rest of the application."
)
args_schema: Type[BaseModel] = CodeGraphQueryInput
def _run(self, query_type: str, target: str) -> str:
if _code_graph is None:
return "Error: Code graph not built yet."
if query_type == "trace":
return _code_graph.trace_calls(target, depth=3)
elif query_type == "callers":
callers = _code_graph.get_callers(target)
if not callers:
return f"No callers found for '{target}'."
return f"Callers of '{target}':\n" + "\n".join(f" - {c}" for c in callers)
elif query_type == "imports":
imports = _code_graph.get_imports(target)
if not imports:
return f"No imports found in '{target}'."
return f"Imports in '{target}':\n" + "\n".join(f" - {i}" for i in imports)
elif query_type == "summary":
return _code_graph.get_file_summary(target)
elif query_type == "source":
for node in _code_graph.nodes.values():
if node.kind == "function" and node.name == target:
return _code_graph.get_function_source(node.file, target)
return f"Function '{target}' not found in code graph."
else:
return f"Unknown query type: '{query_type}'. Use: trace, callers, imports, summary, source"
class FindingContextInput(BaseModel):
filepath: str = Field(description="File path of the finding")
line: int = Field(description="Line number of the finding (1-indexed)")
context_lines: int = Field(default=10, description="Number of context lines before and after")
class FindingContextTool(BaseTool):
"""Get surrounding code context for a specific finding."""
name: str = "finding_context"
description: str = (
"Get surrounding code context for a specific finding. Shows code before and after "
"the flagged line to help understand the full context of a vulnerability or issue. "
"Use this to assess the real severity and impact of each finding."
)
args_schema: Type[BaseModel] = FindingContextInput
def _run(self, filepath: str, line: int, context_lines: int = 10) -> str:
full_path = Path(_target_dir) / filepath
if not full_path.exists():
matches = list(Path(_target_dir).rglob(Path(filepath).name))
if matches:
full_path = matches[0]
else:
return f"Error: File not found: {filepath}"
try:
lines = full_path.read_text(errors="replace").splitlines()
except OSError as e:
return f"Error reading file: {e}"
start = max(0, line - 1 - context_lines)
end = min(len(lines), line + context_lines)
result_lines = []
for i in range(start, end):
marker = " >>>" if i == line - 1 else " "
result_lines.append(f"{i + 1:4d}{marker} | {lines[i]}")
return "\n".join(result_lines)