File size: 10,008 Bytes
2f78834
788f2cb
2f78834
788f2cb
 
 
 
2f78834
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
"""
app.py  β€”  HuggingFace Space entry point  (Gradio 5 compatible)

Serves a Gradio demo with three tabs:
  1. Live Episode   β€” deterministic 2-session run, no GPU needed
  2. Training Results β€” gallery of 5 evaluation plots
  3. Environment Info β€” architecture + reward table
"""

import json
import os
import sys
import textwrap

import gradio as gr

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from server.env import CrossSessionContinuityEnv, Action
from server.mcp_tools import build_tool_registry

# ── Shared env instance (reset per demo run) ─────────────────────────────────

_ENV: CrossSessionContinuityEnv | None = None


def _get_env(difficulty: str) -> CrossSessionContinuityEnv:
    global _ENV
    _ENV = CrossSessionContinuityEnv(difficulty=difficulty)
    return _ENV


# ── Demo logic ────────────────────────────────────────────────────────────────

def run_demo(difficulty: str, seed: int):
    """
    Run a deterministic 2-session episode with a rule-based stub agent
    (no GPU needed for the demo). Returns a formatted transcript.
    """
    env   = _get_env(difficulty)
    tools = build_tool_registry(env)
    obs   = env.reset(seed=int(seed))

    log = []

    def _log(tag, msg):
        log.append(f"**[{tag}]** {msg}")

    _log("RESET", f"Task: {obs['task'][:200]}...")
    _log("INFO",  f"Step limit: {obs['step_limit']} | Difficulty: {difficulty}")

    # ── Session 1: stub agent writes a valid handoff ──────────────────────────
    _log("SESSION 1", "Agent begins working on the task")

    # Step 1 β€” read starter file
    fname = list(obs["starter_code"].keys())[0]
    r = tools["read_file"](path=fname)
    _log("read_file", f"`{fname}` β†’ {str(r.get('output',''))[:120]}")

    # Step 2 β€” write partial implementation
    partial = obs["starter_code"][fname].replace(
        "# TODO: implement", "# Partial implementation from Session 1\n    return []"
    )
    r = tools["write_file"](path=fname, content=partial)
    _log("write_file", f"Partial impl written to `{fname}`")

    # Step 3 β€” run tests (partial β€” likely fails)
    r = tools["run_tests"]()
    _log("run_tests", f"Passed: {r.get('passed',0)}/{r.get('total',1)}")

    # Step 4 β€” write handoff
    handoff = textwrap.dedent(f"""\
        TASK: {obs['task'][:80]}
        COMPLETED:
        - Starter code loaded and read
        - Partial stub written (returns [])
        REMAINING:
        - Full logic implementation
        - Edge case handling (empty input, single element)
        KEY FUNCTIONS:
        - {fname.replace('.py','')}: main function, see starter_code
        EDGE CASES:
        - Empty list must return []
        - Single element list must return as-is
        NEXT STEPS:
        1. Read {fname} to see partial stub
        2. Implement the full algorithm
        3. Run tests and fix failures
        4. Call submit()
    """)
    r = tools["write_handoff"](content=handoff)
    if r.get("error"):
        _log("ERROR", r["error"])
        return "\n\n".join(log)
    _log("write_handoff", f"Handoff written. Session 2 starting.")

    # ── Session 2: cold start, parse handoff, implement, submit ──────────────
    _log("SESSION 2", "Agent starts with ONLY the handoff note")

    r = tools["parse_handoff"]()
    note = r.get("output", "")
    _log("parse_handoff", f"Note retrieved ({len(note.split())} tokens)")

    # Show handoff note nicely
    _log("HANDOFF NOTE", f"\n```\n{note}\n```")

    # Read file (now allowed after parse_handoff)
    r = tools["read_file"](path=fname)
    _log("read_file", f"Current state of `{fname}` retrieved")

    # Write correct implementation (stub oracle for demo)
    if "merge_intervals" in obs["task"].lower() or "combine_ranges" in obs["task"].lower():
        impl = (
            "def merge_intervals(intervals):\n"
            "    if not intervals: return []\n"
            "    intervals.sort(key=lambda x: x[0])\n"
            "    merged = [intervals[0]]\n"
            "    for start, end in intervals[1:]:\n"
            "        if start <= merged[-1][1]:\n"
            "            merged[-1][1] = max(merged[-1][1], end)\n"
            "        else:\n"
            "            merged.append([start, end])\n"
            "    return merged\n"
        )
    else:
        impl = partial.replace("return []", "pass  # TODO: implement in real training")

    # Rename to match randomized function name
    impl = impl
    r = tools["write_file"](path=fname, content=impl)
    _log("write_file", "Full implementation written")

    r = tools["run_tests"]()
    _log("run_tests", f"Passed: {r.get('passed',0)}/{r.get('total',1)}")

    r = tools["submit"]()
    _log("SUBMIT", f"**Reward: {r.get('reward', 0):.4f}**")
    if "breakdown" in r:
        bd = r["breakdown"]
        _log("BREAKDOWN",
             f"test={bd['test_score']:.3f}  "
             f"quality={bd['quality_score']:.3f}  "
             f"linearity={bd['linearity_score']:.3f}  "
             f"rewrite_pen={bd['rewrite_penalty']:.3f}")

    return "\n\n".join(log)


