File size: 7,078 Bytes
e3d3c78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
"""
client/agent.py

Agent loop for the Cross-Session Continuity environment.

Key constraints (enforced):
  - NO imports from server/ β€” client talks via MCP protocol only in production.
    In local dev/training, the env is passed in directly.
  - Retry logic for invalid actions (retry_budget = 3).
  - Session-aware system prompts.
  - Graceful noop fallback on retry exhaustion.
"""

from __future__ import annotations

import re
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional


# ---------------------------------------------------------------------------
# Action (mirrored from server β€” no import)
# ---------------------------------------------------------------------------

@dataclass
class Action:
    tool: str
    path: str = ""
    content: str = ""
    args: Dict[str, Any] = field(default_factory=dict)


# ---------------------------------------------------------------------------
# System prompts
# ---------------------------------------------------------------------------

S1_SYSTEM_PROMPT = """\
You are working on a coding task in Session 1.
Complete as much as possible within your step limit.
When approaching the limit, call write_handoff() with a structured note:

Required sections (all mandatory):
  TASK:          one sentence β€” what the overall task is
  COMPLETED:     bullet list β€” fully implemented + test-verified items
  REMAINING:     bullet list β€” what Session 2 must still implement
  KEY FUNCTIONS: function/class names, signatures, brief purpose
  EDGE CASES:    constraints or tricky logic discovered in Session 1
  NEXT STEPS:    ordered list β€” what Session 2 should do first

Constraints:
  - Max 400 tokens in handoff note.
  - Max 5 lines of code in code blocks.
  - All 6 sections must be present.
  - You have a retry budget of 3 for invalid actions β€” use it wisely.

Available tools: read_file, write_file, run_tests, write_handoff
"""

S2_SYSTEM_PROMPT = """\
You are in Session 2. You have NO memory of Session 1.
Your ONLY information about what was done is the handoff note.

Start by calling parse_handoff() to retrieve the note.
Then use the note to continue the task.
Do NOT rewrite everything from scratch β€” the note tells you what to build on.

Available tools: parse_handoff, read_file, write_file, run_tests, submit
"""


# ---------------------------------------------------------------------------
# Agent
# ---------------------------------------------------------------------------

class Agent:
    """
    LLM-backed agent for the Cross-Session Continuity environment.

    In training, model and tokenizer are injected (Unsloth Qwen2.5-Coder-7B).
    In eval/demo, model can be replaced with a rule-based stub.
    """

    def __init__(
        self,
        model=None,
        tokenizer=None,
        retry_budget: int = 3,
        max_new_tokens: int = 512,
    ):
        self.model          = model
        self.tokenizer      = tokenizer
        self.retry_budget   = retry_budget
        self.max_new_tokens = max_new_tokens
        self.context: List[Dict] = []

    def act(self, obs: Dict[str, Any]) -> Action:
        """
        Generate an action given the current observation.

        Retries up to retry_budget times if the model output cannot be parsed.
        Falls back to a noop action on exhaustion.
        """
        prompt = self._build_prompt(obs)

        for attempt in range(self.retry_budget):
            response = self._generate(prompt)
            action   = self._parse_action(response)

            if action is not None:
                self.context.append({"obs": obs, "action": action, "response": response})
                return action

            # Build a retry prompt with the failed response
            prompt = self._build_retry_prompt(prompt, response, attempt)

        # Graceful fallback β€” noop so the episode doesn't crash
        return Action(tool="noop", content="")

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    def _build_prompt(self, obs: Dict[str, Any]) -> str:
        system = S1_SYSTEM_PROMPT if obs.get("session", 1) == 1 else S2_SYSTEM_PROMPT
        obs_text = self._format_obs(obs)
        return f"{system}\n\nObservation:\n{obs_text}\n\nAction:"

    def _build_retry_prompt(self, prev_prompt: str, failed_response: str, attempt: int) -> str:
        return (
            f"{prev_prompt}\n\n"
            f"[Attempt {attempt + 1} failed. Output was not a valid tool call.]\n"
            f"Failed output: {failed_response[:200]}\n\n"
            "Please output a valid tool call in the format:\n"
            "TOOL: <tool_name>\nPATH: <path (if applicable)>\nCONTENT:\n<content>\n\n"
            "Action:"
        )

    def _generate(self, prompt: str) -> str:
        """
        Generate a response from the model.

        In training: uses self.model + self.tokenizer (Unsloth/HF).
        In stub mode (model=None): returns empty string for override.
        """
        if self.model is None or self.tokenizer is None:
            return ""   # stub β€” caller must override or inject model

        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=self.max_new_tokens,
            do_sample=True,
            temperature=0.7,
            pad_token_id=self.tokenizer.eos_token_id,
        )
        decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Strip the prompt from the output
        if decoded.startswith(prompt):
            decoded = decoded[len(prompt):]
        return decoded.strip()

    @staticmethod
    def _parse_action(response: str) -> Optional[Action]:
        """
        Parse LLM output into an Action.

        Expected format:
          TOOL: write_file
          PATH: solution.py
          CONTENT:
          <content lines>
        """
        if not response:
            return None

        tool_match = re.search(r"TOOL:\s*(\w+)", response, re.IGNORECASE)
        if not tool_match:
            return None

        tool = tool_match.group(1).strip().lower()
        path_match    = re.search(r"PATH:\s*(.+)", response, re.IGNORECASE)
        content_match = re.search(r"CONTENT:\s*\n(.*)", response, re.IGNORECASE | re.DOTALL)

        path    = path_match.group(1).strip()    if path_match    else ""
        content = content_match.group(1).strip() if content_match else response

        return Action(tool=tool, path=path, content=content)

    @staticmethod
    def _format_obs(obs: Dict[str, Any]) -> str:
        parts = []
        for key, val in obs.items():
            if key in ("done", "reward"):
                continue
            if isinstance(val, dict):
                for k, v in val.items():
                    parts.append(f"[{key}/{k}]\n{v}")
            else:
                parts.append(f"[{key}]\n{val}")
        return "\n\n".join(parts)