""" TeamForge Environment Full OpenEnv-compliant environment simulating an autonomous software team. Interface: env = TeamForgeEnv() obs = env.reset(task_id) obs = env.step(action) state = env.state() """ from __future__ import annotations import re import subprocess import sys import time from pathlib import Path from typing import Any, Dict, List, Optional from models import ( Action, ActionStatus, Commit, EditFile, EpisodeResult, FileSnapshot, GenerateReview, LintResult, Observation, PhaseState, PlanStep, ReflectionArtifact, RequestIteration, ReviewArtifact, RunLint, RunTests, SelfReflect, TaskDifficulty, TestResult, ) from sandbox.git_sandbox import GitSandbox from tasks.task_registry import get_task from reward import RewardCalculator from grader import grade_episode class TeamForgeEnv: """ OpenEnv-compliant environment for autonomous software team simulation. An episode represents one attempt to complete a software engineering task. The agent issues structured actions; the environment executes them against a real Git repository and returns observations with dense rewards. """ def __init__(self, log_dir: Optional[str] = None): self._sandbox = GitSandbox() self._reward_calc = RewardCalculator() self._obs: Optional[Observation] = None self._task_module: Any = None self._log_dir = log_dir self._logs: List[str] = [] # Episode state self._step_number = 0 self._cumulative_reward = 0.001 self._plan: List[PlanStep] = [] self._reviews: List[ReviewArtifact] = [] self._reflections: List[ReflectionArtifact] = [] self._last_test_result: Optional[TestResult] = None self._last_lint_result: Optional[LintResult] = None # ───────────────────────────────────────────── # OpenEnv INTERFACE # ───────────────────────────────────────────── def reset(self, task_id: str) -> Observation: """ Start a new episode for the given task. Tears down any previous sandbox and initialises a fresh git repo. Args: task_id: One of easy_bugfix_chunk_list | medium_refactor_stats | hard_lru_cache_performance Returns: Initial observation with full repo snapshot. """ self._log(f"[START] task={task_id}") # Clean up previous episode self._sandbox.teardown() self._sandbox = GitSandbox() # Load task self._task_module = get_task(task_id) self._reward_calc = RewardCalculator() # Detect test files and register with reward calculator test_files = [ p for p in self._task_module.INITIAL_FILES if "test" in p.lower() ] self._reward_calc.set_test_files(test_files) # Reset episode state self._step_number = 0 self._cumulative_reward = 0.1 self._plan = [] self._reviews = [] self._reflections = [] self._last_test_result = None self._last_lint_result = None self._logs = [f"[START] task={task_id}"] # Initialise git sandbox with task files self._sandbox.init(self._task_module.INITIAL_FILES) # Build initial observation self._obs = self._build_observation( action_type=None, status=ActionStatus.SUCCESS, output="Environment initialized.", reward=0.1, done=False, ) return self._obs def step(self, action: Action) -> Observation: """ Execute one action and return the resulting observation. Args: action: A typed Action model (PlanStep, EditFile, RunTests, …) Returns: Updated Observation with reward, done flag, and all state. """ if self._obs is None: raise RuntimeError("Call reset() before step()") self._step_number += 1 action_type = action.type self._log(f"[STEP {self._step_number}] action={action_type}") # ── Max steps guard ── max_steps = self._task_module.MAX_STEPS if self._step_number > max_steps: return self._finalize(reason="Max steps exceeded") # ── Dispatch action ── status = ActionStatus.SUCCESS output = "" edited_file: Optional[str] = None tests_passed: Optional[int] = None lint_violations: Optional[int] = None try: if isinstance(action, PlanStep): output = self._handle_plan_step(action) elif isinstance(action, EditFile): output, edited_file = self._handle_edit_file(action) elif isinstance(action, RunTests): output = self._handle_run_tests(action) tests_passed = (self._last_test_result.passed if self._last_test_result else 0) elif isinstance(action, RunLint): output = self._handle_run_lint(action) lint_violations = (self._last_lint_result.violations if self._last_lint_result else 0) elif isinstance(action, GenerateReview): output = self._handle_generate_review(action) elif isinstance(action, Commit): output = self._handle_commit(action) elif isinstance(action, SelfReflect): output = self._handle_self_reflect(action) elif isinstance(action, RequestIteration): output = self._handle_request_iteration(action) else: status = ActionStatus.FAILURE output = f"Unknown action type: {action_type}" except Exception as exc: status = ActionStatus.FAILURE output = f"Action failed with exception: {exc}" self._log(f"[ERROR] {exc}") # ── Compute reward ── reward = self._reward_calc.compute( action_type=action_type, action_success=(status == ActionStatus.SUCCESS), action_output=output, tests_passed=tests_passed, lint_violations=lint_violations, edited_file=edited_file, ) self._cumulative_reward += reward # ── Check done conditions ── done = self._check_done() self._log(f"[STEP {self._step_number}] reward={reward:.4f} done={done}") self._obs = self._build_observation( action_type=action_type, status=status, output=output, reward=reward, done=done, ) return self._obs def state(self) -> Dict[str, Any]: """ Return current environment state as a plain dict. Useful for serialisation and logging. """ if self._obs is None: return {"status": "not_started"} return { "task_id": self._obs.task_id, "step": self._step_number, "phase": self._obs.phase.value, "cumulative_reward": self._cumulative_reward, "tests_passed": (self._last_test_result.passed if self._last_test_result else 0), "tests_failed": (self._last_test_result.failed if self._last_test_result else 0), "lint_violations": (self._last_lint_result.violations if self._last_lint_result else 0), "commits": len(self._sandbox.get_log()), "plan_steps": len(self._plan), "reviews": len(self._reviews), "reflections": len(self._reflections), "done": self._obs.done, } def grade(self) -> EpisodeResult: """Run the deterministic grader and return an EpisodeResult.""" required_kw = getattr( self._task_module, "REQUIRED_KEYWORDS_IN_REVIEW", [] ) return grade_episode( repo_path=str(self._sandbox.repo_path), task_id=self._task_module.TASK_ID, total_steps=self._step_number, max_steps=self._task_module.MAX_STEPS, reviews=self._reviews, reflections=self._reflections, required_keywords=required_kw, ) # ───────────────────────────────────────────── # ACTION HANDLERS # ───────────────────────────────────────────── def _handle_plan_step(self, action: PlanStep) -> str: self._plan.append(action) return ( f"Plan step {action.step_number} recorded: {action.description} " f"[effort={action.estimated_effort}]" ) def _handle_edit_file(self, action: EditFile) -> tuple[str, str]: self._sandbox.write_file(action.file_path, action.content) size = len(action.content.encode()) return ( f"Wrote {size} bytes to {action.file_path}. Reason: {action.reason}", action.file_path, ) def _handle_run_tests(self, action: RunTests) -> str: cmd = [ sys.executable, "-m", "pytest", "--tb=short", "-q", "--no-header", f"--timeout={action.timeout_seconds}", ] if action.test_path: cmd.append(action.test_path) start = time.perf_counter() result = subprocess.run( cmd, cwd=str(self._sandbox.repo_path), capture_output=True, text=True, timeout=action.timeout_seconds + 5, ) elapsed = time.perf_counter() - start output = result.stdout + result.stderr passed = failed = errors = 0 m_p = re.search(r"(\d+) passed", output) m_f = re.search(r"(\d+) failed", output) m_e = re.search(r"(\d+) error", output) if m_p: passed = int(m_p.group(1)) if m_f: failed = int(m_f.group(1)) if m_e: errors = int(m_e.group(1)) self._last_test_result = TestResult( passed=passed, failed=failed, errors=errors, output=output[:2000], duration_seconds=elapsed, ) return output[:2000] def _handle_run_lint(self, action: RunLint) -> str: cmd = [sys.executable, "-m", "ruff", "check"] if action.fix: cmd.append("--fix") if action.file_path: cmd.append(action.file_path) else: cmd.append(".") result = subprocess.run( cmd, cwd=str(self._sandbox.repo_path), capture_output=True, text=True, ) output = result.stdout + result.stderr violations = len([ ln for ln in output.splitlines() if re.match(r".+:\d+:\d+:", ln) ]) score = max(0.001, min(0.999, 1.0 - violations * 0.05)) self._last_lint_result = LintResult( violations=violations, output=output[:2000], score=score, ) return output[:2000] or "No lint violations found." def _handle_generate_review(self, action: GenerateReview) -> str: review = ReviewArtifact( reviewer="agent", focus_areas=action.focus_areas, text=action.review_text, timestamp_step=self._step_number, ) self._reviews.append(review) return f"Review recorded ({len(action.review_text)} chars). Focus: {action.focus_areas}" def _handle_commit(self, action: Commit) -> str: if not self._sandbox.has_changes(): return "Nothing to commit. Working tree clean." sha = self._sandbox.commit( message=action.message, files=action.files if action.files else None, ) if sha: return f"Committed: {sha} — {action.message}" return "Commit failed (possibly nothing to stage)." def _handle_self_reflect(self, action: SelfReflect) -> str: reflection = ReflectionArtifact( step=self._step_number, what_went_well=action.what_went_well, what_to_improve=action.what_to_improve, adjusted_plan=action.adjusted_plan, ) self._reflections.append(reflection) return ( f"Reflection recorded at step {self._step_number}. " f"Improving: {action.what_to_improve[:80]}" ) def _handle_request_iteration(self, action: RequestIteration) -> str: issues = ", ".join(action.target_issues) if action.target_issues else "none specified" return f"Iteration requested: {action.reason} | Issues: {issues}" # ───────────────────────────────────────────── # HELPERS # ───────────────────────────────────────────── def _check_done(self) -> bool: """Episode is done if all tests pass and lint is clean.""" if self._last_test_result is None: return False tests_ok = ( self._last_test_result.failed == 0 and self._last_test_result.errors == 0 and self._last_test_result.passed > 0 ) lint_ok = ( self._last_lint_result is None or self._last_lint_result.violations == 0 ) committed = len(self._sandbox.get_log()) > 1 # beyond initial commit return tests_ok and lint_ok and committed def _finalize(self, reason: str) -> Observation: self._log(f"[END] {reason}") self._obs = self._build_observation( action_type=None, status=ActionStatus.FAILURE, output=reason, reward=0.001, done=True, ) return self._obs def _build_observation( self, action_type: Optional[str], status: ActionStatus, output: str, reward: float, done: bool, ) -> Observation: """Assemble a full Observation from current environment state.""" # Repo files snapshot (only .py, .md, .toml — cap at 8 files) all_files = self._sandbox.get_all_files() snapshots = [ FileSnapshot( path=p, content=c[:3000], # truncate large files size_bytes=len(c.encode()), ) for p, c in list(all_files.items())[:12] ] # Determine phase phase = self._infer_phase() return Observation( task_id=self._task_module.TASK_ID, task_description=self._task_module.DESCRIPTION, difficulty=TaskDifficulty(self._task_module.DIFFICULTY), step_number=self._step_number, max_steps=self._task_module.MAX_STEPS, phase=phase, repo_files=snapshots, git_log=self._sandbox.get_log(n=5), last_action_type=action_type, last_action_status=status, last_action_output=output, test_results=self._last_test_result, lint_results=self._last_lint_result, plan=self._plan, reviews=self._reviews, reflections=self._reflections, reward=reward, cumulative_reward=self._cumulative_reward, done=done, info={ "sandbox_path": str(self._sandbox.repo_path), "task_difficulty": self._task_module.DIFFICULTY, }, ) def _infer_phase(self) -> PhaseState: if self._step_number == 0: return PhaseState.PLANNING if self._plan and not self._last_test_result: return PhaseState.CODING if self._last_test_result and self._last_test_result.failed > 0: return PhaseState.TESTING if self._last_test_result and self._last_test_result.failed == 0 and not self._reviews: return PhaseState.REVIEWING if self._reviews and not self._reflections: return PhaseState.REFLECTING if self._obs and self._obs.done: return PhaseState.DONE return PhaseState.CODING def _log(self, msg: str) -> None: self._logs.append(msg) print(msg)