treesitter_mcp / tools /code_indexer.py
alexcpn's picture
Upload 9 files (#1)
f6ee9d2 verified
"""
Author: Alex Punnen
Code to create for a Code Reivew tool helper for MCP server
License: Proprietary
"""
import os, textwrap
from pathlib import Path
from tree_sitter_languages import get_language
from tree_sitter import Parser
import tempfile
from git import Repo
from enum import Enum
import logging as log
import requests
import re
from collections import defaultdict
log.basicConfig(
level=log.INFO,
format="%(asctime)s [%(levelname)s] %(message)s", #
# format="[%(levelname)s] %(message)s", # dont need timing
handlers=[log.StreamHandler()],
force=True,
)
parser = Parser()
all_refs = {} # store all classes and functions in a dict
code_ref ={} # hold the code bytes
code_languages = {} # track language per file for downstream queries
DEFAULT_SEARCH_IGNORES: tuple[str, ...] = (
".git",
".hg",
".svn",
".mypy_cache",
"__pycache__",
"node_modules",
"dist",
"build",
".venv",
)
LANGUAGE_NAME_MAP = {
"python": "python",
"go": "go",
"cpp": "cpp",
}
def _collect_files(root_dir, extensions=None):
all_files = []
for dirpath, _, filenames in os.walk(root_dir):
for fname in filenames:
if extensions is None or any(fname.endswith(ext) for ext in extensions):
all_files.append(os.path.join(dirpath, fname))
return all_files
def _decode_text(content_bytes: bytes):
try:
return content_bytes.decode("utf-8")
except UnicodeDecodeError:
try:
return content_bytes.decode("latin-1")
except UnicodeDecodeError:
return None
def _ensure_repo_indexed(github_repo: str):
if github_repo in all_refs:
cached = all_refs[github_repo]
return cached["classes"], cached["functions"]
with tempfile.TemporaryDirectory() as project_root:
log.info(f"Cloning repo {github_repo} into {project_root}...")
Repo.clone_from(github_repo, project_root, depth=1)
log.info(f"Cloned repo {github_repo} into {project_root}.")
log.info(f"Indexing files in {project_root}...")
all_classes, all_functions = index_all_files(project_root, github_repo)
log.info(f"Indexed {len(all_classes)} classes and {len(all_functions)} functions.")
all_refs[github_repo] = {"classes": all_classes, "functions": all_functions}
return all_classes, all_functions
# ---------------------------------------------------------------------------
# Language queries from https://github.com/sankalp1999/code_qa/blob/fe6ce9d852aa1c371c299db22978012df4b354a0/treesitter.py#L16
# ---------------------------------------------------------------------------
class LanguageEnum(Enum):
UNKNOWN = "unknown"
PYTHON= {
"class": """
(class_definition
name: (identifier) @class.name)
""",
"func": """
(function_definition
name: (identifier) @function.name)
""",
"doc": """
(expression_statement (string) @docstring)
""",
}
GO = {
# Match struct type declarations:
# type Foo struct { … }
'struct_query': r"""
(type_spec
name: (type_identifier) @struct.name
type: (struct_type))
""",
# Match both top-level functions and methods:
# func Bar(...) { … }
# func (r Receiver) Baz(...) { … }
'func_query': r"""
[
(function_declaration
name: (identifier) @func.name)
(method_declaration
name: (field_identifier) @method.name)
]
""",
# Capture all comments (line or block) for docstrings:
# // comment
# /* comment */
'doc_query': r"""
(comment) @comment
"""
}
CPP = {
"class_query": r"""
(class_specifier
name: (type_identifier) @class.name)
""",
"struct_query": r"""
(struct_specifier
name: (type_identifier) @struct.name)
""",
"func_query": r"""
(function_definition
declarator: (function_declarator
declarator: [
(identifier) @function.name
(field_identifier) @function.name
]
)
)
""",
"doc_query": r"""
(comment) @comment
""",
}
# ---------------------------------------------------------------------------
# run the query and grab the captures
# ---------------------------------------------------------------------------
def _normalize_block(node):
"""
Walk up the tree until we hit a block node representing a function or class.
"""
target_types = {
"function_definition",
"function_declaration",
"method_declaration",
"class_definition",
"class_specifier",
"struct_specifier",
"type_spec",
}
cur = node
while cur and cur.type not in target_types:
cur = cur.parent
return cur or node
def _run_query(code_bytes, q_src, tag, language):
"""
Return a list of dicts: {node, name, start_line, end_line}
for every capture whose capture-name == tag.
"""
if language not in LANGUAGE_NAME_MAP:
raise ValueError(f"Unsupported language: {language}")
lang = get_language(LANGUAGE_NAME_MAP[language])
parser.set_language(lang)
query = lang.query(q_src)
root = parser.parse(code_bytes).root_node
captures = query.captures(root) # [(node, capture_name), …]
items = []
for node, cap_name in captures:
if cap_name != tag:
continue
name = code_bytes[node.start_byte: node.end_byte].decode()
if tag in {"docstring", "comment"}:
block = node
else:
block = _normalize_block(node)
# Get the enclosing class (if any)
class_name = _get_enclosing_class_name(block, code_bytes)
items.append({
"node" : block,
"name" : name,
"start_line" : block.start_point[0] + 1, # 0-based → 1-based
"end_line" : block.end_point[0] + 1,
"class" : class_name # optional, can be None
})
return items
# ---------------------------------------------------------------------------
# attach the first doc-string that falls *inside* each block
# ---------------------------------------------------------------------------
def _attach_docstrings(code_bytes,items,docs):
for itm in items:
for d in docs:
if (d["node"].start_byte >= itm["node"].start_byte and
d["node"].end_byte <= itm["node"].end_byte):
itm["doc"] = code_bytes[
d["node"].start_byte : d["node"].end_byte
].decode().strip('"\''" \n")
break
else:
itm["doc"] = None
return items
# ---------------------------------------------------------------------------
# Attach file name to each item
# ---------------------------------------------------------------------------
def _attach_file_name(items, file_path):
for itm in items:
itm["file"] = file_path
return items
# ---------------------------------------------------------------------------
# Build a query that finds all calls to `target_name`
# ---------------------------------------------------------------------------
def _build_call_query(target_name: str, language: str):
if language == "python":
lang = get_language("python")
return lang.query(f"""
(
(call
function: (identifier) @call.name
arguments: (argument_list)?
) @call.node
(#eq? @call.name "{target_name}")
)
(
(call
function: (attribute
object: (_)
attribute: (identifier) @call.name
)
arguments: (argument_list)?
) @call.node
(#eq? @call.name "{target_name}")
)
""")
if language == "go":
lang = get_language("go")
return lang.query(f"""
(
(call_expression
function: (identifier) @call.name
arguments: (argument_list)?
) @call.node
(#eq? @call.name "{target_name}")
)
(
(call_expression
function: (selector_expression
operand: (_)
field: (field_identifier) @call.name
)
arguments: (argument_list)?
) @call.node
(#eq? @call.name "{target_name}")
)
""")
if language == "cpp":
lang = get_language("cpp")
return lang.query(f"""
(
(call_expression
function: (identifier) @call.name
arguments: (argument_list)?
) @call.node
(#eq? @call.name "{target_name}")
)
(
(call_expression
function: (field_expression
field: (field_identifier) @call.name
)
arguments: (argument_list)?
) @call.node
(#eq? @call.name "{target_name}")
)
""")
return None
def _extract_identifier(node, code_bytes):
"""
Depth-first search for the first identifier-like child node.
"""
stack = [node]
while stack:
cur = stack.pop()
if cur.type in {"identifier", "field_identifier", "type_identifier"}:
return code_bytes[cur.start_byte:cur.end_byte].decode()
stack.extend(reversed(cur.children or []))
return None
def _get_enclosing_function(node, code_bytes):
"""
Walk up from `node` until we find a function_definition.
Return its name (string) or None if at top-level.
"""
cur = node
while cur:
if cur.type in {"function_definition", "function_declaration", "method_declaration"}:
# child_by_field_name works if the grammar labels the name field
name_node = cur.child_by_field_name("name")
if name_node:
return code_bytes[name_node.start_byte:name_node.end_byte].decode()
declarator = cur.child_by_field_name("declarator")
if declarator:
extracted = _extract_identifier(declarator, code_bytes)
if extracted:
return extracted
cur = cur.parent
return None
def _get_enclosing_class_name(node, code_bytes):
"""
Traverse up the tree to find the enclosing class, if any.
"""
cur = node
while cur:
if cur.type in {"class_definition", "class_specifier", "struct_specifier"}:
name_node = cur.child_by_field_name("name")
if name_node:
return code_bytes[name_node.start_byte:name_node.end_byte].decode()
cur = cur.parent
return None
def find_call_sites(code_bytes: bytes, target_name: str, language: str):
if language not in LANGUAGE_NAME_MAP:
return []
lang = get_language(LANGUAGE_NAME_MAP[language])
parser_local = Parser()
parser_local.set_language(lang)
query = _build_call_query(target_name, language)
tree = parser_local.parse(code_bytes)
if not query:
return []
caps = query.captures(tree.root_node)
sites = []
for node, cap in caps:
if cap != "call.node":
continue
# find caller
caller = _get_enclosing_function(node, code_bytes) or "<module>"
# grab the raw snippet
raw = code_bytes[node.start_byte:node.end_byte].decode(errors="ignore")
snippet = " ".join(raw.split())
start_ln, end_ln = node.start_point[0]+1, node.end_point[0]+1
sites.append({
"caller": caller,
"start_line": start_ln,
"end_line": end_ln,
"snippet": snippet,
})
return sites
def index_all_files(project_root,git_repo_url):
all_classes = []
all_functions = []
all_files = _collect_files(project_root, [".py",".go",".cpp",".cc",".cxx",".hpp",".hh",".h",".hxx",".ipp"])
for path in all_files:
with open(path, "r", encoding="utf8") as f:
code = f.read()
code_bytes = code.encode()
log.info(f"Processing {path}")
language = LanguageEnum.UNKNOWN.value
if path.endswith(".py"):
language = LanguageEnum.PYTHON.value
classes = _run_query(code_bytes,language["class"], "class.name","python")
functions = _run_query(code_bytes,language["func"], "function.name","python")
docs = _run_query(code_bytes,language["doc"], "docstring","python") # optional
classes = _attach_docstrings(code_bytes,classes,docs)
functions = _attach_docstrings(code_bytes,functions,docs)
# get the file name and previous directory
# get only file name and relative path
file_name = os.path.basename(path)
rel_path = os.path.relpath(path, project_root)
log.info(f"Processing {file_name} ({len(classes)} classes, {len(functions)} functions), {rel_path})")
classes = _attach_file_name(classes, rel_path)
functions = _attach_file_name(functions, rel_path)
code_ref[git_repo_url+rel_path] =code_bytes
code_languages[git_repo_url+rel_path] = "python"
all_classes.extend(classes)
all_functions.extend(functions)
elif path.endswith(".go"):
language = LanguageEnum.GO.value
# fill for Go language
structs = _run_query(code_bytes,language["struct_query"], "struct.name","go")
functions = _run_query(code_bytes,language["func_query"], "func.name","go")
docs = _run_query(code_bytes,language["doc_query"], "comment","go")
structs = _attach_docstrings(code_bytes,structs,docs)
functions = _attach_docstrings(code_bytes,functions,docs)
# get the file name and previous directory
file_name = os.path.basename(path)
rel_path = os.path.relpath(path, project_root)
log.info(f"Processing {file_name} ({len(structs)} structs, {len(functions)} functions), {rel_path})")
structs = _attach_file_name(structs, rel_path)
functions = _attach_file_name(functions, rel_path)
code_ref[git_repo_url+rel_path] =code_bytes
code_languages[git_repo_url+rel_path] = "go"
all_classes.extend(structs)
all_functions.extend(functions)
elif path.endswith((".cpp",".cc",".cxx",".hpp",".hh",".h",".hxx",".ipp")):
language = LanguageEnum.CPP.value
classes = _run_query(code_bytes, language["class_query"], "class.name", "cpp")
structs = _run_query(code_bytes, language["struct_query"], "struct.name", "cpp")
functions = _run_query(code_bytes, language["func_query"], "function.name", "cpp")
docs = _run_query(code_bytes, language["doc_query"], "comment", "cpp")
classes = _attach_docstrings(code_bytes, classes, docs)
structs = _attach_docstrings(code_bytes, structs, docs)
functions = _attach_docstrings(code_bytes, functions, docs)
file_name = os.path.basename(path)
rel_path = os.path.relpath(path, project_root)
log.info(f"Processing {file_name} ({len(classes)+len(structs)} class/struct, {len(functions)} functions), {rel_path})")
classes = _attach_file_name(classes, rel_path)
structs = _attach_file_name(structs, rel_path)
functions = _attach_file_name(functions, rel_path)
code_ref[git_repo_url+rel_path] = code_bytes
code_languages[git_repo_url+rel_path] = "cpp"
all_classes.extend(classes)
all_classes.extend(structs)
all_functions.extend(functions)
else:
log.info(f"Skipping {path}, unsupported file type.")
continue
return all_classes, all_functions
def get_function_context(target_name,all_functions,github_url):
"""
Find all functions with the same name as `target_name`.
Return their context (docstring, source code).
@param target_name: The name of the function to find.
@param all_functions: The list of all functions in the project.
"""
matches = [fn for fn in all_functions if fn["name"] == target_name]
log.info(f"\n\nFound {len(matches)} matches for '{target_name}':")
for fn in matches:
start, end = fn["node"].start_byte, fn["node"].end_byte
file_name = fn["file"]
code_bytes = code_ref[github_url+file_name]
raw_src = code_bytes[start:end].decode()
src = textwrap.dedent(raw_src).rstrip()
rel_path = file_name
contex = f"Definition in {rel_path} (L{fn['start_line']}{fn['end_line']}):\n"
if fn.get("class"):
contex += f"{fn['class']}.{fn['name']} (L{fn['start_line']}{fn['end_line']})"
else:
contex += f"{fn['name']} (L{fn['start_line']}{fn['end_line']})"
if fn.get("doc"):
contex += f"\n docstring: {fn['doc']}"
else:
contex += "\nNo docstring found"
contex += "\n" +src
return contex
def get_code_bytes(github_repo, file_name, start_bytes, end_bytes):
"""
Get the code bytes for a specific file and byte range.
"""
if github_repo+file_name not in code_ref:
return (f"File {file_name} not found in code_ref.")
# get the code bytes for the file
code_bytes = code_ref[github_repo+file_name]
# get the code bytes for the lines
code_bytes = code_bytes[start_bytes:end_bytes]
return code_bytes
# find all calls to a specific function in the
def find_function_calls_within_project(function_name,github_repo):
"""
Find all calls to `target_name` in the project.
"""
try:
_ensure_repo_indexed(github_repo)
except Exception as e:
return f"Error: {e}"
contexts = " "
# get all keys of dict code_ref
all_files = code_ref.keys()
for name in all_files:
if name.startswith(github_repo):
code_bytes = code_ref[name]
language = code_languages.get(name)
calls = find_call_sites(code_bytes, function_name, language) if language else []
rel_path = name
if calls:
context = f"\nFound {len(calls)} call(s) to `{function_name}` in {rel_path}:"
for c in calls:
context += f"\n ─ in `{c['caller']}` (L{c['start_line']}–L{c['end_line']}): {c['snippet']}"
contexts += context
if contexts == " ":
contexts = f"\nNo calls to `{function_name}` found in the project."
return contexts
def search_codebase_for_project(
term: str,
github_repo: str,
file_patterns=None,
ignore_names=None,
max_results: int = 200,
) -> str:
"""
Search the indexed project for lines containing ``term``.
"""
if not term:
return "Error: Search term must not be empty."
try:
_ensure_repo_indexed(github_repo)
except Exception as e:
return f"Error: {e}"
normalized_term = term.lower()
ignore_set = set(DEFAULT_SEARCH_IGNORES)
if ignore_names:
if isinstance(ignore_names, str):
ignore_names = [ignore_names]
ignore_set.update(ignore_names)
if isinstance(file_patterns, str):
file_patterns = [file_patterns]
matches = []
for key, content_bytes in code_ref.items():
if not key.startswith(github_repo):
continue
rel_path = key[len(github_repo):].lstrip("/\\")
display_path = rel_path or key
path_obj = Path(rel_path) if rel_path else Path(display_path)
if any(part in ignore_set for part in path_obj.parts[:-1]):
continue
if file_patterns and not any(path_obj.match(pattern) for pattern in file_patterns):
continue
text = _decode_text(content_bytes)
if text is None:
continue
for line_number, line in enumerate(text.splitlines(), start=1):
if normalized_term in line.lower():
matches.append(f"{display_path}:{line_number}: {line}")
if len(matches) >= max_results:
return "\n".join(matches)
if not matches:
return "No matches found."
return "\n".join(matches)
def get_function_context_for_project(function_name:str, github_repo:str,)-> str:
"""
Get the details of a function in a GitHub repo along with its callees.
@param function_name: The name of the function to find.
@param github_repo: The URL of the GitHub repo.
@param project_root: The root directory of the project.
"""
try:
_, all_functions = _ensure_repo_indexed(github_repo)
contex = get_function_context(function_name,all_functions,github_repo)
return contex
except Exception as e:
return f"Error: {e}"
def get_pr_diff_url(repo_url, pr_number) -> dict[str, str]:
"""
Fetch per-file diffs for a given repo URL and PR number.
Get the git diff of the changes of all commits for the given pull/merge request number.
Args:
repo_url (str): The URL of the GitHub repository.
pr_number (int): The pull request number.
returns:
"""
pr_diff_url = f"https://patch-diff.githubusercontent.com/raw/{repo_url.split('/')[-2]}/{repo_url.split('/')[-1]}/pull/{pr_number}.diff"
response = requests.get(pr_diff_url,verify=False)
if response.status_code != 200:
log.info(f"Failed to fetch diff: {response.status_code}")
exit()
if response.status_code != 200:
log.info(f"Failed to fetch diff: {response.status_code}")
exit()
diff_text = response.text
file_diffs = defaultdict(str)
file_diff_pattern = re.compile(r'^diff --git a/(.*?) b/\1$', re.MULTILINE)
split_points = list(file_diff_pattern.finditer(diff_text))
for i, match in enumerate(split_points):
file_path = match.group(1)
start = match.start()
end = split_points[i + 1].start() if i + 1 < len(split_points) else len(diff_text)
file_diffs[file_path] = diff_text[start:end]
return file_diffs
if __name__ == "__main__":
# ---------------------------------------------------------------------------
# For testing purposes, we can use a local directory or a GitHub repo URL.
# ---------------------------------------------------------------------------
# Test with a GitHub repo URL - Pyhon repo
log.info("-----------------Python Repo---------------------------------------")
repo_url = 'https://github.com/huggingface/accelerate'
# find a specific function
target_name = "get_max_layer_size"
contex =get_function_context_for_project(target_name,repo_url)
log.info(contex)
target_name = "get_max_layer_size"
contex =get_function_context_for_project(target_name,repo_url)
log.info(contex)
log.info("------------------End Test Python Repo--------------------------------------")
log.info("-----------------Go Repo---------------------------------------")
repo_url = 'https://github.com/ngrok/ngrok-operator'
# find a specific function
target_name = "createKubernetesOperator"
contex =get_function_context_for_project(target_name,repo_url)
log.info(contex)
target_name = "createKubernetesOperator"
contex =get_function_context_for_project(target_name,repo_url)
log.info(contex)
# Test with a GitHub repo URL - Pyhon repo