File size: 7,078 Bytes
e3d3c78 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | """
client/agent.py
Agent loop for the Cross-Session Continuity environment.
Key constraints (enforced):
- NO imports from server/ β client talks via MCP protocol only in production.
In local dev/training, the env is passed in directly.
- Retry logic for invalid actions (retry_budget = 3).
- Session-aware system prompts.
- Graceful noop fallback on retry exhaustion.
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
# ---------------------------------------------------------------------------
# Action (mirrored from server β no import)
# ---------------------------------------------------------------------------
@dataclass
class Action:
tool: str
path: str = ""
content: str = ""
args: Dict[str, Any] = field(default_factory=dict)
# ---------------------------------------------------------------------------
# System prompts
# ---------------------------------------------------------------------------
S1_SYSTEM_PROMPT = """\
You are working on a coding task in Session 1.
Complete as much as possible within your step limit.
When approaching the limit, call write_handoff() with a structured note:
Required sections (all mandatory):
TASK: one sentence β what the overall task is
COMPLETED: bullet list β fully implemented + test-verified items
REMAINING: bullet list β what Session 2 must still implement
KEY FUNCTIONS: function/class names, signatures, brief purpose
EDGE CASES: constraints or tricky logic discovered in Session 1
NEXT STEPS: ordered list β what Session 2 should do first
Constraints:
- Max 400 tokens in handoff note.
- Max 5 lines of code in code blocks.
- All 6 sections must be present.
- You have a retry budget of 3 for invalid actions β use it wisely.
Available tools: read_file, write_file, run_tests, write_handoff
"""
S2_SYSTEM_PROMPT = """\
You are in Session 2. You have NO memory of Session 1.
Your ONLY information about what was done is the handoff note.
Start by calling parse_handoff() to retrieve the note.
Then use the note to continue the task.
Do NOT rewrite everything from scratch β the note tells you what to build on.
Available tools: parse_handoff, read_file, write_file, run_tests, submit
"""
# ---------------------------------------------------------------------------
# Agent
# ---------------------------------------------------------------------------
class Agent:
"""
LLM-backed agent for the Cross-Session Continuity environment.
In training, model and tokenizer are injected (Unsloth Qwen2.5-Coder-7B).
In eval/demo, model can be replaced with a rule-based stub.
"""
def __init__(
self,
model=None,
tokenizer=None,
retry_budget: int = 3,
max_new_tokens: int = 512,
):
self.model = model
self.tokenizer = tokenizer
self.retry_budget = retry_budget
self.max_new_tokens = max_new_tokens
self.context: List[Dict] = []
def act(self, obs: Dict[str, Any]) -> Action:
"""
Generate an action given the current observation.
Retries up to retry_budget times if the model output cannot be parsed.
Falls back to a noop action on exhaustion.
"""
prompt = self._build_prompt(obs)
for attempt in range(self.retry_budget):
response = self._generate(prompt)
action = self._parse_action(response)
if action is not None:
self.context.append({"obs": obs, "action": action, "response": response})
return action
# Build a retry prompt with the failed response
prompt = self._build_retry_prompt(prompt, response, attempt)
# Graceful fallback β noop so the episode doesn't crash
return Action(tool="noop", content="")
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _build_prompt(self, obs: Dict[str, Any]) -> str:
system = S1_SYSTEM_PROMPT if obs.get("session", 1) == 1 else S2_SYSTEM_PROMPT
obs_text = self._format_obs(obs)
return f"{system}\n\nObservation:\n{obs_text}\n\nAction:"
def _build_retry_prompt(self, prev_prompt: str, failed_response: str, attempt: int) -> str:
return (
f"{prev_prompt}\n\n"
f"[Attempt {attempt + 1} failed. Output was not a valid tool call.]\n"
f"Failed output: {failed_response[:200]}\n\n"
"Please output a valid tool call in the format:\n"
"TOOL: <tool_name>\nPATH: <path (if applicable)>\nCONTENT:\n<content>\n\n"
"Action:"
)
def _generate(self, prompt: str) -> str:
"""
Generate a response from the model.
In training: uses self.model + self.tokenizer (Unsloth/HF).
In stub mode (model=None): returns empty string for override.
"""
if self.model is None or self.tokenizer is None:
return "" # stub β caller must override or inject model
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
do_sample=True,
temperature=0.7,
pad_token_id=self.tokenizer.eos_token_id,
)
decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Strip the prompt from the output
if decoded.startswith(prompt):
decoded = decoded[len(prompt):]
return decoded.strip()
@staticmethod
def _parse_action(response: str) -> Optional[Action]:
"""
Parse LLM output into an Action.
Expected format:
TOOL: write_file
PATH: solution.py
CONTENT:
<content lines>
"""
if not response:
return None
tool_match = re.search(r"TOOL:\s*(\w+)", response, re.IGNORECASE)
if not tool_match:
return None
tool = tool_match.group(1).strip().lower()
path_match = re.search(r"PATH:\s*(.+)", response, re.IGNORECASE)
content_match = re.search(r"CONTENT:\s*\n(.*)", response, re.IGNORECASE | re.DOTALL)
path = path_match.group(1).strip() if path_match else ""
content = content_match.group(1).strip() if content_match else response
return Action(tool=tool, path=path, content=content)
@staticmethod
def _format_obs(obs: Dict[str, Any]) -> str:
parts = []
for key, val in obs.items():
if key in ("done", "reward"):
continue
if isinstance(val, dict):
for k, v in val.items():
parts.append(f"[{key}/{k}]\n{v}")
else:
parts.append(f"[{key}]\n{val}")
return "\n\n".join(parts)
|