codebook / potato /coding_agent_checkpoint.py
davidjurgens's picture
Deploy: Potato — Codebook Annotation
aceb1b2 verified
Raw
History Blame Contribute Delete
9.85 kB
"""
Coding Agent Checkpoint Manager
Git-based checkpointing for coding agent sessions. Creates lightweight
commits after each file-modifying tool call, enabling rollback to any
previous step.
Uses a dedicated git branch (potato-agent-<session_id>) to avoid
interfering with the user's branches.
"""
import logging
import os
import subprocess
import time
from dataclasses import dataclass, field
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
@dataclass
class Checkpoint:
"""A snapshot of the working directory state."""
checkpoint_id: str # git commit hash
step_index: int
tool_name: str
description: str
timestamp: float
files_changed: List[str] = field(default_factory=list)
def to_dict(self) -> dict:
return {
"checkpoint_id": self.checkpoint_id,
"step_index": self.step_index,
"tool_name": self.tool_name,
"description": self.description,
"timestamp": self.timestamp,
"files_changed": self.files_changed,
}
class CheckpointManager:
"""Manages git-based checkpoints for a coding agent session."""
def __init__(self, working_dir: str, session_id: str):
self._working_dir = os.path.abspath(working_dir)
self._session_id = session_id
self._branch_name = f"potato-agent-{session_id[:12]}"
self._checkpoints: List[Checkpoint] = []
self._initialized = False
@property
def checkpoints(self) -> List[Checkpoint]:
return list(self._checkpoints)
def init(self) -> bool:
"""Initialize git repo and create session branch.
Returns True if initialization succeeded.
"""
if self._initialized:
return True
# Ensure git repo exists
if not self._is_git_repo():
try:
self._run_git("init")
self._run_git("add", "-A")
self._run_git("commit", "--allow-empty", "-m", "[potato] init")
except Exception as e:
logger.warning(f"Failed to init git repo: {e}")
return False
# Create session branch from current HEAD
try:
current_branch = self._run_git("rev-parse", "--abbrev-ref", "HEAD").strip()
self._run_git("checkout", "-b", self._branch_name)
except subprocess.CalledProcessError:
# Branch might already exist (session restart)
try:
self._run_git("checkout", self._branch_name)
except subprocess.CalledProcessError as e:
logger.warning(f"Failed to create/checkout session branch: {e}")
return False
# Create initial checkpoint
try:
self._run_git("add", "-A")
self._run_git("commit", "--allow-empty", "-m",
f"[potato] session start {self._session_id[:8]}")
commit_hash = self._get_head_hash()
self._checkpoints.append(Checkpoint(
checkpoint_id=commit_hash,
step_index=-1,
tool_name="init",
description="Session start",
timestamp=time.time(),
))
except Exception as e:
logger.warning(f"Failed to create initial checkpoint: {e}")
self._initialized = True
logger.info(f"CheckpointManager initialized on branch {self._branch_name}")
return True
def create_checkpoint(self, step_index: int, tool_name: str,
description: str = "") -> Optional[str]:
"""Create a checkpoint after a tool execution.
Returns the commit hash, or None if no changes to commit.
"""
if not self._initialized:
if not self.init():
return None
try:
# Stage all changes
self._run_git("add", "-A")
# Check if there are changes to commit
status = self._run_git("status", "--porcelain")
if not status.strip():
# No changes, but still record the checkpoint for rollback
commit_hash = self._get_head_hash()
self._checkpoints.append(Checkpoint(
checkpoint_id=commit_hash,
step_index=step_index,
tool_name=tool_name,
description=description or f"Step {step_index}: {tool_name}",
timestamp=time.time(),
))
return commit_hash
# Get list of changed files
changed = [
line.split(None, 1)[-1].strip()
for line in status.strip().split("\n")
if line.strip()
]
# Commit
msg = f"[potato] step={step_index} tool={tool_name}"
if description:
msg += f" {description}"
self._run_git("commit", "-m", msg)
commit_hash = self._get_head_hash()
checkpoint = Checkpoint(
checkpoint_id=commit_hash,
step_index=step_index,
tool_name=tool_name,
description=description or f"Step {step_index}: {tool_name}",
timestamp=time.time(),
files_changed=changed,
)
self._checkpoints.append(checkpoint)
logger.debug(f"Created checkpoint {commit_hash[:8]} at step {step_index}")
return commit_hash
except Exception as e:
logger.warning(f"Failed to create checkpoint: {e}")
return None
def rollback_to(self, step_index: int) -> bool:
"""Rollback to the checkpoint at the given step index.
Returns True if rollback succeeded.
"""
# Find the checkpoint
target = None
for cp in self._checkpoints:
if cp.step_index == step_index:
target = cp
break
if cp.step_index <= step_index:
target = cp # Use the latest checkpoint at or before step_index
if not target:
logger.warning(f"No checkpoint found at or before step {step_index}")
return False
try:
self._run_git("reset", "--hard", target.checkpoint_id)
# Truncate checkpoint list
self._checkpoints = [
cp for cp in self._checkpoints
if cp.step_index <= step_index
]
logger.info(f"Rolled back to step {step_index} (commit {target.checkpoint_id[:8]})")
return True
except Exception as e:
logger.error(f"Rollback failed: {e}")
return False
def get_diff_between(self, from_step: int, to_step: int) -> str:
"""Get the git diff between two checkpoints."""
from_cp = self._find_checkpoint(from_step)
to_cp = self._find_checkpoint(to_step)
if not from_cp or not to_cp:
return ""
try:
return self._run_git("diff", from_cp.checkpoint_id, to_cp.checkpoint_id)
except Exception:
return ""
def get_diff_since(self, step_index: int) -> str:
"""Get the diff from a checkpoint to current HEAD."""
cp = self._find_checkpoint(step_index)
if not cp:
return ""
try:
return self._run_git("diff", cp.checkpoint_id, "HEAD")
except Exception:
return ""
def get_file_at(self, step_index: int, file_path: str) -> Optional[str]:
"""Get file contents at a specific checkpoint."""
cp = self._find_checkpoint(step_index)
if not cp:
return None
try:
return self._run_git("show", f"{cp.checkpoint_id}:{file_path}")
except Exception:
return None
def list_checkpoints(self) -> List[dict]:
"""Return checkpoint metadata as list of dicts."""
return [cp.to_dict() for cp in self._checkpoints]
def cleanup(self) -> None:
"""Clean up the session branch."""
if not self._initialized:
return
try:
# Switch back to the original branch
branches = self._run_git("branch", "--list").strip().split("\n")
main_branch = None
for b in branches:
name = b.strip().lstrip("* ")
if name and name != self._branch_name:
main_branch = name
break
if main_branch:
self._run_git("checkout", main_branch)
self._run_git("branch", "-D", self._branch_name)
logger.info(f"Cleaned up session branch {self._branch_name}")
except Exception as e:
logger.warning(f"Failed to clean up session branch: {e}")
def _find_checkpoint(self, step_index: int) -> Optional[Checkpoint]:
for cp in self._checkpoints:
if cp.step_index == step_index:
return cp
return None
def _is_git_repo(self) -> bool:
try:
self._run_git("rev-parse", "--git-dir")
return True
except (subprocess.CalledProcessError, FileNotFoundError):
return False
def _get_head_hash(self) -> str:
return self._run_git("rev-parse", "HEAD").strip()
def _run_git(self, *args) -> str:
result = subprocess.run(
["git"] + list(args),
cwd=self._working_dir,
capture_output=True,
text=True,
timeout=30,
)
if result.returncode != 0:
raise subprocess.CalledProcessError(
result.returncode, ["git"] + list(args),
output=result.stdout, stderr=result.stderr,
)
return result.stdout