teamforge / environment.py
Your Name
fix(OpenEnv): implement system-wide [0.1, 0.9] boundary scrub for Phase 2 compliance
efa2d2a
"""
TeamForge Environment
Full OpenEnv-compliant environment simulating an autonomous software team.
Interface:
env = TeamForgeEnv()
obs = env.reset(task_id)
obs = env.step(action)
state = env.state()
"""
from __future__ import annotations
import re
import subprocess
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Optional
from models import (
Action,
ActionStatus,
Commit,
EditFile,
EpisodeResult,
FileSnapshot,
GenerateReview,
LintResult,
Observation,
PhaseState,
PlanStep,
ReflectionArtifact,
RequestIteration,
ReviewArtifact,
RunLint,
RunTests,
SelfReflect,
TaskDifficulty,
TestResult,
)
from sandbox.git_sandbox import GitSandbox
from tasks.task_registry import get_task
from reward import RewardCalculator
from grader import grade_episode
class TeamForgeEnv:
"""
OpenEnv-compliant environment for autonomous software team simulation.
An episode represents one attempt to complete a software engineering task.
The agent issues structured actions; the environment executes them against
a real Git repository and returns observations with dense rewards.
"""
def __init__(self, log_dir: Optional[str] = None):
self._sandbox = GitSandbox()
self._reward_calc = RewardCalculator()
self._obs: Optional[Observation] = None
self._task_module: Any = None
self._log_dir = log_dir
self._logs: List[str] = []
# Episode state
self._step_number = 0
self._cumulative_reward = 0.001
self._plan: List[PlanStep] = []
self._reviews: List[ReviewArtifact] = []
self._reflections: List[ReflectionArtifact] = []
self._last_test_result: Optional[TestResult] = None
self._last_lint_result: Optional[LintResult] = None
# ─────────────────────────────────────────────
# OpenEnv INTERFACE
# ─────────────────────────────────────────────
def reset(self, task_id: str) -> Observation:
"""
Start a new episode for the given task.
Tears down any previous sandbox and initialises a fresh git repo.
Args:
task_id: One of easy_bugfix_chunk_list | medium_refactor_stats |
hard_lru_cache_performance
Returns:
Initial observation with full repo snapshot.
"""
self._log(f"[START] task={task_id}")
# Clean up previous episode
self._sandbox.teardown()
self._sandbox = GitSandbox()
# Load task
self._task_module = get_task(task_id)
self._reward_calc = RewardCalculator()
# Detect test files and register with reward calculator
test_files = [
p for p in self._task_module.INITIAL_FILES
if "test" in p.lower()
]
self._reward_calc.set_test_files(test_files)
# Reset episode state
self._step_number = 0
self._cumulative_reward = 0.1
self._plan = []
self._reviews = []
self._reflections = []
self._last_test_result = None
self._last_lint_result = None
self._logs = [f"[START] task={task_id}"]
# Initialise git sandbox with task files
self._sandbox.init(self._task_module.INITIAL_FILES)
# Build initial observation
self._obs = self._build_observation(
action_type=None,
status=ActionStatus.SUCCESS,
output="Environment initialized.",
reward=0.1,
done=False,
)
return self._obs
def step(self, action: Action) -> Observation:
"""
Execute one action and return the resulting observation.
Args:
action: A typed Action model (PlanStep, EditFile, RunTests, …)
Returns:
Updated Observation with reward, done flag, and all state.
"""
if self._obs is None:
raise RuntimeError("Call reset() before step()")
self._step_number += 1
action_type = action.type
self._log(f"[STEP {self._step_number}] action={action_type}")
# ── Max steps guard ──
max_steps = self._task_module.MAX_STEPS
if self._step_number > max_steps:
return self._finalize(reason="Max steps exceeded")
# ── Dispatch action ──
status = ActionStatus.SUCCESS
output = ""
edited_file: Optional[str] = None
tests_passed: Optional[int] = None
lint_violations: Optional[int] = None
try:
if isinstance(action, PlanStep):
output = self._handle_plan_step(action)
elif isinstance(action, EditFile):
output, edited_file = self._handle_edit_file(action)
elif isinstance(action, RunTests):
output = self._handle_run_tests(action)
tests_passed = (self._last_test_result.passed
if self._last_test_result else 0)
elif isinstance(action, RunLint):
output = self._handle_run_lint(action)
lint_violations = (self._last_lint_result.violations
if self._last_lint_result else 0)
elif isinstance(action, GenerateReview):
output = self._handle_generate_review(action)
elif isinstance(action, Commit):
output = self._handle_commit(action)
elif isinstance(action, SelfReflect):
output = self._handle_self_reflect(action)
elif isinstance(action, RequestIteration):
output = self._handle_request_iteration(action)
else:
status = ActionStatus.FAILURE
output = f"Unknown action type: {action_type}"
except Exception as exc:
status = ActionStatus.FAILURE
output = f"Action failed with exception: {exc}"
self._log(f"[ERROR] {exc}")
# ── Compute reward ──
reward = self._reward_calc.compute(
action_type=action_type,
action_success=(status == ActionStatus.SUCCESS),
action_output=output,
tests_passed=tests_passed,
lint_violations=lint_violations,
edited_file=edited_file,
)
self._cumulative_reward += reward
# ── Check done conditions ──
done = self._check_done()
self._log(f"[STEP {self._step_number}] reward={reward:.4f} done={done}")
self._obs = self._build_observation(
action_type=action_type,
status=status,
output=output,
reward=reward,
done=done,
)
return self._obs
def state(self) -> Dict[str, Any]:
"""
Return current environment state as a plain dict.
Useful for serialisation and logging.
"""
if self._obs is None:
return {"status": "not_started"}
return {
"task_id": self._obs.task_id,
"step": self._step_number,
"phase": self._obs.phase.value,
"cumulative_reward": self._cumulative_reward,
"tests_passed": (self._last_test_result.passed
if self._last_test_result else 0),
"tests_failed": (self._last_test_result.failed
if self._last_test_result else 0),
"lint_violations": (self._last_lint_result.violations
if self._last_lint_result else 0),
"commits": len(self._sandbox.get_log()),
"plan_steps": len(self._plan),
"reviews": len(self._reviews),
"reflections": len(self._reflections),
"done": self._obs.done,
}
def grade(self) -> EpisodeResult:
"""Run the deterministic grader and return an EpisodeResult."""
required_kw = getattr(
self._task_module, "REQUIRED_KEYWORDS_IN_REVIEW", []
)
return grade_episode(
repo_path=str(self._sandbox.repo_path),
task_id=self._task_module.TASK_ID,
total_steps=self._step_number,
max_steps=self._task_module.MAX_STEPS,
reviews=self._reviews,
reflections=self._reflections,
required_keywords=required_kw,
)
# ─────────────────────────────────────────────
# ACTION HANDLERS
# ─────────────────────────────────────────────
def _handle_plan_step(self, action: PlanStep) -> str:
self._plan.append(action)
return (
f"Plan step {action.step_number} recorded: {action.description} "
f"[effort={action.estimated_effort}]"
)
def _handle_edit_file(self, action: EditFile) -> tuple[str, str]:
self._sandbox.write_file(action.file_path, action.content)
size = len(action.content.encode())
return (
f"Wrote {size} bytes to {action.file_path}. Reason: {action.reason}",
action.file_path,
)
def _handle_run_tests(self, action: RunTests) -> str:
cmd = [
sys.executable, "-m", "pytest",
"--tb=short", "-q", "--no-header",
f"--timeout={action.timeout_seconds}",
]
if action.test_path:
cmd.append(action.test_path)
start = time.perf_counter()
result = subprocess.run(
cmd,
cwd=str(self._sandbox.repo_path),
capture_output=True,
text=True,
timeout=action.timeout_seconds + 5,
)
elapsed = time.perf_counter() - start
output = result.stdout + result.stderr
passed = failed = errors = 0
m_p = re.search(r"(\d+) passed", output)
m_f = re.search(r"(\d+) failed", output)
m_e = re.search(r"(\d+) error", output)
if m_p:
passed = int(m_p.group(1))
if m_f:
failed = int(m_f.group(1))
if m_e:
errors = int(m_e.group(1))
self._last_test_result = TestResult(
passed=passed,
failed=failed,
errors=errors,
output=output[:2000],
duration_seconds=elapsed,
)
return output[:2000]
def _handle_run_lint(self, action: RunLint) -> str:
cmd = [sys.executable, "-m", "ruff", "check"]
if action.fix:
cmd.append("--fix")
if action.file_path:
cmd.append(action.file_path)
else:
cmd.append(".")
result = subprocess.run(
cmd,
cwd=str(self._sandbox.repo_path),
capture_output=True,
text=True,
)
output = result.stdout + result.stderr
violations = len([
ln for ln in output.splitlines()
if re.match(r".+:\d+:\d+:", ln)
])
score = max(0.001, min(0.999, 1.0 - violations * 0.05))
self._last_lint_result = LintResult(
violations=violations,
output=output[:2000],
score=score,
)
return output[:2000] or "No lint violations found."
def _handle_generate_review(self, action: GenerateReview) -> str:
review = ReviewArtifact(
reviewer="agent",
focus_areas=action.focus_areas,
text=action.review_text,
timestamp_step=self._step_number,
)
self._reviews.append(review)
return f"Review recorded ({len(action.review_text)} chars). Focus: {action.focus_areas}"
def _handle_commit(self, action: Commit) -> str:
if not self._sandbox.has_changes():
return "Nothing to commit. Working tree clean."
sha = self._sandbox.commit(
message=action.message,
files=action.files if action.files else None,
)
if sha:
return f"Committed: {sha} β€” {action.message}"
return "Commit failed (possibly nothing to stage)."
def _handle_self_reflect(self, action: SelfReflect) -> str:
reflection = ReflectionArtifact(
step=self._step_number,
what_went_well=action.what_went_well,
what_to_improve=action.what_to_improve,
adjusted_plan=action.adjusted_plan,
)
self._reflections.append(reflection)
return (
f"Reflection recorded at step {self._step_number}. "
f"Improving: {action.what_to_improve[:80]}"
)
def _handle_request_iteration(self, action: RequestIteration) -> str:
issues = ", ".join(action.target_issues) if action.target_issues else "none specified"
return f"Iteration requested: {action.reason} | Issues: {issues}"
# ─────────────────────────────────────────────
# HELPERS
# ─────────────────────────────────────────────
def _check_done(self) -> bool:
"""Episode is done if all tests pass and lint is clean."""
if self._last_test_result is None:
return False
tests_ok = (
self._last_test_result.failed == 0
and self._last_test_result.errors == 0
and self._last_test_result.passed > 0
)
lint_ok = (
self._last_lint_result is None
or self._last_lint_result.violations == 0
)
committed = len(self._sandbox.get_log()) > 1 # beyond initial commit
return tests_ok and lint_ok and committed
def _finalize(self, reason: str) -> Observation:
self._log(f"[END] {reason}")
self._obs = self._build_observation(
action_type=None,
status=ActionStatus.FAILURE,
output=reason,
reward=0.001,
done=True,
)
return self._obs
def _build_observation(
self,
action_type: Optional[str],
status: ActionStatus,
output: str,
reward: float,
done: bool,
) -> Observation:
"""Assemble a full Observation from current environment state."""
# Repo files snapshot (only .py, .md, .toml β€” cap at 8 files)
all_files = self._sandbox.get_all_files()
snapshots = [
FileSnapshot(
path=p,
content=c[:3000], # truncate large files
size_bytes=len(c.encode()),
)
for p, c in list(all_files.items())[:12]
]
# Determine phase
phase = self._infer_phase()
return Observation(
task_id=self._task_module.TASK_ID,
task_description=self._task_module.DESCRIPTION,
difficulty=TaskDifficulty(self._task_module.DIFFICULTY),
step_number=self._step_number,
max_steps=self._task_module.MAX_STEPS,
phase=phase,
repo_files=snapshots,
git_log=self._sandbox.get_log(n=5),
last_action_type=action_type,
last_action_status=status,
last_action_output=output,
test_results=self._last_test_result,
lint_results=self._last_lint_result,
plan=self._plan,
reviews=self._reviews,
reflections=self._reflections,
reward=reward,
cumulative_reward=self._cumulative_reward,
done=done,
info={
"sandbox_path": str(self._sandbox.repo_path),
"task_difficulty": self._task_module.DIFFICULTY,
},
)
def _infer_phase(self) -> PhaseState:
if self._step_number == 0:
return PhaseState.PLANNING
if self._plan and not self._last_test_result:
return PhaseState.CODING
if self._last_test_result and self._last_test_result.failed > 0:
return PhaseState.TESTING
if self._last_test_result and self._last_test_result.failed == 0 and not self._reviews:
return PhaseState.REVIEWING
if self._reviews and not self._reflections:
return PhaseState.REFLECTING
if self._obs and self._obs.done:
return PhaseState.DONE
return PhaseState.CODING
def _log(self, msg: str) -> None:
self._logs.append(msg)
print(msg)