"""Custom CrewAI tools for CodeTribunal agents.""" import subprocess from pathlib import Path from typing import Type from crewai.tools import BaseTool from pydantic import BaseModel, Field _target_dir: str = "" _code_graph = None def configure_tools(target_dir: str, code_graph) -> None: """Set the shared runtime context for all tools.""" global _target_dir, _code_graph _target_dir = target_dir _code_graph = code_graph class FileReaderInput(BaseModel): filepath: str = Field(description="Path to the source file to read (relative to project root)") start_line: int = Field(default=0, description="Start line (0-indexed). Default: 0 (beginning)") end_line: int = Field(default=-1, description="End line (-1 = entire file)") class FileReaderTool(BaseTool): """Read source code from a specific file with line numbers.""" name: str = "file_reader" description: str = ( "Read source code from a specific file. Returns the file content with line numbers. " "Use this to examine specific files, functions, or code sections that you need to analyze. " "You can specify a line range to focus on specific sections." ) args_schema: Type[BaseModel] = FileReaderInput def _run(self, filepath: str, start_line: int = 0, end_line: int = -1) -> str: full_path = Path(_target_dir) / filepath if not full_path.exists(): return f"Error: File not found: {filepath}" try: lines = full_path.read_text(errors="replace").splitlines() except OSError as e: return f"Error reading file: {e}" total = len(lines) end = total if end_line == -1 else min(end_line, total) start = max(0, start_line) if start >= total: return f"Error: start_line {start_line} exceeds file length ({total} lines)" result_lines = [] for i in range(start, end): result_lines.append(f"{i + 1:4d} | {lines[i]}") return "\n".join(result_lines) class PatternSearchInput(BaseModel): pattern: str = Field(description="GritQL pattern to search for (e.g., 'eval($_)' or 'TODO: $_')") language: str | None = Field(default=None, description="Language filter: python, javascript, etc.") class PatternSearchTool(BaseTool): """Search for code patterns using GritQL syntax.""" name: str = "pattern_search" description: str = ( "Search for code patterns using GritQL syntax. Use this to find specific code constructs " "like function calls, variable assignments, security patterns, or code smells that you " "want to investigate further. Examples: 'eval($_)' to find eval usage, " "'$PASS = $_' to find password assignments." ) args_schema: Type[BaseModel] = PatternSearchInput def _run(self, pattern: str, language: str | None = None) -> str: cmd = ["grit", "apply", "--dry-run", pattern, _target_dir] if language: cmd += ["--language", language] try: result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) except FileNotFoundError: return "Error: grit CLI not found." except subprocess.TimeoutExpired: return "Error: Pattern search timed out after 30 seconds." output = result.stdout.strip() if not output or "found 0 matches" in output: return "No matches found for this pattern." return output class CodeGraphQueryInput(BaseModel): query_type: str = Field( description=( "Type of query to run. Options: " "'trace' — trace function call chain, " "'callers' — find who calls a function, " "'imports' — list imports for a file, " "'summary' — get file summary, " "'source' — get function source code" ) ) target: str = Field(description="Function name, file path, or node ID to query") class CodeGraphQueryTool(BaseTool): """Query the code dependency graph to understand code structure.""" name: str = "code_graph_query" description: str = ( "Query the code dependency graph to understand code structure and relationships. " "Can trace function call chains, find callers, list imports, get file summaries, " "or retrieve function source code. Use this to understand how vulnerable code connects " "to the rest of the application." ) args_schema: Type[BaseModel] = CodeGraphQueryInput def _run(self, query_type: str, target: str) -> str: if _code_graph is None: return "Error: Code graph not built yet." if query_type == "trace": return _code_graph.trace_calls(target, depth=3) elif query_type == "callers": callers = _code_graph.get_callers(target) if not callers: return f"No callers found for '{target}'." return f"Callers of '{target}':\n" + "\n".join(f" - {c}" for c in callers) elif query_type == "imports": imports = _code_graph.get_imports(target) if not imports: return f"No imports found in '{target}'." return f"Imports in '{target}':\n" + "\n".join(f" - {i}" for i in imports) elif query_type == "summary": return _code_graph.get_file_summary(target) elif query_type == "source": for node in _code_graph.nodes.values(): if node.kind == "function" and node.name == target: return _code_graph.get_function_source(node.file, target) return f"Function '{target}' not found in code graph." else: return f"Unknown query type: '{query_type}'. Use: trace, callers, imports, summary, source" class FindingContextInput(BaseModel): filepath: str = Field(description="File path of the finding") line: int = Field(description="Line number of the finding (1-indexed)") context_lines: int = Field(default=10, description="Number of context lines before and after") class FindingContextTool(BaseTool): """Get surrounding code context for a specific finding.""" name: str = "finding_context" description: str = ( "Get surrounding code context for a specific finding. Shows code before and after " "the flagged line to help understand the full context of a vulnerability or issue. " "Use this to assess the real severity and impact of each finding." ) args_schema: Type[BaseModel] = FindingContextInput def _run(self, filepath: str, line: int, context_lines: int = 10) -> str: full_path = Path(_target_dir) / filepath if not full_path.exists(): matches = list(Path(_target_dir).rglob(Path(filepath).name)) if matches: full_path = matches[0] else: return f"Error: File not found: {filepath}" try: lines = full_path.read_text(errors="replace").splitlines() except OSError as e: return f"Error reading file: {e}" start = max(0, line - 1 - context_lines) end = min(len(lines), line + context_lines) result_lines = [] for i in range(start, end): marker = " >>>" if i == line - 1 else " " result_lines.append(f"{i + 1:4d}{marker} | {lines[i]}") return "\n".join(result_lines)