Spaces:
Sleeping
Sleeping
| """ | |
| Author: Alex Punnen | |
| Code to create for a Code Reivew tool helper for MCP server | |
| License: Proprietary | |
| """ | |
| import os, textwrap | |
| from pathlib import Path | |
| from tree_sitter_languages import get_language | |
| from tree_sitter import Parser | |
| import tempfile | |
| from git import Repo | |
| from enum import Enum | |
| import logging as log | |
| import requests | |
| import re | |
| from collections import defaultdict | |
| log.basicConfig( | |
| level=log.INFO, | |
| format="%(asctime)s [%(levelname)s] %(message)s", # | |
| # format="[%(levelname)s] %(message)s", # dont need timing | |
| handlers=[log.StreamHandler()], | |
| force=True, | |
| ) | |
| parser = Parser() | |
| all_refs = {} # store all classes and functions in a dict | |
| code_ref ={} # hold the code bytes | |
| code_languages = {} # track language per file for downstream queries | |
| DEFAULT_SEARCH_IGNORES: tuple[str, ...] = ( | |
| ".git", | |
| ".hg", | |
| ".svn", | |
| ".mypy_cache", | |
| "__pycache__", | |
| "node_modules", | |
| "dist", | |
| "build", | |
| ".venv", | |
| ) | |
| LANGUAGE_NAME_MAP = { | |
| "python": "python", | |
| "go": "go", | |
| "cpp": "cpp", | |
| } | |
| def _collect_files(root_dir, extensions=None): | |
| all_files = [] | |
| for dirpath, _, filenames in os.walk(root_dir): | |
| for fname in filenames: | |
| if extensions is None or any(fname.endswith(ext) for ext in extensions): | |
| all_files.append(os.path.join(dirpath, fname)) | |
| return all_files | |
| def _decode_text(content_bytes: bytes): | |
| try: | |
| return content_bytes.decode("utf-8") | |
| except UnicodeDecodeError: | |
| try: | |
| return content_bytes.decode("latin-1") | |
| except UnicodeDecodeError: | |
| return None | |
| def _ensure_repo_indexed(github_repo: str): | |
| if github_repo in all_refs: | |
| cached = all_refs[github_repo] | |
| return cached["classes"], cached["functions"] | |
| with tempfile.TemporaryDirectory() as project_root: | |
| log.info(f"Cloning repo {github_repo} into {project_root}...") | |
| Repo.clone_from(github_repo, project_root, depth=1) | |
| log.info(f"Cloned repo {github_repo} into {project_root}.") | |
| log.info(f"Indexing files in {project_root}...") | |
| all_classes, all_functions = index_all_files(project_root, github_repo) | |
| log.info(f"Indexed {len(all_classes)} classes and {len(all_functions)} functions.") | |
| all_refs[github_repo] = {"classes": all_classes, "functions": all_functions} | |
| return all_classes, all_functions | |
| # --------------------------------------------------------------------------- | |
| # Language queries from https://github.com/sankalp1999/code_qa/blob/fe6ce9d852aa1c371c299db22978012df4b354a0/treesitter.py#L16 | |
| # --------------------------------------------------------------------------- | |
| class LanguageEnum(Enum): | |
| UNKNOWN = "unknown" | |
| PYTHON= { | |
| "class": """ | |
| (class_definition | |
| name: (identifier) @class.name) | |
| """, | |
| "func": """ | |
| (function_definition | |
| name: (identifier) @function.name) | |
| """, | |
| "doc": """ | |
| (expression_statement (string) @docstring) | |
| """, | |
| } | |
| GO = { | |
| # Match struct type declarations: | |
| # type Foo struct { … } | |
| 'struct_query': r""" | |
| (type_spec | |
| name: (type_identifier) @struct.name | |
| type: (struct_type)) | |
| """, | |
| # Match both top-level functions and methods: | |
| # func Bar(...) { … } | |
| # func (r Receiver) Baz(...) { … } | |
| 'func_query': r""" | |
| [ | |
| (function_declaration | |
| name: (identifier) @func.name) | |
| (method_declaration | |
| name: (field_identifier) @method.name) | |
| ] | |
| """, | |
| # Capture all comments (line or block) for docstrings: | |
| # // comment | |
| # /* comment */ | |
| 'doc_query': r""" | |
| (comment) @comment | |
| """ | |
| } | |
| CPP = { | |
| "class_query": r""" | |
| (class_specifier | |
| name: (type_identifier) @class.name) | |
| """, | |
| "struct_query": r""" | |
| (struct_specifier | |
| name: (type_identifier) @struct.name) | |
| """, | |
| "func_query": r""" | |
| (function_definition | |
| declarator: (function_declarator | |
| declarator: [ | |
| (identifier) @function.name | |
| (field_identifier) @function.name | |
| ] | |
| ) | |
| ) | |
| """, | |
| "doc_query": r""" | |
| (comment) @comment | |
| """, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # run the query and grab the captures | |
| # --------------------------------------------------------------------------- | |
| def _normalize_block(node): | |
| """ | |
| Walk up the tree until we hit a block node representing a function or class. | |
| """ | |
| target_types = { | |
| "function_definition", | |
| "function_declaration", | |
| "method_declaration", | |
| "class_definition", | |
| "class_specifier", | |
| "struct_specifier", | |
| "type_spec", | |
| } | |
| cur = node | |
| while cur and cur.type not in target_types: | |
| cur = cur.parent | |
| return cur or node | |
| def _run_query(code_bytes, q_src, tag, language): | |
| """ | |
| Return a list of dicts: {node, name, start_line, end_line} | |
| for every capture whose capture-name == tag. | |
| """ | |
| if language not in LANGUAGE_NAME_MAP: | |
| raise ValueError(f"Unsupported language: {language}") | |
| lang = get_language(LANGUAGE_NAME_MAP[language]) | |
| parser.set_language(lang) | |
| query = lang.query(q_src) | |
| root = parser.parse(code_bytes).root_node | |
| captures = query.captures(root) # [(node, capture_name), …] | |
| items = [] | |
| for node, cap_name in captures: | |
| if cap_name != tag: | |
| continue | |
| name = code_bytes[node.start_byte: node.end_byte].decode() | |
| if tag in {"docstring", "comment"}: | |
| block = node | |
| else: | |
| block = _normalize_block(node) | |
| # Get the enclosing class (if any) | |
| class_name = _get_enclosing_class_name(block, code_bytes) | |
| items.append({ | |
| "node" : block, | |
| "name" : name, | |
| "start_line" : block.start_point[0] + 1, # 0-based → 1-based | |
| "end_line" : block.end_point[0] + 1, | |
| "class" : class_name # optional, can be None | |
| }) | |
| return items | |
| # --------------------------------------------------------------------------- | |
| # attach the first doc-string that falls *inside* each block | |
| # --------------------------------------------------------------------------- | |
| def _attach_docstrings(code_bytes,items,docs): | |
| for itm in items: | |
| for d in docs: | |
| if (d["node"].start_byte >= itm["node"].start_byte and | |
| d["node"].end_byte <= itm["node"].end_byte): | |
| itm["doc"] = code_bytes[ | |
| d["node"].start_byte : d["node"].end_byte | |
| ].decode().strip('"\''" \n") | |
| break | |
| else: | |
| itm["doc"] = None | |
| return items | |
| # --------------------------------------------------------------------------- | |
| # Attach file name to each item | |
| # --------------------------------------------------------------------------- | |
| def _attach_file_name(items, file_path): | |
| for itm in items: | |
| itm["file"] = file_path | |
| return items | |
| # --------------------------------------------------------------------------- | |
| # Build a query that finds all calls to `target_name` | |
| # --------------------------------------------------------------------------- | |
| def _build_call_query(target_name: str, language: str): | |
| if language == "python": | |
| lang = get_language("python") | |
| return lang.query(f""" | |
| ( | |
| (call | |
| function: (identifier) @call.name | |
| arguments: (argument_list)? | |
| ) @call.node | |
| (#eq? @call.name "{target_name}") | |
| ) | |
| ( | |
| (call | |
| function: (attribute | |
| object: (_) | |
| attribute: (identifier) @call.name | |
| ) | |
| arguments: (argument_list)? | |
| ) @call.node | |
| (#eq? @call.name "{target_name}") | |
| ) | |
| """) | |
| if language == "go": | |
| lang = get_language("go") | |
| return lang.query(f""" | |
| ( | |
| (call_expression | |
| function: (identifier) @call.name | |
| arguments: (argument_list)? | |
| ) @call.node | |
| (#eq? @call.name "{target_name}") | |
| ) | |
| ( | |
| (call_expression | |
| function: (selector_expression | |
| operand: (_) | |
| field: (field_identifier) @call.name | |
| ) | |
| arguments: (argument_list)? | |
| ) @call.node | |
| (#eq? @call.name "{target_name}") | |
| ) | |
| """) | |
| if language == "cpp": | |
| lang = get_language("cpp") | |
| return lang.query(f""" | |
| ( | |
| (call_expression | |
| function: (identifier) @call.name | |
| arguments: (argument_list)? | |
| ) @call.node | |
| (#eq? @call.name "{target_name}") | |
| ) | |
| ( | |
| (call_expression | |
| function: (field_expression | |
| field: (field_identifier) @call.name | |
| ) | |
| arguments: (argument_list)? | |
| ) @call.node | |
| (#eq? @call.name "{target_name}") | |
| ) | |
| """) | |
| return None | |
| def _extract_identifier(node, code_bytes): | |
| """ | |
| Depth-first search for the first identifier-like child node. | |
| """ | |
| stack = [node] | |
| while stack: | |
| cur = stack.pop() | |
| if cur.type in {"identifier", "field_identifier", "type_identifier"}: | |
| return code_bytes[cur.start_byte:cur.end_byte].decode() | |
| stack.extend(reversed(cur.children or [])) | |
| return None | |
| def _get_enclosing_function(node, code_bytes): | |
| """ | |
| Walk up from `node` until we find a function_definition. | |
| Return its name (string) or None if at top-level. | |
| """ | |
| cur = node | |
| while cur: | |
| if cur.type in {"function_definition", "function_declaration", "method_declaration"}: | |
| # child_by_field_name works if the grammar labels the name field | |
| name_node = cur.child_by_field_name("name") | |
| if name_node: | |
| return code_bytes[name_node.start_byte:name_node.end_byte].decode() | |
| declarator = cur.child_by_field_name("declarator") | |
| if declarator: | |
| extracted = _extract_identifier(declarator, code_bytes) | |
| if extracted: | |
| return extracted | |
| cur = cur.parent | |
| return None | |
| def _get_enclosing_class_name(node, code_bytes): | |
| """ | |
| Traverse up the tree to find the enclosing class, if any. | |
| """ | |
| cur = node | |
| while cur: | |
| if cur.type in {"class_definition", "class_specifier", "struct_specifier"}: | |
| name_node = cur.child_by_field_name("name") | |
| if name_node: | |
| return code_bytes[name_node.start_byte:name_node.end_byte].decode() | |
| cur = cur.parent | |
| return None | |
| def find_call_sites(code_bytes: bytes, target_name: str, language: str): | |
| if language not in LANGUAGE_NAME_MAP: | |
| return [] | |
| lang = get_language(LANGUAGE_NAME_MAP[language]) | |
| parser_local = Parser() | |
| parser_local.set_language(lang) | |
| query = _build_call_query(target_name, language) | |
| tree = parser_local.parse(code_bytes) | |
| if not query: | |
| return [] | |
| caps = query.captures(tree.root_node) | |
| sites = [] | |
| for node, cap in caps: | |
| if cap != "call.node": | |
| continue | |
| # find caller | |
| caller = _get_enclosing_function(node, code_bytes) or "<module>" | |
| # grab the raw snippet | |
| raw = code_bytes[node.start_byte:node.end_byte].decode(errors="ignore") | |
| snippet = " ".join(raw.split()) | |
| start_ln, end_ln = node.start_point[0]+1, node.end_point[0]+1 | |
| sites.append({ | |
| "caller": caller, | |
| "start_line": start_ln, | |
| "end_line": end_ln, | |
| "snippet": snippet, | |
| }) | |
| return sites | |
| def index_all_files(project_root,git_repo_url): | |
| all_classes = [] | |
| all_functions = [] | |
| all_files = _collect_files(project_root, [".py",".go",".cpp",".cc",".cxx",".hpp",".hh",".h",".hxx",".ipp"]) | |
| for path in all_files: | |
| with open(path, "r", encoding="utf8") as f: | |
| code = f.read() | |
| code_bytes = code.encode() | |
| log.info(f"Processing {path}") | |
| language = LanguageEnum.UNKNOWN.value | |
| if path.endswith(".py"): | |
| language = LanguageEnum.PYTHON.value | |
| classes = _run_query(code_bytes,language["class"], "class.name","python") | |
| functions = _run_query(code_bytes,language["func"], "function.name","python") | |
| docs = _run_query(code_bytes,language["doc"], "docstring","python") # optional | |
| classes = _attach_docstrings(code_bytes,classes,docs) | |
| functions = _attach_docstrings(code_bytes,functions,docs) | |
| # get the file name and previous directory | |
| # get only file name and relative path | |
| file_name = os.path.basename(path) | |
| rel_path = os.path.relpath(path, project_root) | |
| log.info(f"Processing {file_name} ({len(classes)} classes, {len(functions)} functions), {rel_path})") | |
| classes = _attach_file_name(classes, rel_path) | |
| functions = _attach_file_name(functions, rel_path) | |
| code_ref[git_repo_url+rel_path] =code_bytes | |
| code_languages[git_repo_url+rel_path] = "python" | |
| all_classes.extend(classes) | |
| all_functions.extend(functions) | |
| elif path.endswith(".go"): | |
| language = LanguageEnum.GO.value | |
| # fill for Go language | |
| structs = _run_query(code_bytes,language["struct_query"], "struct.name","go") | |
| functions = _run_query(code_bytes,language["func_query"], "func.name","go") | |
| docs = _run_query(code_bytes,language["doc_query"], "comment","go") | |
| structs = _attach_docstrings(code_bytes,structs,docs) | |
| functions = _attach_docstrings(code_bytes,functions,docs) | |
| # get the file name and previous directory | |
| file_name = os.path.basename(path) | |
| rel_path = os.path.relpath(path, project_root) | |
| log.info(f"Processing {file_name} ({len(structs)} structs, {len(functions)} functions), {rel_path})") | |
| structs = _attach_file_name(structs, rel_path) | |
| functions = _attach_file_name(functions, rel_path) | |
| code_ref[git_repo_url+rel_path] =code_bytes | |
| code_languages[git_repo_url+rel_path] = "go" | |
| all_classes.extend(structs) | |
| all_functions.extend(functions) | |
| elif path.endswith((".cpp",".cc",".cxx",".hpp",".hh",".h",".hxx",".ipp")): | |
| language = LanguageEnum.CPP.value | |
| classes = _run_query(code_bytes, language["class_query"], "class.name", "cpp") | |
| structs = _run_query(code_bytes, language["struct_query"], "struct.name", "cpp") | |
| functions = _run_query(code_bytes, language["func_query"], "function.name", "cpp") | |
| docs = _run_query(code_bytes, language["doc_query"], "comment", "cpp") | |
| classes = _attach_docstrings(code_bytes, classes, docs) | |
| structs = _attach_docstrings(code_bytes, structs, docs) | |
| functions = _attach_docstrings(code_bytes, functions, docs) | |
| file_name = os.path.basename(path) | |
| rel_path = os.path.relpath(path, project_root) | |
| log.info(f"Processing {file_name} ({len(classes)+len(structs)} class/struct, {len(functions)} functions), {rel_path})") | |
| classes = _attach_file_name(classes, rel_path) | |
| structs = _attach_file_name(structs, rel_path) | |
| functions = _attach_file_name(functions, rel_path) | |
| code_ref[git_repo_url+rel_path] = code_bytes | |
| code_languages[git_repo_url+rel_path] = "cpp" | |
| all_classes.extend(classes) | |
| all_classes.extend(structs) | |
| all_functions.extend(functions) | |
| else: | |
| log.info(f"Skipping {path}, unsupported file type.") | |
| continue | |
| return all_classes, all_functions | |
| def get_function_context(target_name,all_functions,github_url): | |
| """ | |
| Find all functions with the same name as `target_name`. | |
| Return their context (docstring, source code). | |
| @param target_name: The name of the function to find. | |
| @param all_functions: The list of all functions in the project. | |
| """ | |
| matches = [fn for fn in all_functions if fn["name"] == target_name] | |
| log.info(f"\n\nFound {len(matches)} matches for '{target_name}':") | |
| for fn in matches: | |
| start, end = fn["node"].start_byte, fn["node"].end_byte | |
| file_name = fn["file"] | |
| code_bytes = code_ref[github_url+file_name] | |
| raw_src = code_bytes[start:end].decode() | |
| src = textwrap.dedent(raw_src).rstrip() | |
| rel_path = file_name | |
| contex = f"Definition in {rel_path} (L{fn['start_line']}–{fn['end_line']}):\n" | |
| if fn.get("class"): | |
| contex += f"{fn['class']}.{fn['name']} (L{fn['start_line']}–{fn['end_line']})" | |
| else: | |
| contex += f"{fn['name']} (L{fn['start_line']}–{fn['end_line']})" | |
| if fn.get("doc"): | |
| contex += f"\n docstring: {fn['doc']}" | |
| else: | |
| contex += "\nNo docstring found" | |
| contex += "\n" +src | |
| return contex | |
| def get_code_bytes(github_repo, file_name, start_bytes, end_bytes): | |
| """ | |
| Get the code bytes for a specific file and byte range. | |
| """ | |
| if github_repo+file_name not in code_ref: | |
| return (f"File {file_name} not found in code_ref.") | |
| # get the code bytes for the file | |
| code_bytes = code_ref[github_repo+file_name] | |
| # get the code bytes for the lines | |
| code_bytes = code_bytes[start_bytes:end_bytes] | |
| return code_bytes | |
| # find all calls to a specific function in the | |
| def find_function_calls_within_project(function_name,github_repo): | |
| """ | |
| Find all calls to `target_name` in the project. | |
| """ | |
| try: | |
| _ensure_repo_indexed(github_repo) | |
| except Exception as e: | |
| return f"Error: {e}" | |
| contexts = " " | |
| # get all keys of dict code_ref | |
| all_files = code_ref.keys() | |
| for name in all_files: | |
| if name.startswith(github_repo): | |
| code_bytes = code_ref[name] | |
| language = code_languages.get(name) | |
| calls = find_call_sites(code_bytes, function_name, language) if language else [] | |
| rel_path = name | |
| if calls: | |
| context = f"\nFound {len(calls)} call(s) to `{function_name}` in {rel_path}:" | |
| for c in calls: | |
| context += f"\n ─ in `{c['caller']}` (L{c['start_line']}–L{c['end_line']}): {c['snippet']}" | |
| contexts += context | |
| if contexts == " ": | |
| contexts = f"\nNo calls to `{function_name}` found in the project." | |
| return contexts | |
| def search_codebase_for_project( | |
| term: str, | |
| github_repo: str, | |
| file_patterns=None, | |
| ignore_names=None, | |
| max_results: int = 200, | |
| ) -> str: | |
| """ | |
| Search the indexed project for lines containing ``term``. | |
| """ | |
| if not term: | |
| return "Error: Search term must not be empty." | |
| try: | |
| _ensure_repo_indexed(github_repo) | |
| except Exception as e: | |
| return f"Error: {e}" | |
| normalized_term = term.lower() | |
| ignore_set = set(DEFAULT_SEARCH_IGNORES) | |
| if ignore_names: | |
| if isinstance(ignore_names, str): | |
| ignore_names = [ignore_names] | |
| ignore_set.update(ignore_names) | |
| if isinstance(file_patterns, str): | |
| file_patterns = [file_patterns] | |
| matches = [] | |
| for key, content_bytes in code_ref.items(): | |
| if not key.startswith(github_repo): | |
| continue | |
| rel_path = key[len(github_repo):].lstrip("/\\") | |
| display_path = rel_path or key | |
| path_obj = Path(rel_path) if rel_path else Path(display_path) | |
| if any(part in ignore_set for part in path_obj.parts[:-1]): | |
| continue | |
| if file_patterns and not any(path_obj.match(pattern) for pattern in file_patterns): | |
| continue | |
| text = _decode_text(content_bytes) | |
| if text is None: | |
| continue | |
| for line_number, line in enumerate(text.splitlines(), start=1): | |
| if normalized_term in line.lower(): | |
| matches.append(f"{display_path}:{line_number}: {line}") | |
| if len(matches) >= max_results: | |
| return "\n".join(matches) | |
| if not matches: | |
| return "No matches found." | |
| return "\n".join(matches) | |
| def get_function_context_for_project(function_name:str, github_repo:str,)-> str: | |
| """ | |
| Get the details of a function in a GitHub repo along with its callees. | |
| @param function_name: The name of the function to find. | |
| @param github_repo: The URL of the GitHub repo. | |
| @param project_root: The root directory of the project. | |
| """ | |
| try: | |
| _, all_functions = _ensure_repo_indexed(github_repo) | |
| contex = get_function_context(function_name,all_functions,github_repo) | |
| return contex | |
| except Exception as e: | |
| return f"Error: {e}" | |
| def get_pr_diff_url(repo_url, pr_number) -> dict[str, str]: | |
| """ | |
| Fetch per-file diffs for a given repo URL and PR number. | |
| Get the git diff of the changes of all commits for the given pull/merge request number. | |
| Args: | |
| repo_url (str): The URL of the GitHub repository. | |
| pr_number (int): The pull request number. | |
| returns: | |
| """ | |
| pr_diff_url = f"https://patch-diff.githubusercontent.com/raw/{repo_url.split('/')[-2]}/{repo_url.split('/')[-1]}/pull/{pr_number}.diff" | |
| response = requests.get(pr_diff_url,verify=False) | |
| if response.status_code != 200: | |
| log.info(f"Failed to fetch diff: {response.status_code}") | |
| exit() | |
| if response.status_code != 200: | |
| log.info(f"Failed to fetch diff: {response.status_code}") | |
| exit() | |
| diff_text = response.text | |
| file_diffs = defaultdict(str) | |
| file_diff_pattern = re.compile(r'^diff --git a/(.*?) b/\1$', re.MULTILINE) | |
| split_points = list(file_diff_pattern.finditer(diff_text)) | |
| for i, match in enumerate(split_points): | |
| file_path = match.group(1) | |
| start = match.start() | |
| end = split_points[i + 1].start() if i + 1 < len(split_points) else len(diff_text) | |
| file_diffs[file_path] = diff_text[start:end] | |
| return file_diffs | |
| if __name__ == "__main__": | |
| # --------------------------------------------------------------------------- | |
| # For testing purposes, we can use a local directory or a GitHub repo URL. | |
| # --------------------------------------------------------------------------- | |
| # Test with a GitHub repo URL - Pyhon repo | |
| log.info("-----------------Python Repo---------------------------------------") | |
| repo_url = 'https://github.com/huggingface/accelerate' | |
| # find a specific function | |
| target_name = "get_max_layer_size" | |
| contex =get_function_context_for_project(target_name,repo_url) | |
| log.info(contex) | |
| target_name = "get_max_layer_size" | |
| contex =get_function_context_for_project(target_name,repo_url) | |
| log.info(contex) | |
| log.info("------------------End Test Python Repo--------------------------------------") | |
| log.info("-----------------Go Repo---------------------------------------") | |
| repo_url = 'https://github.com/ngrok/ngrok-operator' | |
| # find a specific function | |
| target_name = "createKubernetesOperator" | |
| contex =get_function_context_for_project(target_name,repo_url) | |
| log.info(contex) | |
| target_name = "createKubernetesOperator" | |
| contex =get_function_context_for_project(target_name,repo_url) | |
| log.info(contex) | |
| # Test with a GitHub repo URL - Pyhon repo | |