def show_plots():
    plots_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "plots")
    files = ["baseline_vs_trained.png", "reward_curve.png",
             "ablation_comparison.png", "difficulty_breakdown.png",
             "handoff_diff_over_epochs.png"]
    return [os.path.join(plots_dir, f) for f in files
            if os.path.exists(os.path.join(plots_dir, f))]


# ── Gradio UI ─────────────────────────────────────────────────────────────────

THEME = gr.themes.Soft(
    primary_hue="indigo",
    secondary_hue="blue",
    neutral_hue="slate",
)

with gr.Blocks(theme=THEME, title="Cross-Session Continuity Env") as demo:

    gr.Markdown("""
    # 🧠 Cross-Session Continuity Env
    > *Can RL teach an LLM to write better notes to its future self?*

    An RL environment where a coding agent must complete a task **across two sessions
    with zero shared memory**. Session 1 writes a structured handoff note.
    Session 2 starts completely cold β€” only the note exists.

    **Reward** = test correctness (visible + hidden) + handoff quality + session 2 linearity
    """)

    with gr.Tabs():

        # ── Tab 1: Live Demo ──────────────────────────────────────────────────
        with gr.Tab("Live Episode"):
            with gr.Row():
                difficulty = gr.Dropdown(
                    ["easy", "medium", "hard"], value="easy",
                    label="Difficulty", scale=1,
                )
                seed = gr.Slider(0, 100, value=42, step=1,
                                 label="Episode Seed (deterministic)", scale=3)
            run_btn = gr.Button("Run Episode", variant="primary")
            transcript = gr.Markdown(label="Episode Transcript")
            run_btn.click(run_demo, inputs=[difficulty, seed], outputs=transcript)

        # ── Tab 2: Training Results ────────────────────────────────────────────
        with gr.Tab("Training Results"):
            gr.Markdown("""
            ### Evaluation Plots
            Generated from real GRPO training. If training has not run yet,
            plots are synthetic placeholders (marked **[SYNTHETIC]** in title).
            """)
            refresh_btn = gr.Button("Refresh Plots")
            gallery = gr.Gallery(label="Training Evidence", columns=2, height=600)
            refresh_btn.click(show_plots, outputs=gallery)
            demo.load(show_plots, outputs=gallery)

        # ── Tab 3: Environment Info ────────────────────────────────────────────
        with gr.Tab("Environment Info"):
            gr.Markdown("""
            ### Architecture

            ```
            Episode = Session 1 + Session 2

            Session 1:
              Agent β†’ reads code, writes code, runs tests
              Agent β†’ calls write_handoff(structured_note)
                            ↓ [handoff.md is the ONLY bridge]
                            ↓ [filesystem wiped]
                            ↓ [function names randomized per episode]
            Session 2:
              Agent β†’ calls parse_handoff() first (enforced)
              Agent β†’ picks up, finishes implementation
              Agent β†’ calls submit() β†’ reward computed
            ```

            ### Reward Components

            | Component | Weight | Anti-gaming |
            |-----------|--------|-------------|
            | Tests (visible) | 33% | Hidden tests at submit |
            | Tests (hidden)  | 22% | Not shown via run_tests |
            | Handoff quality | 20% | Code-dump blocked |
            | Linearity       | 15% | Thrash detection |
            | Penalties       | 10% | Rewrite + invalid action |

            ### Tools
            `read_file` Β· `write_file` Β· `run_tests` Β· `write_handoff` Β· `parse_handoff` Β· `submit`

            ### Difficulty
            | Level  | Step limit | Visible tests | Hidden tests |
            |--------|-----------|---------------|--------------|
            | Easy   | 20        | 3             | 1            |
            | Medium | 35        | 5             | 2            |
            | Hard   | 55        | 8             | 3            |
            """)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)