""" Coding Agent Branch Manager Manages alternative trajectory branches for coding agent sessions. Each branch is backed by a git branch, enabling independent file states and conversation histories. Branch model: main ────○──○──○──○──○──○ (original trajectory) │ └── branch-1 ──○──○──○ (replayed with new instructions) │ └── branch-2 ──○──○ (edited action) """ import logging import os import subprocess import time import uuid from dataclasses import dataclass, field from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) @dataclass class TrajectoryBranch: """A single branch in the trajectory tree.""" branch_id: str parent_branch_id: Optional[str] branch_point_step: Optional[int] # step where this diverges from parent turns: List[Dict[str, Any]] git_branch: str status: str = "active" # active, completed, abandoned created_at: float = 0.0 instructions: Optional[str] = None edited_actions: Optional[List[Dict]] = None def to_dict(self) -> dict: return { "branch_id": self.branch_id, "parent_branch_id": self.parent_branch_id, "branch_point_step": self.branch_point_step, "turns": self.turns, "git_branch": self.git_branch, "status": self.status, "created_at": self.created_at, "instructions": self.instructions, "edited_actions": self.edited_actions, "turn_count": len(self.turns), } class BranchManager: """Manages trajectory branches for a coding agent session.""" def __init__(self, session_id: str, working_dir: str): self._session_id = session_id self._working_dir = os.path.abspath(working_dir) self._branches: Dict[str, TrajectoryBranch] = {} self._active_branch_id: Optional[str] = None # Create the main branch main = TrajectoryBranch( branch_id="main", parent_branch_id=None, branch_point_step=None, turns=[], git_branch=f"potato-agent-{session_id[:12]}", created_at=time.time(), ) self._branches["main"] = main self._active_branch_id = "main" @property def active_branch(self) -> TrajectoryBranch: return self._branches[self._active_branch_id] @property def active_branch_id(self) -> str: return self._active_branch_id def create_branch(self, parent_branch_id: str, branch_point_step: int, instructions: Optional[str] = None, edited_actions: Optional[List[Dict]] = None) -> TrajectoryBranch: """Create a new branch from a parent at a given step. Args: parent_branch_id: ID of the parent branch branch_point_step: Step index where the branch diverges instructions: Optional user instructions for the new branch edited_actions: Optional modified tool calls to execute Returns: The new TrajectoryBranch """ parent = self._branches.get(parent_branch_id) if not parent: raise ValueError(f"Parent branch '{parent_branch_id}' not found") branch_id = f"branch-{len(self._branches)}" git_branch = f"potato-agent-{self._session_id[:8]}-{branch_id}" # Create git branch from parent's state at branch_point_step try: # First, ensure we're on the parent branch self._run_git("checkout", parent.git_branch) # Find the commit at branch_point_step # We use git log to find commits with [potato] step=N log = self._run_git("log", "--oneline", "--all") target_commit = None for line in log.strip().split("\n"): if f"step={branch_point_step}" in line: target_commit = line.split()[0] break if target_commit: self._run_git("checkout", "-b", git_branch, target_commit) else: # Fallback: branch from current HEAD self._run_git("checkout", "-b", git_branch) logger.warning(f"Could not find commit for step {branch_point_step}, branching from HEAD") except subprocess.CalledProcessError as e: logger.error(f"Failed to create git branch: {e}") # Create branch without git backing git_branch = parent.git_branch # Copy turns up to branch point branch_turns = list(parent.turns[:branch_point_step + 1]) branch = TrajectoryBranch( branch_id=branch_id, parent_branch_id=parent_branch_id, branch_point_step=branch_point_step, turns=branch_turns, git_branch=git_branch, created_at=time.time(), instructions=instructions, edited_actions=edited_actions, ) self._branches[branch_id] = branch self._active_branch_id = branch_id logger.info(f"Created branch {branch_id} from {parent_branch_id} at step {branch_point_step}") return branch def switch_branch(self, branch_id: str) -> bool: """Switch to a different branch.""" if branch_id not in self._branches: return False branch = self._branches[branch_id] try: self._run_git("checkout", branch.git_branch) except subprocess.CalledProcessError as e: logger.warning(f"Failed to switch git branch: {e}") self._active_branch_id = branch_id logger.info(f"Switched to branch {branch_id}") return True def add_turn_to_active(self, turn: Dict[str, Any]) -> None: """Add a turn to the active branch.""" self.active_branch.turns.append(turn) def get_branch(self, branch_id: str) -> Optional[TrajectoryBranch]: return self._branches.get(branch_id) def list_branches(self) -> List[dict]: return [b.to_dict() for b in self._branches.values()] def get_branch_tree(self) -> dict: """Return tree structure for UI rendering.""" tree = {} for bid, branch in self._branches.items(): tree[bid] = { "branch_id": bid, "parent": branch.parent_branch_id, "branch_point": branch.branch_point_step, "turns": len(branch.turns), "status": branch.status, "instructions": branch.instructions, "is_active": bid == self._active_branch_id, } return tree def save_all(self) -> dict: """Serialize all branches for trace export.""" return { bid: branch.to_dict() for bid, branch in self._branches.items() } def _run_git(self, *args) -> str: result = subprocess.run( ["git"] + list(args), cwd=self._working_dir, capture_output=True, text=True, timeout=30, ) if result.returncode != 0: raise subprocess.CalledProcessError( result.returncode, ["git"] + list(args), output=result.stdout, stderr=result.stderr, ) return result.stdout