""" CodeWorkspace — read-only filesystem + fake-git layer for Phase 2 exploration. Each scenario points `CodeContext.repo_snapshot_path` at a directory under `snapshots/`. That directory contains: snapshots// tree/ ← actual source files the agent reads /.py ... git_log.json ← list of commits (sha, author, date, message, files[]) diffs/.patch ← unified diff for that commit (any file path) This is a "fake git" by design — it's tighter, deterministic, and trivially serializable for trajectory replay. No subprocess, no real .git directory. CodeWorkspace is constructed at the start of Phase 2 and lives for the rest of the episode. It exposes safe, sandboxed file access (no `..`, no absolute paths, no symlinks). """ from __future__ import annotations import fnmatch import json import os from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional @dataclass class CommitRecord: sha: str author: str date: str # ISO-ish string message: str files: List[str] class CodeWorkspaceError(Exception): """Raised on illegal path access or missing files (returned to agent).""" class CodeWorkspace: """ Sandboxed read-only view over a snapshot. All paths are interpreted as relative to `tree/` inside the snapshot. Access to anything outside `tree/` raises CodeWorkspaceError. """ MAX_FILE_BYTES = 64 * 1024 # truncate large files in read_file() MAX_LIST_ENTRIES = 200 MAX_SEARCH_HITS = 50 MAX_SEARCH_BYTES = 5 * 1024 * 1024 # don't grep monsters def __init__(self, snapshot_root: str, bad_commit_sha: str = ""): root = Path(snapshot_root).resolve() if not root.exists(): raise CodeWorkspaceError(f"Snapshot not found: {snapshot_root}") self.root = root self.tree_root = (root / "tree").resolve() if not self.tree_root.exists(): raise CodeWorkspaceError( f"Snapshot {snapshot_root} missing tree/ subdir") self.bad_commit_sha = bad_commit_sha self._git_log: Optional[List[CommitRecord]] = None self._diffs_root = (root / "diffs").resolve() # ------------------------------------------------------------------ # Public file-system API (1 method per agent action_type) # ------------------------------------------------------------------ def list_dir(self, path: str = ".") -> Dict[str, Any]: """List files + subdirs at a relative path under tree/.""" target = self._resolve_tree(path) if not target.is_dir(): raise CodeWorkspaceError(f"Not a directory: {path}") entries = [] for child in sorted(target.iterdir()): if child.name.startswith("."): continue entries.append({ "name": child.name, "type": "dir" if child.is_dir() else "file", "size": child.stat().st_size if child.is_file() else None, }) if len(entries) >= self.MAX_LIST_ENTRIES: break return { "path": self._rel_to_tree(target), "entries": entries, "count": len(entries), } def read_file(self, path: str) -> Dict[str, Any]: """Read a file under tree/. Truncates if larger than MAX_FILE_BYTES.""" target = self._resolve_tree(path) if not target.is_file(): raise CodeWorkspaceError(f"Not a file: {path}") data = target.read_bytes() truncated = False if len(data) > self.MAX_FILE_BYTES: data = data[: self.MAX_FILE_BYTES] truncated = True try: text = data.decode("utf-8") except UnicodeDecodeError: text = data.decode("utf-8", errors="replace") return { "path": self._rel_to_tree(target), "content": text, "size": target.stat().st_size, "truncated": truncated, } def search_code( self, query: str, file_pattern: str = "*.py", max_hits: Optional[int] = None, ) -> Dict[str, Any]: """ Substring search across files matching `file_pattern` under tree/. Returns up to `max_hits` (or MAX_SEARCH_HITS) hits with line context. """ if not query: return {"query": query, "hits": [], "count": 0} cap = min(max_hits or self.MAX_SEARCH_HITS, self.MAX_SEARCH_HITS) hits: List[Dict[str, Any]] = [] bytes_scanned = 0 for fp in self._iter_tree_files(file_pattern): try: text = fp.read_text("utf-8", errors="replace") except OSError: continue bytes_scanned += len(text) if bytes_scanned > self.MAX_SEARCH_BYTES: break for ln, line in enumerate(text.splitlines(), 1): if query in line: hits.append({ "path": self._rel_to_tree(fp), "line": ln, "match": line.strip()[:240], }) if len(hits) >= cap: return {"query": query, "hits": hits, "count": len(hits), "truncated": True} return {"query": query, "hits": hits, "count": len(hits), "truncated": False} def get_git_log( self, path: str = "", n_commits: int = 10, ) -> Dict[str, Any]: """ Return up to `n_commits` commits from the snapshot's pre-baked git_log. If `path` is provided, filters to commits that touched a file matching that exact path (or that path's directory). """ log = self._load_git_log() if path: target = path.strip("/") log = [c for c in log if any( f == target or f.startswith(target.rstrip("/") + "/") for f in c.files )] log = log[: max(1, n_commits)] return { "path": path or ".", "commits": [ {"sha": c.sha, "author": c.author, "date": c.date, "message": c.message, "files": list(c.files)} for c in log ], "count": len(log), } def get_file_diff( self, commit_sha: str, path: str = "", ) -> Dict[str, Any]: """ Return the unified diff for `commit_sha`, optionally filtered to hunks touching files matching `path`. """ diff_path = (self._diffs_root / f"{commit_sha}.patch") try: diff_path = diff_path.resolve() if not str(diff_path).startswith(str(self._diffs_root)): raise CodeWorkspaceError(f"Illegal diff path: {commit_sha}") except OSError: raise CodeWorkspaceError(f"Diff not found for {commit_sha}") if not diff_path.exists(): raise CodeWorkspaceError(f"Diff not found for {commit_sha}") text = diff_path.read_text("utf-8", errors="replace") if path: text = self._filter_diff_by_path(text, path) return { "commit_sha": commit_sha, "path": path or "*", "diff": text, } # ------------------------------------------------------------------ # Lightweight introspection (used to seed the code agent at handoff) # ------------------------------------------------------------------ def file_tree(self, max_depth: int = 3) -> List[str]: """Flat list of files under tree/, capped to a sane depth.""" out: List[str] = [] for fp in self._iter_tree_files("*", max_depth=max_depth): out.append(self._rel_to_tree(fp)) if len(out) >= self.MAX_LIST_ENTRIES: break return sorted(out) def bad_commit_metadata(self) -> Optional[Dict[str, Any]]: """Return commit metadata for `bad_commit_sha` (without the diff).""" if not self.bad_commit_sha: return None for c in self._load_git_log(): if c.sha.startswith(self.bad_commit_sha) or self.bad_commit_sha.startswith(c.sha): return {"sha": c.sha, "author": c.author, "date": c.date, "message": c.message, "files": list(c.files)} return None # ------------------------------------------------------------------ # Internals # ------------------------------------------------------------------ def _load_git_log(self) -> List[CommitRecord]: if self._git_log is not None: return self._git_log path = self.root / "git_log.json" if not path.exists(): self._git_log = [] return self._git_log raw = json.loads(path.read_text("utf-8")) self._git_log = [ CommitRecord( sha = c["sha"], author = c.get("author", "unknown"), date = c.get("date", ""), message = c.get("message", ""), files = list(c.get("files", [])), ) for c in raw ] return self._git_log def _resolve_tree(self, path: str) -> Path: """Resolve a user-supplied relative path under tree/, blocking escapes.""" cleaned = (path or ".").lstrip("/").lstrip(os.sep) if cleaned in ("", "."): return self.tree_root target = (self.tree_root / cleaned).resolve() if not str(target).startswith(str(self.tree_root)): raise CodeWorkspaceError(f"Illegal path (escapes sandbox): {path}") if not target.exists(): raise CodeWorkspaceError(f"Path not found: {path}") return target def _rel_to_tree(self, p: Path) -> str: try: return str(p.relative_to(self.tree_root)) or "." except ValueError: return str(p) def _iter_tree_files(self, pattern: str, max_depth: int = 16): """Yield Paths under tree/ matching pattern (glob-style).""" for dirpath, dirnames, filenames in os.walk(self.tree_root): depth = Path(dirpath).relative_to(self.tree_root).parts if len(depth) > max_depth: dirnames[:] = [] continue dirnames[:] = [d for d in dirnames if not d.startswith(".")] for fname in filenames: if fname.startswith("."): continue if pattern == "*" or fnmatch.fnmatch(fname, pattern): yield Path(dirpath) / fname @staticmethod def _filter_diff_by_path(diff: str, path: str) -> str: """Return only diff hunks where the +++ b/ matches `path`.""" out: List[str] = [] keep = False target = path.strip("/") for line in diff.split("\n"): if line.startswith("diff --git ") or line.startswith("--- a/") or line.startswith("+++ b/"): if line.startswith("+++ b/"): file_in_hunk = line[6:].strip() keep = (file_in_hunk == target or file_in_hunk.startswith(target.rstrip("/") + "/")) if line.startswith("diff --git "): # carry through; we'll re-evaluate at +++ pass if keep or line.startswith("diff --git "): out.append(line) return "\n".join(out)