# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Plain tool functions adapted from timemachine-bench-main/agents/src/utils/tools.py.
All LangChain decorators removed; every parameter is an explicit keyword-only argument.
Logic is preserved exactly from the original.
"""
from __future__ import annotations
import re
import subprocess
import tempfile
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
def list_dir(
*,
host_repo_dir: str,
repo_name: str,
dir_path: Optional[str] = "/work",
) -> str:
"""List files and subdirectories under *dir_path* (default ``/work``)."""
try:
p = Path(dir_path)
if p.is_absolute() and not p.is_relative_to("/work"):
raise ValueError("You cannot access directories outside /work")
if not p.is_absolute() and ".." in p.parts:
raise ValueError("You cannot access directories outside /work")
rel_path = p.relative_to("/work") if p.is_absolute() else p
host_dir = Path(host_repo_dir) / rel_path
files: list[str] = []
subdirectories: list[str] = []
if not host_dir.exists():
raise FileNotFoundError(f"The directory does not exist: {dir_path}")
if not host_dir.is_dir():
raise NotADirectoryError(f"The path is not a directory: {dir_path}")
for entry in host_dir.iterdir():
if entry.name.startswith("."):
continue
if entry.is_dir():
subdirectories.append(entry.name)
elif entry.is_file():
files.append(entry.name)
subdirectories = list(sorted([d + "/" for d in subdirectories]))
files = list(sorted(files))
if dir_path == "/work":
files = [
f
for f in files
if f not in {f"setup_{repo_name}.sh", f"test_{repo_name}.sh"}
]
content_lst = subdirectories + files
return (
f"Found {len(files)} file(s) and {len(subdirectories)} subdirectory(s) under {dir_path}:\n\n"
f"\n" + "\n".join(content_lst) + "\n"
)
except Exception as e:
return f"Error listing the directory: {dir_path}. Here is the message: {str(e)}"
def search_dir(
*,
regex_pattern: str,
host_repo_dir: str,
dir_path: Optional[str] = "/work",
) -> str:
"""Search for *regex_pattern* in all ``.py`` files under *dir_path*."""
try:
ptn = re.compile(regex_pattern)
p = Path(dir_path)
if p.is_absolute() and not p.is_relative_to("/work"):
raise ValueError("You cannot access directories outside /work")
if not p.is_absolute() and ".." in p.parts:
raise ValueError("You cannot access directories outside /work")
rel_path = p.relative_to("/work") if p.is_absolute() else p
host_dir = Path(host_repo_dir) / rel_path
if not host_dir.exists():
raise FileNotFoundError(f"The directory does not exist: {dir_path}")
if not host_dir.is_dir():
raise NotADirectoryError(f"The path is not a directory: {dir_path}")
except Exception as e:
return f'Search failed for "{regex_pattern}" in the directory. Here is the message: {str(e)}'
matches: list[tuple[str, int]] = []
for fpath in Path(host_dir).rglob("*.py"):
try:
match_in_file = 0
with open(fpath, "r", newline="") as f:
for line in f:
line = line.rstrip()
if ptn.search(line):
match_in_file += 1
if match_in_file > 0:
rel = fpath.relative_to(host_repo_dir)
container_path = Path("/work") / rel
matches.append((str(container_path), match_in_file))
except Exception:
continue
num_matches = len(matches)
if not matches:
return f'No matches found for "{regex_pattern}" in the directory.'
if num_matches > 50:
return (
f'More than 50 matches found for "{regex_pattern}" in the directory.'
" Please write a more specific query to narrow down the results."
)
result_str = f'Found {len(matches)} match(es) for "{regex_pattern}" in the directory:\n\n'
result_str += "\n".join(f"{fpath} ({count} match(es))" for fpath, count in matches)
return result_str
def search_file(
*,
regex_pattern: str,
file_path: str,
host_repo_dir: str,
) -> str:
"""Search for *regex_pattern* in the file at *file_path*."""
matches: list[tuple[int, str]] = []
try:
ptn = re.compile(regex_pattern)
p = Path(file_path)
if p.is_absolute() and not p.is_relative_to("/work"):
raise ValueError("You cannot access directories outside /work")
if not p.is_absolute() and ".." in p.parts:
raise ValueError("You cannot access directories outside /work")
rel_path = p.relative_to("/work") if p.is_absolute() else p
host_path = Path(host_repo_dir) / rel_path
if not host_path.exists():
raise FileNotFoundError(f"The file does not exist: {file_path}")
if not host_path.is_file():
raise IsADirectoryError(f"The path is not a file: {file_path}")
with open(host_path, "r", newline="") as f:
for i, line in enumerate(f, 1):
line = line.rstrip()
if ptn.search(line):
matches.append((i, line))
if not matches:
return f'No matches found for "{regex_pattern}" in {file_path}.'
num_matches = len(matches)
if num_matches > 50:
return (
f'More than 50 matches found for "{regex_pattern}" in {file_path}.'
" Please write a more specific query to narrow down the results."
)
result_str = f'Found {len(matches)} match(es) for "{regex_pattern}" in {file_path}:\n\n'
result_str += "\n".join(f"{i}: {line}" for i, line in matches)
return result_str
except Exception as e:
return f'Search failed for "{regex_pattern}" in {file_path}. Here is the message: {str(e)}'
def view_file(
*,
file_path: str,
line_no: int,
host_repo_dir: str,
) -> str:
"""Show ±50 lines around *line_no* in *file_path*."""
try:
p = Path(file_path)
if p.is_absolute() and not p.is_relative_to("/work"):
raise ValueError("You cannot access directories outside /work")
if not p.is_absolute() and ".." in p.parts:
raise ValueError("You cannot access directories outside /work")
rel_path = p.relative_to("/work") if p.is_absolute() else p
host_path = Path(host_repo_dir) / rel_path
if not host_path.exists():
raise FileNotFoundError(f"The file does not exist: {file_path}")
if not host_path.is_file():
raise IsADirectoryError(f"The path is not a file: {file_path}")
target_lines: list[str] = []
start_line, end_line = max(1, line_no - 50), line_no + 50
with open(host_path, "r", newline="") as f:
for i, line in enumerate(f, start=1):
if start_line <= i <= end_line:
target_lines.append(line)
total_lines = i # noqa: F821 — set by the loop
content = "".join(
f"{i}: {line}" for i, line in enumerate(target_lines, start=start_line)
).rstrip()
lines_before = start_line - 1
lines_after = total_lines - end_line
if lines_before > 0:
content = f"({lines_before} lines above)\n" + content
if lines_after > 0:
content = content + f"\n({lines_after} lines below)"
return (
f"Here is the content of the file around line #{line_no}:\n\n"
f"\n{content}\n"
)
except Exception as e:
return f"Error opening the file: {file_path}. Here is the message: {str(e)}"
def edit_file(
*,
file_path: str,
start_line: int,
end_line: int,
replacement_text: str,
host_repo_dir: str,
test_files: List[str],
) -> Dict[str, Optional[str]]:
"""Replace lines *start_line*–*end_line* (inclusive) with *replacement_text*."""
try:
p = Path(file_path)
if p.is_absolute() and not p.is_relative_to("/work"):
raise ValueError("You cannot access directories outside /work")
if not p.is_absolute() and ".." in p.parts:
raise ValueError("You cannot access directories outside /work")
rel_path = p.relative_to("/work") if p.is_absolute() else p
host_path = Path(host_repo_dir) / rel_path
if not host_path.exists():
raise FileNotFoundError(f"The file does not exist: {file_path}")
if not host_path.is_file():
raise IsADirectoryError(f"The path is not a file: {file_path}")
rel_path_str = str(rel_path)
if rel_path_str in test_files:
return {
"patch": None,
"message": "Editing test files is not allowed. The file was kept unchanged.",
}
with open(host_path, "r", newline="") as f:
original_lines = f.readlines()
to_replace_content = "".join(original_lines[start_line - 1 : end_line])
newline = ""
if to_replace_content.endswith("\r\n"):
newline = "\r\n"
elif to_replace_content.endswith("\n"):
newline = "\n"
replacement_text = replacement_text.rstrip(newline) + newline
updated_lines = (
original_lines[: start_line - 1]
+ replacement_text.splitlines(keepends=True)
+ original_lines[end_line:]
)
updated_content = "".join(updated_lines)
with tempfile.NamedTemporaryFile("w", newline="") as tmpf:
tmp_path = tmpf.name
tmpf.write(updated_content)
tmpf.flush()
diff_proc = subprocess.run(
[
"diff", "-u",
"--label", str(rel_path),
str(host_path),
"--label", str(rel_path),
str(tmp_path),
],
stdout=subprocess.PIPE,
check=False,
)
patch_content = diff_proc.stdout.decode("utf-8")
subprocess.run(
["patch", "-p0", "-d", str(host_repo_dir)],
input=patch_content,
text=True,
check=True,
capture_output=True,
)
new_end_line = start_line + len(replacement_text.splitlines(keepends=True)) - 1
to_show_start_line = max(1, start_line - 50)
to_show_end_line = new_end_line + 50
diff_lines = updated_lines[to_show_start_line - 1 : to_show_end_line]
diff_content = "".join(
f"{i}: {line}" for i, line in enumerate(diff_lines, start=to_show_start_line)
).rstrip()
return {
"patch": patch_content,
"message": (
"Edit succeeded. Here is the updated part of the file.\n\n"
f"\n{diff_content}\n"
),
}
except Exception as e:
return {
"patch": None,
"message": f"Edit failed. The file was kept unchanged. Here is the message: {str(e)}",
}
def replace_all_in_file(
*,
file_path: str,
regex_pattern: str,
replacement_string: str,
host_repo_dir: str,
test_files: List[str],
) -> Dict[str, Optional[str]]:
"""Find-and-replace all occurrences of *regex_pattern* in *file_path*."""
try:
p = Path(file_path)
if p.is_absolute() and not p.is_relative_to("/work"):
raise ValueError("You cannot access directories outside /work")
if not p.is_absolute() and ".." in p.parts:
raise ValueError("You cannot access directories outside /work")
rel_path = p.relative_to("/work") if p.is_absolute() else p
host_path = Path(host_repo_dir) / rel_path
if not host_path.exists():
raise FileNotFoundError(f"The file does not exist: {file_path}")
if not host_path.is_file():
raise IsADirectoryError(f"The path is not a file: {file_path}")
rel_path_str = str(rel_path)
if rel_path_str in test_files:
return {
"patch": None,
"message": "Editing test files is not allowed. The file was kept unchanged.",
}
with open(host_path, "r", newline="") as f:
original_content = f.read()
updated_content, num_replacements = re.subn(
regex_pattern, replacement_string, original_content
)
if num_replacements == 0:
return {
"patch": None,
"message": f'No matches found for "{regex_pattern}" in {file_path}. The file was kept unchanged.',
}
with tempfile.NamedTemporaryFile("w", newline="") as tmpf:
tmp_path = tmpf.name
tmpf.write(updated_content)
tmpf.flush()
diff_proc = subprocess.run(
[
"diff", "-u",
"--label", str(rel_path),
str(host_path),
"--label", str(rel_path),
str(tmp_path),
],
stdout=subprocess.PIPE,
check=False,
)
patch_content = diff_proc.stdout.decode("utf-8")
subprocess.run(
["patch", "-p0", "-d", str(host_repo_dir)],
input=patch_content,
text=True,
check=True,
capture_output=True,
)
return {
"patch": patch_content,
"message": (
f'Replaced {num_replacements} occurrences of "{regex_pattern}" '
f'with "{replacement_string}" in {file_path}.'
),
}
except Exception as e:
return {
"patch": None,
"message": f"Replacement failed. The file was kept unchanged. Here is the message: {str(e)}",
}
def revert_last(
*,
last_patch: Optional[Tuple[str, str]],
host_repo_dir: str,
) -> str:
"""Revert the last edit using the stored patch."""
def _extract_diff_lines_to_show(patch: str) -> List[Tuple[int, int]]:
lines_to_show: list[tuple[int, int]] = []
for diff_header in re.finditer(
r"^@@ \-(\d+)(?:,(\d+))? \+\d+(?:,\d+)? @@", patch, re.MULTILINE
):
diff_start = int(diff_header.group(1))
num_affected_lines = int(diff_header.group(2) or 1)
diff_end = diff_start + num_affected_lines - 1
line_start = max(1, diff_start - 50)
line_end = diff_end + 50
lines_to_show.append((line_start, line_end))
if not lines_to_show:
return []
sorted_lines = list(sorted(lines_to_show, key=lambda x: x[0]))
merged: list[tuple[int, int]] = [sorted_lines[0]]
current_start, current_end = merged[0]
for start, end in sorted_lines[1:]:
if start <= current_end:
current_end = end
merged[-1] = (current_start, current_end)
else:
merged.append((start, end))
current_start, current_end = start, end
return merged
def _format_diff_blocks(
file_path: str,
file_lines: List[str],
diff_lines: List[Tuple[int, int]],
) -> str:
parts = [f"{file_path}", ""]
for start, end in diff_lines:
target_lines = file_lines[start - 1 : end]
part_content = "".join(
f"{i}: {line}" for i, line in enumerate(target_lines, start=start)
).rstrip()
parts.append(f" \n{part_content}\n ")
parts.append("")
return "\n".join(parts)
if last_patch is None:
return "No edit to revert. The file was kept unchanged."
try:
last_modified_path, patch_content = last_patch
rel_path = Path(last_modified_path).relative_to("/work")
host_path = Path(host_repo_dir) / rel_path
if not host_path.exists():
raise FileNotFoundError(f"The file does not exist: {last_modified_path}")
if not host_path.is_file():
raise IsADirectoryError(f"The path is not a file: {last_modified_path}")
subprocess.run(
["patch", "-R", "-p0", "-d", str(host_repo_dir)],
input=patch_content,
text=True,
check=True,
capture_output=True,
)
with open(host_path, "r", newline="") as f:
updated_lines = f.readlines()
diff_lines = _extract_diff_lines_to_show(patch_content)
diff_block_str = _format_diff_blocks(last_modified_path, updated_lines, diff_lines)
return "Revert succeeded. Here are the updated parts of the file.\n\n" + diff_block_str
except Exception as e:
return f"Revert failed. The file was kept unchanged. Here is the message: {str(e)}"
def execute_tests(
*,
host_repo_dir: str,
image_name: str,
sec_timeout: int,
mem_limit: str,
) -> Dict[str, Union[str, Optional[int]]]:
"""Run tests. Uses Docker if available, otherwise runs via venv directly."""
import shutil as _shutil
# Check if Docker is available
docker_available = False
if _shutil.which("docker"):
try:
subprocess.run(["docker", "info"], capture_output=True, timeout=5, check=True)
docker_available = True
except Exception:
pass
if docker_available:
return _execute_tests_docker(
host_repo_dir=host_repo_dir, image_name=image_name,
sec_timeout=sec_timeout, mem_limit=mem_limit,
)
else:
return _execute_tests_venv(
host_repo_dir=host_repo_dir, image_name=image_name,
sec_timeout=sec_timeout,
)
def _execute_tests_venv(
*,
host_repo_dir: str,
image_name: str,
sec_timeout: int,
) -> Dict[str, Union[str, Optional[int]]]:
"""Run tests directly using the venv created during setup."""
try:
venv_python = str(Path(host_repo_dir) / ".venv" / "bin" / "python")
# Find the test script
# image_name is like "brmc__shortbus_new", strip the "_new" suffix
base_name = image_name.replace("_new", "")
test_script = Path(host_repo_dir) / f"test_{base_name}.sh"
if not test_script.exists():
# Try to find any test script
candidates = list(Path(host_repo_dir).glob("test_*.sh"))
if candidates:
test_script = candidates[0]
else:
return {
"test_result": "No test script found.",
"full_log": None,
"container_status": 1,
}
# Parse the test script to extract the test command
test_cmd = None
with open(test_script) as f:
for line in f:
line = line.strip()
if not line or line.startswith("#") or line.startswith("set "):
continue
test_cmd = line
break
if not test_cmd:
return {
"test_result": "Test script is empty.",
"full_log": None,
"container_status": 1,
}
# Replace python/python3 with venv python
test_cmd = test_cmd.replace("python3 ", f"{venv_python} ").replace("python ", f"{venv_python} ")
print(f"[NO-DOCKER] Running tests: {test_cmd}", flush=True)
result = subprocess.run(
test_cmd,
shell=True,
cwd=host_repo_dir,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
timeout=sec_timeout,
)
test_log = result.stdout
container_status = result.returncode
# Truncate to last 100 lines with line numbers
lines = test_log.splitlines(keepends=True)
num_lines = len(lines)
lines_to_show = lines[-100:]
start_line = num_lines - len(lines_to_show) + 1
truncated_log = "".join(
f"{i}: {line}" for i, line in enumerate(lines_to_show, start=start_line)
)
lines_before = max(0, num_lines - 100)
if lines_before > 0:
truncated_log = f"({lines_before} lines above)\n" + truncated_log
truncated_log = truncated_log.rstrip()
return {
"test_result": (
"Test execution completed. Here is the test log.\n\n"
f"\n{truncated_log}\n"
),
"full_log": test_log,
"container_status": container_status,
}
except subprocess.TimeoutExpired:
return {
"test_result": f"Test execution timed out after {sec_timeout}s.",
"full_log": None,
"container_status": 1,
}
except Exception as e:
return {
"test_result": f"Test execution failed: {e}",
"full_log": None,
"container_status": None,
}
def _execute_tests_docker(
*,
host_repo_dir: str,
image_name: str,
sec_timeout: int,
mem_limit: str,
) -> Dict[str, Union[str, Optional[int]]]:
"""Run the Docker container and return truncated log + exit code."""
try:
container_name = image_name
command = [
"timeout",
"--foreground",
"-s",
"KILL",
str(sec_timeout),
"docker",
"run",
"-v",
f"{host_repo_dir}:/work",
f"--memory={mem_limit}",
"--name",
image_name,
container_name,
]
result = subprocess.run(
command,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
check=False,
)
test_log = result.stdout
container_status = result.returncode
# truncate: last 100 lines
lines = test_log.splitlines(keepends=True)
num_lines = len(lines)
lines_to_show = lines[-100:]
start_line = num_lines - len(lines_to_show) + 1
truncated_log = "".join(
f"{i}: {line}" for i, line in enumerate(lines_to_show, start=start_line)
)
lines_before = max(0, num_lines - 100)
if lines_before > 0:
truncated_log = f"({lines_before} lines above)\n" + truncated_log
truncated_log = truncated_log.rstrip()
return {
"test_result": (
"Test execution completed. Here is the test log.\n\n"
f"\n{truncated_log}\n"
),
"full_log": test_log,
"container_status": container_status,
}
except Exception:
return {
"test_result": "Test execution failed.",
"full_log": None,
"container_status": None,
}
finally:
try:
subprocess.run(["docker", "rm", "-f", container_name], capture_output=True)
except Exception:
pass
def search_last_log(
*,
regex_pattern: str,
last_log_path: str,
) -> str:
"""Search the last test log for *regex_pattern*."""
matches: list[tuple[int, str]] = []
try:
ptn = re.compile(regex_pattern)
with open(last_log_path, "r", newline="") as f:
for i, line in enumerate(f, start=1):
line = line.rstrip()
if ptn.search(line):
matches.append((i, line))
if not matches:
return f'No matches found for "{regex_pattern}" in the last test log.'
num_matches = len(matches)
if num_matches > 50:
return (
f'More than 50 matches found for "{regex_pattern}" in the last test log.'
" Please write a more specific query to narrow down the results."
)
result_str = f'Found {len(matches)} match(es) for "{regex_pattern}" in the last test log:\n\n'
result_str += "\n".join(f"{i}: {line}" for i, line in matches[:50])
return result_str
except Exception:
return f'Search failed for "{regex_pattern}" in the last test log.'
def view_last_log(
*,
line_no: int,
last_log_path: str,
) -> str:
"""Show ±50 lines around *line_no* in the last test log."""
start_line, end_line = max(1, line_no - 50), line_no + 50
target_lines: list[str] = []
try:
with open(last_log_path, "r", newline="") as f:
for i, line in enumerate(f, start=1):
if start_line <= i <= end_line:
target_lines.append(line)
log_content = "".join(
f"{i}: {line}" for i, line in enumerate(target_lines, start=start_line)
).rstrip()
return (
f"Here is the content of the last test log around line #{line_no}.\n\n"
f"\n{log_content}\n"
)
except Exception:
return "Error opening the last test log."