File size: 7,457 Bytes
1cdb3e3
 
 
 
 
 
 
 
 
 
 
6a2abaa
1cdb3e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a2abaa
 
1cdb3e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a2abaa
 
1cdb3e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a2abaa
 
1cdb3e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a2abaa
 
1cdb3e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""Custom CrewAI tools for CodeTribunal agents."""

import subprocess
from pathlib import Path
from typing import Type

from crewai.tools import BaseTool
from pydantic import BaseModel, Field


_target_dir: str = ""
_code_graph = None


def configure_tools(target_dir: str, code_graph) -> None:
    """Set the shared runtime context for all tools."""
    global _target_dir, _code_graph
    _target_dir = target_dir
    _code_graph = code_graph


class FileReaderInput(BaseModel):
    filepath: str = Field(description="Path to the source file to read (relative to project root)")
    start_line: int = Field(default=0, description="Start line (0-indexed). Default: 0 (beginning)")
    end_line: int = Field(default=-1, description="End line (-1 = entire file)")


class FileReaderTool(BaseTool):
    """Read source code from a specific file with line numbers."""

    name: str = "file_reader"
    description: str = (
        "Read source code from a specific file. Returns the file content with line numbers. "
        "Use this to examine specific files, functions, or code sections that you need to analyze. "
        "You can specify a line range to focus on specific sections."
    )
    args_schema: Type[BaseModel] = FileReaderInput

    def _run(self, filepath: str, start_line: int = 0, end_line: int = -1) -> str:
        full_path = Path(_target_dir) / filepath
        if not full_path.exists():
            return f"Error: File not found: {filepath}"
        try:
            lines = full_path.read_text(errors="replace").splitlines()
        except OSError as e:
            return f"Error reading file: {e}"

        total = len(lines)
        end = total if end_line == -1 else min(end_line, total)
        start = max(0, start_line)

        if start >= total:
            return f"Error: start_line {start_line} exceeds file length ({total} lines)"

        result_lines = []
        for i in range(start, end):
            result_lines.append(f"{i + 1:4d} | {lines[i]}")
        return "\n".join(result_lines)


class PatternSearchInput(BaseModel):
    pattern: str = Field(description="GritQL pattern to search for (e.g., 'eval($_)' or 'TODO: $_')")
    language: str | None = Field(default=None, description="Language filter: python, javascript, etc.")


class PatternSearchTool(BaseTool):
    """Search for code patterns using GritQL syntax."""

    name: str = "pattern_search"
    description: str = (
        "Search for code patterns using GritQL syntax. Use this to find specific code constructs "
        "like function calls, variable assignments, security patterns, or code smells that you "
        "want to investigate further. Examples: 'eval($_)' to find eval usage, "
        "'$PASS = $_' to find password assignments."
    )
    args_schema: Type[BaseModel] = PatternSearchInput

    def _run(self, pattern: str, language: str | None = None) -> str:
        cmd = ["grit", "apply", "--dry-run", pattern, _target_dir]
        if language:
            cmd += ["--language", language]

        try:
            result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
        except FileNotFoundError:
            return "Error: grit CLI not found."
        except subprocess.TimeoutExpired:
            return "Error: Pattern search timed out after 30 seconds."

        output = result.stdout.strip()
        if not output or "found 0 matches" in output:
            return "No matches found for this pattern."

        return output


class CodeGraphQueryInput(BaseModel):
    query_type: str = Field(
        description=(
            "Type of query to run. Options: "
            "'trace' — trace function call chain, "
            "'callers' — find who calls a function, "
            "'imports' — list imports for a file, "
            "'summary' — get file summary, "
            "'source' — get function source code"
        )
    )
    target: str = Field(description="Function name, file path, or node ID to query")


class CodeGraphQueryTool(BaseTool):
    """Query the code dependency graph to understand code structure."""

    name: str = "code_graph_query"
    description: str = (
        "Query the code dependency graph to understand code structure and relationships. "
        "Can trace function call chains, find callers, list imports, get file summaries, "
        "or retrieve function source code. Use this to understand how vulnerable code connects "
        "to the rest of the application."
    )
    args_schema: Type[BaseModel] = CodeGraphQueryInput

    def _run(self, query_type: str, target: str) -> str:
        if _code_graph is None:
            return "Error: Code graph not built yet."

        if query_type == "trace":
            return _code_graph.trace_calls(target, depth=3)
        elif query_type == "callers":
            callers = _code_graph.get_callers(target)
            if not callers:
                return f"No callers found for '{target}'."
            return f"Callers of '{target}':\n" + "\n".join(f"  - {c}" for c in callers)
        elif query_type == "imports":
            imports = _code_graph.get_imports(target)
            if not imports:
                return f"No imports found in '{target}'."
            return f"Imports in '{target}':\n" + "\n".join(f"  - {i}" for i in imports)
        elif query_type == "summary":
            return _code_graph.get_file_summary(target)
        elif query_type == "source":
            for node in _code_graph.nodes.values():
                if node.kind == "function" and node.name == target:
                    return _code_graph.get_function_source(node.file, target)
            return f"Function '{target}' not found in code graph."
        else:
            return f"Unknown query type: '{query_type}'. Use: trace, callers, imports, summary, source"


class FindingContextInput(BaseModel):
    filepath: str = Field(description="File path of the finding")
    line: int = Field(description="Line number of the finding (1-indexed)")
    context_lines: int = Field(default=10, description="Number of context lines before and after")


class FindingContextTool(BaseTool):
    """Get surrounding code context for a specific finding."""

    name: str = "finding_context"
    description: str = (
        "Get surrounding code context for a specific finding. Shows code before and after "
        "the flagged line to help understand the full context of a vulnerability or issue. "
        "Use this to assess the real severity and impact of each finding."
    )
    args_schema: Type[BaseModel] = FindingContextInput

    def _run(self, filepath: str, line: int, context_lines: int = 10) -> str:
        full_path = Path(_target_dir) / filepath
        if not full_path.exists():
            matches = list(Path(_target_dir).rglob(Path(filepath).name))
            if matches:
                full_path = matches[0]
            else:
                return f"Error: File not found: {filepath}"

        try:
            lines = full_path.read_text(errors="replace").splitlines()
        except OSError as e:
            return f"Error reading file: {e}"

        start = max(0, line - 1 - context_lines)
        end = min(len(lines), line + context_lines)

        result_lines = []
        for i in range(start, end):
            marker = " >>>" if i == line - 1 else "    "
            result_lines.append(f"{i + 1:4d}{marker} | {lines[i]}")
        return "\n".join(result_lines)