Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| TeamForge Baseline Inference | |
| Runs a language-model agent through all TeamForge tasks. | |
| Usage: | |
| export GROQ_API_KEY=gsk_... | |
| export API_BASE_URL=https://api.groq.com/openai/v1 | |
| export MODEL_NAME=llama3-8b-8192 | |
| python baseline_inference.py [--task TASK_ID] [--seed 42] | |
| Outputs structured logs: [START] [STEP] [ACTION] [OBS] [END] | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| import time | |
| from typing import Any, Dict, List, Optional | |
| from openai import OpenAI | |
| # Local imports | |
| from environment import TeamForgeEnv | |
| from models import ( | |
| Action, | |
| Commit, | |
| EditFile, | |
| GenerateReview, | |
| Observation, | |
| PlanStep, | |
| RequestIteration, | |
| RunLint, | |
| RunTests, | |
| SelfReflect, | |
| ) | |
| from tasks import ALL_TASK_IDS | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # CONFIGURATION | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "llama3-8b-8192") | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY", "") | |
| OPENAI_API_KEY = GROQ_API_KEY | |
| MAX_RETRIES = 3 | |
| TEMPERATURE = 0.2 | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # SYSTEM PROMPT | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| SYSTEM_PROMPT = """ | |
| You are TeamForge β an autonomous AI software engineer. | |
| You work in structured phases: PLAN β CODE β TEST β REVIEW β REFLECT. | |
| At each step, you receive an observation (current repo state, test results, lint output) | |
| and must return exactly ONE action as a JSON object. | |
| Available action types and their required fields: | |
| 1. plan_step: | |
| {"type": "plan_step", "step_number": <int>, "description": "<str>", "estimated_effort": "low|medium|high"} | |
| 2. edit_file: | |
| {"type": "edit_file", "file_path": "<str>", "content": "<full file content>", "reason": "<str>"} | |
| 3. run_tests: | |
| {"type": "run_tests", "timeout_seconds": 30} | |
| 4. run_lint: | |
| {"type": "run_lint", "fix": false} | |
| 5. generate_review: | |
| {"type": "generate_review", "focus_areas": ["correctness", "style", "performance"], "review_text": "<detailed review>"} | |
| 6. commit: | |
| {"type": "commit", "message": "<conventional commit message>"} | |
| 7. self_reflect: | |
| {"type": "self_reflect", "what_went_well": "<str>", "what_to_improve": "<str>"} | |
| 8. request_iteration: | |
| {"type": "request_iteration", "reason": "<str>", "target_issues": ["<issue1>", "<issue2>"]} | |
| Rules: | |
| - NEVER modify test files (files whose path contains "test") | |
| - Always plan first (at least 2 plan_step actions) | |
| - After fixing code, always run_tests before committing | |
| - Always generate_review before final commit | |
| - Return ONLY the JSON object, no markdown, no explanation | |
| """.strip() | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # AGENT | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class TeamForgeAgent: | |
| """LLM-powered agent that drives the TeamForge environment.""" | |
| def __init__(self, client: OpenAI): | |
| self.client = client | |
| self.history: List[Dict[str, str]] = [] | |
| def reset(self) -> None: | |
| self.history = [] | |
| def act(self, obs: Observation) -> Optional[Action]: | |
| """Given an observation, call the LLM and parse the action.""" | |
| user_message = self._obs_to_prompt(obs) | |
| self.history.append({"role": "user", "content": user_message}) | |
| for attempt in range(MAX_RETRIES): | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| *self.history, | |
| ], | |
| temperature=TEMPERATURE, | |
| max_tokens=2000, | |
| ) | |
| content = response.choices[0].message.content.strip() | |
| self.history.append({"role": "assistant", "content": content}) | |
| action = self._parse_action(content) | |
| return action | |
| except Exception as exc: | |
| print(f"[WARN] LLM call attempt {attempt+1} failed: {exc}") | |
| time.sleep(2 ** attempt) | |
| return None | |
| def _obs_to_prompt(self, obs: Observation) -> str: | |
| """Convert observation to a compact text prompt.""" | |
| lines = [ | |
| f"## Task: {obs.task_id} ({obs.difficulty.value})", | |
| f"Step {obs.step_number}/{obs.max_steps} | Phase: {obs.phase.value}", | |
| f"Cumulative reward: {obs.cumulative_reward:.3f}", | |
| "", | |
| f"### Task Description\n{obs.task_description[:600]}", | |
| "", | |
| ] | |
| # Last action result | |
| if obs.last_action_type: | |
| lines += [ | |
| f"### Last Action: {obs.last_action_type} β {obs.last_action_status.value}", | |
| f"```\n{obs.last_action_output[:800]}\n```", | |
| "", | |
| ] | |
| # Test results | |
| if obs.test_results: | |
| tr = obs.test_results | |
| lines += [ | |
| f"### Tests: {tr.passed} passed / {tr.failed} failed / {tr.errors} errors", | |
| f"```\n{tr.output[:600]}\n```", | |
| "", | |
| ] | |
| # Lint | |
| if obs.lint_results: | |
| lr = obs.lint_results | |
| lines += [ | |
| f"### Lint: {lr.violations} violations (score={lr.score:.2f})", | |
| ] | |
| # Repo files (show names + first 200 chars of each) | |
| lines.append("### Repo Files") | |
| for f in obs.repo_files[:8]: | |
| lines.append(f"**{f.path}** ({f.size_bytes} bytes)") | |
| if f.size_bytes < 4000: | |
| lines.append(f"```python\n{f.content[:800]}\n```") | |
| # Plan so far | |
| if obs.plan: | |
| lines.append(f"### Plan ({len(obs.plan)} steps)") | |
| for step in obs.plan[-3:]: | |
| lines.append(f" {step.step_number}. {step.description}") | |
| lines.append("\n### What is your next action? Return ONLY a JSON object.") | |
| return "\n".join(lines) | |
| def _parse_action(self, text: str) -> Optional[Action]: | |
| """Parse LLM output as an Action model.""" | |
| # Strip markdown fences if present | |
| text = text.strip() | |
| if text.startswith("```"): | |
| lines = text.split("\n") | |
| text = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]) | |
| data = json.loads(text) | |
| action_type = data.get("type") | |
| dispatch = { | |
| "plan_step": PlanStep, | |
| "edit_file": EditFile, | |
| "run_tests": RunTests, | |
| "run_lint": RunLint, | |
| "generate_review": GenerateReview, | |
| "commit": Commit, | |
| "self_reflect": SelfReflect, | |
| "request_iteration": RequestIteration, | |
| } | |
| cls = dispatch.get(action_type) | |
| if cls is None: | |
| print(f"[WARN] Unknown action type: {action_type}") | |
| return None | |
| return cls(**data) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # EPISODE RUNNER | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_episode( | |
| env: TeamForgeEnv, | |
| agent: TeamForgeAgent, | |
| task_id: str, | |
| verbose: bool = True, | |
| ) -> Dict[str, Any]: | |
| """Run a single episode and return results.""" | |
| agent.reset() | |
| obs = env.reset(task_id) | |
| episode_log = [] | |
| print(f"\n{'='*60}") | |
| print(f"[START] task={task_id} | model={MODEL_NAME}") | |
| print(f"{'='*60}") | |
| episode_log.append({ | |
| "event": "START", | |
| "task_id": task_id, | |
| "model": MODEL_NAME, | |
| }) | |
| while not obs.done: | |
| action = agent.act(obs) | |
| if action is None: | |
| print("[ERROR] Agent returned no action. Stopping.") | |
| break | |
| if verbose: | |
| print(f"[STEP {obs.step_number + 1}] action={action.type}") | |
| obs = env.step(action) | |
| step_log = { | |
| "event": "STEP", | |
| "step": obs.step_number, | |
| "action_type": obs.last_action_type, | |
| "action_status": obs.last_action_status.value, | |
| "reward": obs.reward, | |
| "cumulative_reward": obs.cumulative_reward, | |
| "tests_passed": obs.test_results.passed if obs.test_results else 0, | |
| "tests_failed": obs.test_results.failed if obs.test_results else 0, | |
| "done": obs.done, | |
| } | |
| episode_log.append(step_log) | |
| if verbose: | |
| print( | |
| f" reward={obs.reward:.4f} cum={obs.cumulative_reward:.4f} " | |
| f"tests={step_log['tests_passed']}p/{step_log['tests_failed']}f " | |
| f"done={obs.done}" | |
| ) | |
| # Grade the episode | |
| result = env.grade() | |
| print(f"\n{'='*60}") | |
| print(f"[END] task={task_id}") | |
| print(f" final_score = {result.final_score:.4f}") | |
| print(f" test_pass_rate = {result.test_pass_rate:.4f}") | |
| print(f" lint_score = {result.lint_score:.4f}") | |
| print(f" efficiency = {result.efficiency_score:.4f}") | |
| print(f" review_quality = {result.review_quality:.4f}") | |
| print(f" passed = {result.passed}") | |
| print(f"{'='*60}\n") | |
| episode_log.append({ | |
| "event": "END", | |
| "task_id": task_id, | |
| "final_score": result.final_score, | |
| "test_pass_rate": result.test_pass_rate, | |
| "lint_score": result.lint_score, | |
| "efficiency_score": result.efficiency_score, | |
| "review_quality": result.review_quality, | |
| "passed": result.passed, | |
| "total_steps": result.total_steps, | |
| }) | |
| return { | |
| "task_id": task_id, | |
| "result": result.model_dump(), | |
| "log": episode_log, | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # MAIN | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| parser = argparse.ArgumentParser(description="TeamForge Baseline Inference") | |
| parser.add_argument( | |
| "--task", | |
| choices=ALL_TASK_IDS + ["all"], | |
| default="all", | |
| help="Task ID to run, or 'all'", | |
| ) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--output", type=str, default="results.json") | |
| parser.add_argument("--verbose", action="store_true", default=True) | |
| args = parser.parse_args() | |
| if not OPENAI_API_KEY or OPENAI_API_KEY.startswith("sk-placeholder"): | |
| print("[ERROR] Set OPENAI_API_KEY environment variable.") | |
| sys.exit(1) | |
| client = OpenAI(api_key=GROQ_API_KEY, base_url=API_BASE_URL) | |
| env = TeamForgeEnv(log_dir="logs/") | |
| agent = TeamForgeAgent(client) | |
| tasks_to_run = ALL_TASK_IDS if args.task == "all" else [args.task] | |
| all_results = [] | |
| for task_id in tasks_to_run: | |
| result = run_episode(env, agent, task_id, verbose=args.verbose) | |
| all_results.append(result) | |
| # Save results | |
| with open(args.output, "w") as f: | |
| json.dump(all_results, f, indent=2) | |
| print(f"\nResults saved to {args.output}") | |
| # Summary | |
| print("\nβββ SUMMARY βββββββββββββββββββββββββββββββββββββββββββββββ") | |
| for r in all_results: | |
| res = r["result"] | |
| status = "β PASS" if res["passed"] else "β FAIL" | |
| print( | |
| f"{status} {r['task_id']:40s} " | |
| f"score={res['final_score']:.4f} " | |
| f"steps={res['total_steps']}" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |