Spaces:
Paused
Paused
| """ | |
| Coding Agent Checkpoint Manager | |
| Git-based checkpointing for coding agent sessions. Creates lightweight | |
| commits after each file-modifying tool call, enabling rollback to any | |
| previous step. | |
| Uses a dedicated git branch (potato-agent-<session_id>) to avoid | |
| interfering with the user's branches. | |
| """ | |
| import logging | |
| import os | |
| import subprocess | |
| import time | |
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Optional | |
| logger = logging.getLogger(__name__) | |
| class Checkpoint: | |
| """A snapshot of the working directory state.""" | |
| checkpoint_id: str # git commit hash | |
| step_index: int | |
| tool_name: str | |
| description: str | |
| timestamp: float | |
| files_changed: List[str] = field(default_factory=list) | |
| def to_dict(self) -> dict: | |
| return { | |
| "checkpoint_id": self.checkpoint_id, | |
| "step_index": self.step_index, | |
| "tool_name": self.tool_name, | |
| "description": self.description, | |
| "timestamp": self.timestamp, | |
| "files_changed": self.files_changed, | |
| } | |
| class CheckpointManager: | |
| """Manages git-based checkpoints for a coding agent session.""" | |
| def __init__(self, working_dir: str, session_id: str): | |
| self._working_dir = os.path.abspath(working_dir) | |
| self._session_id = session_id | |
| self._branch_name = f"potato-agent-{session_id[:12]}" | |
| self._checkpoints: List[Checkpoint] = [] | |
| self._initialized = False | |
| def checkpoints(self) -> List[Checkpoint]: | |
| return list(self._checkpoints) | |
| def init(self) -> bool: | |
| """Initialize git repo and create session branch. | |
| Returns True if initialization succeeded. | |
| """ | |
| if self._initialized: | |
| return True | |
| # Ensure git repo exists | |
| if not self._is_git_repo(): | |
| try: | |
| self._run_git("init") | |
| self._run_git("add", "-A") | |
| self._run_git("commit", "--allow-empty", "-m", "[potato] init") | |
| except Exception as e: | |
| logger.warning(f"Failed to init git repo: {e}") | |
| return False | |
| # Create session branch from current HEAD | |
| try: | |
| current_branch = self._run_git("rev-parse", "--abbrev-ref", "HEAD").strip() | |
| self._run_git("checkout", "-b", self._branch_name) | |
| except subprocess.CalledProcessError: | |
| # Branch might already exist (session restart) | |
| try: | |
| self._run_git("checkout", self._branch_name) | |
| except subprocess.CalledProcessError as e: | |
| logger.warning(f"Failed to create/checkout session branch: {e}") | |
| return False | |
| # Create initial checkpoint | |
| try: | |
| self._run_git("add", "-A") | |
| self._run_git("commit", "--allow-empty", "-m", | |
| f"[potato] session start {self._session_id[:8]}") | |
| commit_hash = self._get_head_hash() | |
| self._checkpoints.append(Checkpoint( | |
| checkpoint_id=commit_hash, | |
| step_index=-1, | |
| tool_name="init", | |
| description="Session start", | |
| timestamp=time.time(), | |
| )) | |
| except Exception as e: | |
| logger.warning(f"Failed to create initial checkpoint: {e}") | |
| self._initialized = True | |
| logger.info(f"CheckpointManager initialized on branch {self._branch_name}") | |
| return True | |
| def create_checkpoint(self, step_index: int, tool_name: str, | |
| description: str = "") -> Optional[str]: | |
| """Create a checkpoint after a tool execution. | |
| Returns the commit hash, or None if no changes to commit. | |
| """ | |
| if not self._initialized: | |
| if not self.init(): | |
| return None | |
| try: | |
| # Stage all changes | |
| self._run_git("add", "-A") | |
| # Check if there are changes to commit | |
| status = self._run_git("status", "--porcelain") | |
| if not status.strip(): | |
| # No changes, but still record the checkpoint for rollback | |
| commit_hash = self._get_head_hash() | |
| self._checkpoints.append(Checkpoint( | |
| checkpoint_id=commit_hash, | |
| step_index=step_index, | |
| tool_name=tool_name, | |
| description=description or f"Step {step_index}: {tool_name}", | |
| timestamp=time.time(), | |
| )) | |
| return commit_hash | |
| # Get list of changed files | |
| changed = [ | |
| line.split(None, 1)[-1].strip() | |
| for line in status.strip().split("\n") | |
| if line.strip() | |
| ] | |
| # Commit | |
| msg = f"[potato] step={step_index} tool={tool_name}" | |
| if description: | |
| msg += f" {description}" | |
| self._run_git("commit", "-m", msg) | |
| commit_hash = self._get_head_hash() | |
| checkpoint = Checkpoint( | |
| checkpoint_id=commit_hash, | |
| step_index=step_index, | |
| tool_name=tool_name, | |
| description=description or f"Step {step_index}: {tool_name}", | |
| timestamp=time.time(), | |
| files_changed=changed, | |
| ) | |
| self._checkpoints.append(checkpoint) | |
| logger.debug(f"Created checkpoint {commit_hash[:8]} at step {step_index}") | |
| return commit_hash | |
| except Exception as e: | |
| logger.warning(f"Failed to create checkpoint: {e}") | |
| return None | |
| def rollback_to(self, step_index: int) -> bool: | |
| """Rollback to the checkpoint at the given step index. | |
| Returns True if rollback succeeded. | |
| """ | |
| # Find the checkpoint | |
| target = None | |
| for cp in self._checkpoints: | |
| if cp.step_index == step_index: | |
| target = cp | |
| break | |
| if cp.step_index <= step_index: | |
| target = cp # Use the latest checkpoint at or before step_index | |
| if not target: | |
| logger.warning(f"No checkpoint found at or before step {step_index}") | |
| return False | |
| try: | |
| self._run_git("reset", "--hard", target.checkpoint_id) | |
| # Truncate checkpoint list | |
| self._checkpoints = [ | |
| cp for cp in self._checkpoints | |
| if cp.step_index <= step_index | |
| ] | |
| logger.info(f"Rolled back to step {step_index} (commit {target.checkpoint_id[:8]})") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Rollback failed: {e}") | |
| return False | |
| def get_diff_between(self, from_step: int, to_step: int) -> str: | |
| """Get the git diff between two checkpoints.""" | |
| from_cp = self._find_checkpoint(from_step) | |
| to_cp = self._find_checkpoint(to_step) | |
| if not from_cp or not to_cp: | |
| return "" | |
| try: | |
| return self._run_git("diff", from_cp.checkpoint_id, to_cp.checkpoint_id) | |
| except Exception: | |
| return "" | |
| def get_diff_since(self, step_index: int) -> str: | |
| """Get the diff from a checkpoint to current HEAD.""" | |
| cp = self._find_checkpoint(step_index) | |
| if not cp: | |
| return "" | |
| try: | |
| return self._run_git("diff", cp.checkpoint_id, "HEAD") | |
| except Exception: | |
| return "" | |
| def get_file_at(self, step_index: int, file_path: str) -> Optional[str]: | |
| """Get file contents at a specific checkpoint.""" | |
| cp = self._find_checkpoint(step_index) | |
| if not cp: | |
| return None | |
| try: | |
| return self._run_git("show", f"{cp.checkpoint_id}:{file_path}") | |
| except Exception: | |
| return None | |
| def list_checkpoints(self) -> List[dict]: | |
| """Return checkpoint metadata as list of dicts.""" | |
| return [cp.to_dict() for cp in self._checkpoints] | |
| def cleanup(self) -> None: | |
| """Clean up the session branch.""" | |
| if not self._initialized: | |
| return | |
| try: | |
| # Switch back to the original branch | |
| branches = self._run_git("branch", "--list").strip().split("\n") | |
| main_branch = None | |
| for b in branches: | |
| name = b.strip().lstrip("* ") | |
| if name and name != self._branch_name: | |
| main_branch = name | |
| break | |
| if main_branch: | |
| self._run_git("checkout", main_branch) | |
| self._run_git("branch", "-D", self._branch_name) | |
| logger.info(f"Cleaned up session branch {self._branch_name}") | |
| except Exception as e: | |
| logger.warning(f"Failed to clean up session branch: {e}") | |
| def _find_checkpoint(self, step_index: int) -> Optional[Checkpoint]: | |
| for cp in self._checkpoints: | |
| if cp.step_index == step_index: | |
| return cp | |
| return None | |
| def _is_git_repo(self) -> bool: | |
| try: | |
| self._run_git("rev-parse", "--git-dir") | |
| return True | |
| except (subprocess.CalledProcessError, FileNotFoundError): | |
| return False | |
| def _get_head_hash(self) -> str: | |
| return self._run_git("rev-parse", "HEAD").strip() | |
| def _run_git(self, *args) -> str: | |
| result = subprocess.run( | |
| ["git"] + list(args), | |
| cwd=self._working_dir, | |
| capture_output=True, | |
| text=True, | |
| timeout=30, | |
| ) | |
| if result.returncode != 0: | |
| raise subprocess.CalledProcessError( | |
| result.returncode, ["git"] + list(args), | |
| output=result.stdout, stderr=result.stderr, | |
| ) | |
| return result.stdout | |