# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ Plain tool functions adapted from timemachine-bench-main/agents/src/utils/tools.py. All LangChain decorators removed; every parameter is an explicit keyword-only argument. Logic is preserved exactly from the original. """ from __future__ import annotations import re import subprocess import tempfile from pathlib import Path from typing import Dict, List, Optional, Tuple, Union def list_dir( *, host_repo_dir: str, repo_name: str, dir_path: Optional[str] = "/work", ) -> str: """List files and subdirectories under *dir_path* (default ``/work``).""" try: p = Path(dir_path) if p.is_absolute() and not p.is_relative_to("/work"): raise ValueError("You cannot access directories outside /work") if not p.is_absolute() and ".." in p.parts: raise ValueError("You cannot access directories outside /work") rel_path = p.relative_to("/work") if p.is_absolute() else p host_dir = Path(host_repo_dir) / rel_path files: list[str] = [] subdirectories: list[str] = [] if not host_dir.exists(): raise FileNotFoundError(f"The directory does not exist: {dir_path}") if not host_dir.is_dir(): raise NotADirectoryError(f"The path is not a directory: {dir_path}") for entry in host_dir.iterdir(): if entry.name.startswith("."): continue if entry.is_dir(): subdirectories.append(entry.name) elif entry.is_file(): files.append(entry.name) subdirectories = list(sorted([d + "/" for d in subdirectories])) files = list(sorted(files)) if dir_path == "/work": files = [ f for f in files if f not in {f"setup_{repo_name}.sh", f"test_{repo_name}.sh"} ] content_lst = subdirectories + files return ( f"Found {len(files)} file(s) and {len(subdirectories)} subdirectory(s) under {dir_path}:\n\n" f"\n" + "\n".join(content_lst) + "\n" ) except Exception as e: return f"Error listing the directory: {dir_path}. Here is the message: {str(e)}" def search_dir( *, regex_pattern: str, host_repo_dir: str, dir_path: Optional[str] = "/work", ) -> str: """Search for *regex_pattern* in all ``.py`` files under *dir_path*.""" try: ptn = re.compile(regex_pattern) p = Path(dir_path) if p.is_absolute() and not p.is_relative_to("/work"): raise ValueError("You cannot access directories outside /work") if not p.is_absolute() and ".." in p.parts: raise ValueError("You cannot access directories outside /work") rel_path = p.relative_to("/work") if p.is_absolute() else p host_dir = Path(host_repo_dir) / rel_path if not host_dir.exists(): raise FileNotFoundError(f"The directory does not exist: {dir_path}") if not host_dir.is_dir(): raise NotADirectoryError(f"The path is not a directory: {dir_path}") except Exception as e: return f'Search failed for "{regex_pattern}" in the directory. Here is the message: {str(e)}' matches: list[tuple[str, int]] = [] for fpath in Path(host_dir).rglob("*.py"): try: match_in_file = 0 with open(fpath, "r", newline="") as f: for line in f: line = line.rstrip() if ptn.search(line): match_in_file += 1 if match_in_file > 0: rel = fpath.relative_to(host_repo_dir) container_path = Path("/work") / rel matches.append((str(container_path), match_in_file)) except Exception: continue num_matches = len(matches) if not matches: return f'No matches found for "{regex_pattern}" in the directory.' if num_matches > 50: return ( f'More than 50 matches found for "{regex_pattern}" in the directory.' " Please write a more specific query to narrow down the results." ) result_str = f'Found {len(matches)} match(es) for "{regex_pattern}" in the directory:\n\n' result_str += "\n".join(f"{fpath} ({count} match(es))" for fpath, count in matches) return result_str def search_file( *, regex_pattern: str, file_path: str, host_repo_dir: str, ) -> str: """Search for *regex_pattern* in the file at *file_path*.""" matches: list[tuple[int, str]] = [] try: ptn = re.compile(regex_pattern) p = Path(file_path) if p.is_absolute() and not p.is_relative_to("/work"): raise ValueError("You cannot access directories outside /work") if not p.is_absolute() and ".." in p.parts: raise ValueError("You cannot access directories outside /work") rel_path = p.relative_to("/work") if p.is_absolute() else p host_path = Path(host_repo_dir) / rel_path if not host_path.exists(): raise FileNotFoundError(f"The file does not exist: {file_path}") if not host_path.is_file(): raise IsADirectoryError(f"The path is not a file: {file_path}") with open(host_path, "r", newline="") as f: for i, line in enumerate(f, 1): line = line.rstrip() if ptn.search(line): matches.append((i, line)) if not matches: return f'No matches found for "{regex_pattern}" in {file_path}.' num_matches = len(matches) if num_matches > 50: return ( f'More than 50 matches found for "{regex_pattern}" in {file_path}.' " Please write a more specific query to narrow down the results." ) result_str = f'Found {len(matches)} match(es) for "{regex_pattern}" in {file_path}:\n\n' result_str += "\n".join(f"{i}: {line}" for i, line in matches) return result_str except Exception as e: return f'Search failed for "{regex_pattern}" in {file_path}. Here is the message: {str(e)}' def view_file( *, file_path: str, line_no: int, host_repo_dir: str, ) -> str: """Show ±50 lines around *line_no* in *file_path*.""" try: p = Path(file_path) if p.is_absolute() and not p.is_relative_to("/work"): raise ValueError("You cannot access directories outside /work") if not p.is_absolute() and ".." in p.parts: raise ValueError("You cannot access directories outside /work") rel_path = p.relative_to("/work") if p.is_absolute() else p host_path = Path(host_repo_dir) / rel_path if not host_path.exists(): raise FileNotFoundError(f"The file does not exist: {file_path}") if not host_path.is_file(): raise IsADirectoryError(f"The path is not a file: {file_path}") target_lines: list[str] = [] start_line, end_line = max(1, line_no - 50), line_no + 50 with open(host_path, "r", newline="") as f: for i, line in enumerate(f, start=1): if start_line <= i <= end_line: target_lines.append(line) total_lines = i # noqa: F821 — set by the loop content = "".join( f"{i}: {line}" for i, line in enumerate(target_lines, start=start_line) ).rstrip() lines_before = start_line - 1 lines_after = total_lines - end_line if lines_before > 0: content = f"({lines_before} lines above)\n" + content if lines_after > 0: content = content + f"\n({lines_after} lines below)" return ( f"Here is the content of the file around line #{line_no}:\n\n" f"\n{content}\n" ) except Exception as e: return f"Error opening the file: {file_path}. Here is the message: {str(e)}" def edit_file( *, file_path: str, start_line: int, end_line: int, replacement_text: str, host_repo_dir: str, test_files: List[str], ) -> Dict[str, Optional[str]]: """Replace lines *start_line*–*end_line* (inclusive) with *replacement_text*.""" try: p = Path(file_path) if p.is_absolute() and not p.is_relative_to("/work"): raise ValueError("You cannot access directories outside /work") if not p.is_absolute() and ".." in p.parts: raise ValueError("You cannot access directories outside /work") rel_path = p.relative_to("/work") if p.is_absolute() else p host_path = Path(host_repo_dir) / rel_path if not host_path.exists(): raise FileNotFoundError(f"The file does not exist: {file_path}") if not host_path.is_file(): raise IsADirectoryError(f"The path is not a file: {file_path}") rel_path_str = str(rel_path) if rel_path_str in test_files: return { "patch": None, "message": "Editing test files is not allowed. The file was kept unchanged.", } with open(host_path, "r", newline="") as f: original_lines = f.readlines() to_replace_content = "".join(original_lines[start_line - 1 : end_line]) newline = "" if to_replace_content.endswith("\r\n"): newline = "\r\n" elif to_replace_content.endswith("\n"): newline = "\n" replacement_text = replacement_text.rstrip(newline) + newline updated_lines = ( original_lines[: start_line - 1] + replacement_text.splitlines(keepends=True) + original_lines[end_line:] ) updated_content = "".join(updated_lines) with tempfile.NamedTemporaryFile("w", newline="") as tmpf: tmp_path = tmpf.name tmpf.write(updated_content) tmpf.flush() diff_proc = subprocess.run( [ "diff", "-u", "--label", str(rel_path), str(host_path), "--label", str(rel_path), str(tmp_path), ], stdout=subprocess.PIPE, check=False, ) patch_content = diff_proc.stdout.decode("utf-8") subprocess.run( ["patch", "-p0", "-d", str(host_repo_dir)], input=patch_content, text=True, check=True, capture_output=True, ) new_end_line = start_line + len(replacement_text.splitlines(keepends=True)) - 1 to_show_start_line = max(1, start_line - 50) to_show_end_line = new_end_line + 50 diff_lines = updated_lines[to_show_start_line - 1 : to_show_end_line] diff_content = "".join( f"{i}: {line}" for i, line in enumerate(diff_lines, start=to_show_start_line) ).rstrip() return { "patch": patch_content, "message": ( "Edit succeeded. Here is the updated part of the file.\n\n" f"\n{diff_content}\n" ), } except Exception as e: return { "patch": None, "message": f"Edit failed. The file was kept unchanged. Here is the message: {str(e)}", } def replace_all_in_file( *, file_path: str, regex_pattern: str, replacement_string: str, host_repo_dir: str, test_files: List[str], ) -> Dict[str, Optional[str]]: """Find-and-replace all occurrences of *regex_pattern* in *file_path*.""" try: p = Path(file_path) if p.is_absolute() and not p.is_relative_to("/work"): raise ValueError("You cannot access directories outside /work") if not p.is_absolute() and ".." in p.parts: raise ValueError("You cannot access directories outside /work") rel_path = p.relative_to("/work") if p.is_absolute() else p host_path = Path(host_repo_dir) / rel_path if not host_path.exists(): raise FileNotFoundError(f"The file does not exist: {file_path}") if not host_path.is_file(): raise IsADirectoryError(f"The path is not a file: {file_path}") rel_path_str = str(rel_path) if rel_path_str in test_files: return { "patch": None, "message": "Editing test files is not allowed. The file was kept unchanged.", } with open(host_path, "r", newline="") as f: original_content = f.read() updated_content, num_replacements = re.subn( regex_pattern, replacement_string, original_content ) if num_replacements == 0: return { "patch": None, "message": f'No matches found for "{regex_pattern}" in {file_path}. The file was kept unchanged.', } with tempfile.NamedTemporaryFile("w", newline="") as tmpf: tmp_path = tmpf.name tmpf.write(updated_content) tmpf.flush() diff_proc = subprocess.run( [ "diff", "-u", "--label", str(rel_path), str(host_path), "--label", str(rel_path), str(tmp_path), ], stdout=subprocess.PIPE, check=False, ) patch_content = diff_proc.stdout.decode("utf-8") subprocess.run( ["patch", "-p0", "-d", str(host_repo_dir)], input=patch_content, text=True, check=True, capture_output=True, ) return { "patch": patch_content, "message": ( f'Replaced {num_replacements} occurrences of "{regex_pattern}" ' f'with "{replacement_string}" in {file_path}.' ), } except Exception as e: return { "patch": None, "message": f"Replacement failed. The file was kept unchanged. Here is the message: {str(e)}", } def revert_last( *, last_patch: Optional[Tuple[str, str]], host_repo_dir: str, ) -> str: """Revert the last edit using the stored patch.""" def _extract_diff_lines_to_show(patch: str) -> List[Tuple[int, int]]: lines_to_show: list[tuple[int, int]] = [] for diff_header in re.finditer( r"^@@ \-(\d+)(?:,(\d+))? \+\d+(?:,\d+)? @@", patch, re.MULTILINE ): diff_start = int(diff_header.group(1)) num_affected_lines = int(diff_header.group(2) or 1) diff_end = diff_start + num_affected_lines - 1 line_start = max(1, diff_start - 50) line_end = diff_end + 50 lines_to_show.append((line_start, line_end)) if not lines_to_show: return [] sorted_lines = list(sorted(lines_to_show, key=lambda x: x[0])) merged: list[tuple[int, int]] = [sorted_lines[0]] current_start, current_end = merged[0] for start, end in sorted_lines[1:]: if start <= current_end: current_end = end merged[-1] = (current_start, current_end) else: merged.append((start, end)) current_start, current_end = start, end return merged def _format_diff_blocks( file_path: str, file_lines: List[str], diff_lines: List[Tuple[int, int]], ) -> str: parts = [f"{file_path}", ""] for start, end in diff_lines: target_lines = file_lines[start - 1 : end] part_content = "".join( f"{i}: {line}" for i, line in enumerate(target_lines, start=start) ).rstrip() parts.append(f" \n{part_content}\n ") parts.append("") return "\n".join(parts) if last_patch is None: return "No edit to revert. The file was kept unchanged." try: last_modified_path, patch_content = last_patch rel_path = Path(last_modified_path).relative_to("/work") host_path = Path(host_repo_dir) / rel_path if not host_path.exists(): raise FileNotFoundError(f"The file does not exist: {last_modified_path}") if not host_path.is_file(): raise IsADirectoryError(f"The path is not a file: {last_modified_path}") subprocess.run( ["patch", "-R", "-p0", "-d", str(host_repo_dir)], input=patch_content, text=True, check=True, capture_output=True, ) with open(host_path, "r", newline="") as f: updated_lines = f.readlines() diff_lines = _extract_diff_lines_to_show(patch_content) diff_block_str = _format_diff_blocks(last_modified_path, updated_lines, diff_lines) return "Revert succeeded. Here are the updated parts of the file.\n\n" + diff_block_str except Exception as e: return f"Revert failed. The file was kept unchanged. Here is the message: {str(e)}" def execute_tests( *, host_repo_dir: str, image_name: str, sec_timeout: int, mem_limit: str, ) -> Dict[str, Union[str, Optional[int]]]: """Run tests. Uses Docker if available, otherwise runs via venv directly.""" import shutil as _shutil # Check if Docker is available docker_available = False if _shutil.which("docker"): try: subprocess.run(["docker", "info"], capture_output=True, timeout=5, check=True) docker_available = True except Exception: pass if docker_available: return _execute_tests_docker( host_repo_dir=host_repo_dir, image_name=image_name, sec_timeout=sec_timeout, mem_limit=mem_limit, ) else: return _execute_tests_venv( host_repo_dir=host_repo_dir, image_name=image_name, sec_timeout=sec_timeout, ) def _execute_tests_venv( *, host_repo_dir: str, image_name: str, sec_timeout: int, ) -> Dict[str, Union[str, Optional[int]]]: """Run tests directly using the venv created during setup.""" try: venv_python = str(Path(host_repo_dir) / ".venv" / "bin" / "python") # Find the test script # image_name is like "brmc__shortbus_new", strip the "_new" suffix base_name = image_name.replace("_new", "") test_script = Path(host_repo_dir) / f"test_{base_name}.sh" if not test_script.exists(): # Try to find any test script candidates = list(Path(host_repo_dir).glob("test_*.sh")) if candidates: test_script = candidates[0] else: return { "test_result": "No test script found.", "full_log": None, "container_status": 1, } # Parse the test script to extract the test command test_cmd = None with open(test_script) as f: for line in f: line = line.strip() if not line or line.startswith("#") or line.startswith("set "): continue test_cmd = line break if not test_cmd: return { "test_result": "Test script is empty.", "full_log": None, "container_status": 1, } # Replace python/python3 with venv python test_cmd = test_cmd.replace("python3 ", f"{venv_python} ").replace("python ", f"{venv_python} ") print(f"[NO-DOCKER] Running tests: {test_cmd}", flush=True) result = subprocess.run( test_cmd, shell=True, cwd=host_repo_dir, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, timeout=sec_timeout, ) test_log = result.stdout container_status = result.returncode # Truncate to last 100 lines with line numbers lines = test_log.splitlines(keepends=True) num_lines = len(lines) lines_to_show = lines[-100:] start_line = num_lines - len(lines_to_show) + 1 truncated_log = "".join( f"{i}: {line}" for i, line in enumerate(lines_to_show, start=start_line) ) lines_before = max(0, num_lines - 100) if lines_before > 0: truncated_log = f"({lines_before} lines above)\n" + truncated_log truncated_log = truncated_log.rstrip() return { "test_result": ( "Test execution completed. Here is the test log.\n\n" f"\n{truncated_log}\n" ), "full_log": test_log, "container_status": container_status, } except subprocess.TimeoutExpired: return { "test_result": f"Test execution timed out after {sec_timeout}s.", "full_log": None, "container_status": 1, } except Exception as e: return { "test_result": f"Test execution failed: {e}", "full_log": None, "container_status": None, } def _execute_tests_docker( *, host_repo_dir: str, image_name: str, sec_timeout: int, mem_limit: str, ) -> Dict[str, Union[str, Optional[int]]]: """Run the Docker container and return truncated log + exit code.""" try: container_name = image_name command = [ "timeout", "--foreground", "-s", "KILL", str(sec_timeout), "docker", "run", "-v", f"{host_repo_dir}:/work", f"--memory={mem_limit}", "--name", image_name, container_name, ] result = subprocess.run( command, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, check=False, ) test_log = result.stdout container_status = result.returncode # truncate: last 100 lines lines = test_log.splitlines(keepends=True) num_lines = len(lines) lines_to_show = lines[-100:] start_line = num_lines - len(lines_to_show) + 1 truncated_log = "".join( f"{i}: {line}" for i, line in enumerate(lines_to_show, start=start_line) ) lines_before = max(0, num_lines - 100) if lines_before > 0: truncated_log = f"({lines_before} lines above)\n" + truncated_log truncated_log = truncated_log.rstrip() return { "test_result": ( "Test execution completed. Here is the test log.\n\n" f"\n{truncated_log}\n" ), "full_log": test_log, "container_status": container_status, } except Exception: return { "test_result": "Test execution failed.", "full_log": None, "container_status": None, } finally: try: subprocess.run(["docker", "rm", "-f", container_name], capture_output=True) except Exception: pass def search_last_log( *, regex_pattern: str, last_log_path: str, ) -> str: """Search the last test log for *regex_pattern*.""" matches: list[tuple[int, str]] = [] try: ptn = re.compile(regex_pattern) with open(last_log_path, "r", newline="") as f: for i, line in enumerate(f, start=1): line = line.rstrip() if ptn.search(line): matches.append((i, line)) if not matches: return f'No matches found for "{regex_pattern}" in the last test log.' num_matches = len(matches) if num_matches > 50: return ( f'More than 50 matches found for "{regex_pattern}" in the last test log.' " Please write a more specific query to narrow down the results." ) result_str = f'Found {len(matches)} match(es) for "{regex_pattern}" in the last test log:\n\n' result_str += "\n".join(f"{i}: {line}" for i, line in matches[:50]) return result_str except Exception: return f'Search failed for "{regex_pattern}" in the last test log.' def view_last_log( *, line_no: int, last_log_path: str, ) -> str: """Show ±50 lines around *line_no* in the last test log.""" start_line, end_line = max(1, line_no - 50), line_no + 50 target_lines: list[str] = [] try: with open(last_log_path, "r", newline="") as f: for i, line in enumerate(f, start=1): if start_line <= i <= end_line: target_lines.append(line) log_content = "".join( f"{i}: {line}" for i, line in enumerate(target_lines, start=start_line) ).rstrip() return ( f"Here is the content of the last test log around line #{line_no}.\n\n" f"\n{log_content}\n" ) except Exception: return "Error opening the last test log."