Spaces:
Sleeping
Sleeping
File size: 7,457 Bytes
1cdb3e3 6a2abaa 1cdb3e3 6a2abaa 1cdb3e3 6a2abaa 1cdb3e3 6a2abaa 1cdb3e3 6a2abaa 1cdb3e3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | """Custom CrewAI tools for CodeTribunal agents."""
import subprocess
from pathlib import Path
from typing import Type
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
_target_dir: str = ""
_code_graph = None
def configure_tools(target_dir: str, code_graph) -> None:
"""Set the shared runtime context for all tools."""
global _target_dir, _code_graph
_target_dir = target_dir
_code_graph = code_graph
class FileReaderInput(BaseModel):
filepath: str = Field(description="Path to the source file to read (relative to project root)")
start_line: int = Field(default=0, description="Start line (0-indexed). Default: 0 (beginning)")
end_line: int = Field(default=-1, description="End line (-1 = entire file)")
class FileReaderTool(BaseTool):
"""Read source code from a specific file with line numbers."""
name: str = "file_reader"
description: str = (
"Read source code from a specific file. Returns the file content with line numbers. "
"Use this to examine specific files, functions, or code sections that you need to analyze. "
"You can specify a line range to focus on specific sections."
)
args_schema: Type[BaseModel] = FileReaderInput
def _run(self, filepath: str, start_line: int = 0, end_line: int = -1) -> str:
full_path = Path(_target_dir) / filepath
if not full_path.exists():
return f"Error: File not found: {filepath}"
try:
lines = full_path.read_text(errors="replace").splitlines()
except OSError as e:
return f"Error reading file: {e}"
total = len(lines)
end = total if end_line == -1 else min(end_line, total)
start = max(0, start_line)
if start >= total:
return f"Error: start_line {start_line} exceeds file length ({total} lines)"
result_lines = []
for i in range(start, end):
result_lines.append(f"{i + 1:4d} | {lines[i]}")
return "\n".join(result_lines)
class PatternSearchInput(BaseModel):
pattern: str = Field(description="GritQL pattern to search for (e.g., 'eval($_)' or 'TODO: $_')")
language: str | None = Field(default=None, description="Language filter: python, javascript, etc.")
class PatternSearchTool(BaseTool):
"""Search for code patterns using GritQL syntax."""
name: str = "pattern_search"
description: str = (
"Search for code patterns using GritQL syntax. Use this to find specific code constructs "
"like function calls, variable assignments, security patterns, or code smells that you "
"want to investigate further. Examples: 'eval($_)' to find eval usage, "
"'$PASS = $_' to find password assignments."
)
args_schema: Type[BaseModel] = PatternSearchInput
def _run(self, pattern: str, language: str | None = None) -> str:
cmd = ["grit", "apply", "--dry-run", pattern, _target_dir]
if language:
cmd += ["--language", language]
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
except FileNotFoundError:
return "Error: grit CLI not found."
except subprocess.TimeoutExpired:
return "Error: Pattern search timed out after 30 seconds."
output = result.stdout.strip()
if not output or "found 0 matches" in output:
return "No matches found for this pattern."
return output
class CodeGraphQueryInput(BaseModel):
query_type: str = Field(
description=(
"Type of query to run. Options: "
"'trace' — trace function call chain, "
"'callers' — find who calls a function, "
"'imports' — list imports for a file, "
"'summary' — get file summary, "
"'source' — get function source code"
)
)
target: str = Field(description="Function name, file path, or node ID to query")
class CodeGraphQueryTool(BaseTool):
"""Query the code dependency graph to understand code structure."""
name: str = "code_graph_query"
description: str = (
"Query the code dependency graph to understand code structure and relationships. "
"Can trace function call chains, find callers, list imports, get file summaries, "
"or retrieve function source code. Use this to understand how vulnerable code connects "
"to the rest of the application."
)
args_schema: Type[BaseModel] = CodeGraphQueryInput
def _run(self, query_type: str, target: str) -> str:
if _code_graph is None:
return "Error: Code graph not built yet."
if query_type == "trace":
return _code_graph.trace_calls(target, depth=3)
elif query_type == "callers":
callers = _code_graph.get_callers(target)
if not callers:
return f"No callers found for '{target}'."
return f"Callers of '{target}':\n" + "\n".join(f" - {c}" for c in callers)
elif query_type == "imports":
imports = _code_graph.get_imports(target)
if not imports:
return f"No imports found in '{target}'."
return f"Imports in '{target}':\n" + "\n".join(f" - {i}" for i in imports)
elif query_type == "summary":
return _code_graph.get_file_summary(target)
elif query_type == "source":
for node in _code_graph.nodes.values():
if node.kind == "function" and node.name == target:
return _code_graph.get_function_source(node.file, target)
return f"Function '{target}' not found in code graph."
else:
return f"Unknown query type: '{query_type}'. Use: trace, callers, imports, summary, source"
class FindingContextInput(BaseModel):
filepath: str = Field(description="File path of the finding")
line: int = Field(description="Line number of the finding (1-indexed)")
context_lines: int = Field(default=10, description="Number of context lines before and after")
class FindingContextTool(BaseTool):
"""Get surrounding code context for a specific finding."""
name: str = "finding_context"
description: str = (
"Get surrounding code context for a specific finding. Shows code before and after "
"the flagged line to help understand the full context of a vulnerability or issue. "
"Use this to assess the real severity and impact of each finding."
)
args_schema: Type[BaseModel] = FindingContextInput
def _run(self, filepath: str, line: int, context_lines: int = 10) -> str:
full_path = Path(_target_dir) / filepath
if not full_path.exists():
matches = list(Path(_target_dir).rglob(Path(filepath).name))
if matches:
full_path = matches[0]
else:
return f"Error: File not found: {filepath}"
try:
lines = full_path.read_text(errors="replace").splitlines()
except OSError as e:
return f"Error reading file: {e}"
start = max(0, line - 1 - context_lines)
end = min(len(lines), line + context_lines)
result_lines = []
for i in range(start, end):
marker = " >>>" if i == line - 1 else " "
result_lines.append(f"{i + 1:4d}{marker} | {lines[i]}")
return "\n".join(result_lines)
|