File size: 9,509 Bytes
d1221ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8a321a
 
d1221ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
999c3ec
 
d1221ff
 
 
f8a321a
 
 
 
 
 
 
 
 
d1221ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8a321a
d1221ff
 
 
 
 
 
 
999c3ec
 
d1221ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8a321a
d1221ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
999c3ec
d1221ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
999c3ec
 
d1221ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
999c3ec
 
d1221ff
 
 
 
 
 
 
 
 
 
 
 
 
 
999c3ec
d1221ff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""Main SummarizationEnvironment β€” implements the OpenEnv Environment interface.

Episode flow per task:
  easy   (2 steps): truncated_context β†’ summarize β†’ question β†’ answer
  medium (2 steps): longer truncated_context β†’ summarize β†’ question β†’ answer
  hard   (3 steps): chunk1 β†’ summarize β†’ chunk2 β†’ update_summary β†’ question β†’ answer

Reward: token-level F1 score, with a small conciseness bonus for compact summaries.
"""
import random
import sys
import os
import logging
from typing import Optional, List, Dict, Any

# Allow imports from project root when running from server/
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from openenv.core.env_server import Environment

from models import SummarizationAction, SummarizationObservation, SummarizationState
from tasks import get_task
from tasks.hard import HardTask
from graders import compute_reward

logger = logging.getLogger(__name__)


class SummarizationEnvironment(Environment):
    """RL environment for evaluating long-context summarization.

    The agent must condense a truncated document into a summary, then use
    that summary to answer a question about the original content. The reward
    signal trains the model to write summaries that preserve answer-critical
    information.
    """

    SUPPORTS_CONCURRENT_SESSIONS = False

    def __init__(self):
        logger.info("Initialising SummarizationEnvironment...")
        self._tasks: Dict[str, Any] = {}
        self._reset_episode_state()
        logger.info("Environment ready.")

    # ------------------------------------------------------------------
    # Internal episode state
    # ------------------------------------------------------------------

    def _reset_episode_state(self):
        self._episode_id: Optional[str] = None
        self._step_count: int = 0
        self._task_name: str = "easy"
        self._step_type: str = "summarize"
        self._messages: List[Dict[str, str]] = []
        self._ground_truth_list: List[str] = []
        self._summary: Optional[str] = None
        self._question: Optional[str] = None
        self._context_length: int = 0
        self._truncation_ratio: float = 0.7
        self._category: Optional[str] = None
        self._source_type: Optional[str] = None
        # Hard task only: second chunk shown after first summary
        self._hard_chunk2: Optional[str] = None

    def _get_task(self, task_name: str):
        """Lazily initialize tasks so app startup stays fast on Spaces."""
        task = self._tasks.get(task_name)
        if task is None:
            logger.info("Loading task '%s'...", task_name)
            task = get_task(task_name)
            self._tasks[task_name] = task
        return task

    # ------------------------------------------------------------------
    # OpenEnv API
    # ------------------------------------------------------------------

    def reset(
        self,
        seed: Optional[int] = None,
        episode_id: Optional[str] = None,
        task_name: Optional[str] = None,
        **kwargs,
    ) -> SummarizationObservation:
        """Start a new episode.

        Task selection priority:
          1. ``task_name`` kwarg (passed as extra field in ResetRequest)
          2. ``seed`` β€” seed % 3 maps to easy/medium/hard
          3. random choice
        """
        self._reset_episode_state()

        # Determine task
        if task_name is None:
            if seed is not None:
                names = ["easy", "medium", "hard"]
                task_name = names[seed % len(names)]
            else:
                task_name = random.choice(["easy", "medium", "hard"])

        self._task_name = task_name
        self._episode_id = episode_id or f"ep_{random.randint(10000, 99999)}"

        rng_seed = seed
        task = self._get_task(task_name)
        sample = task.get_sample(seed=rng_seed)

        # Store episode data
        self._question = sample["question"]
        self._ground_truth_list = sample["answer_list"]
        self._context_length = len(sample["context"])
        self._truncation_ratio = sample["truncation_ratio"]
        self._category = sample.get("category")
        self._source_type = sample.get("source_type")

        # Hard task: store second chunk for step 2
        if task_name == "hard" and "chunk2" in sample:
            self._hard_chunk2 = sample["chunk2"]
            first_chunk = sample["chunk1"]
        else:
            self._hard_chunk2 = None
            first_chunk = sample["truncated_context"]

        # Build initial conversation
        system_msg = {"role": "system", "content": task.get_system_prompt()}
        user_msg = {
            "role": "user",
            "content": task.get_summarize_prompt(first_chunk, self._truncation_ratio),
        }
        self._messages = [system_msg, user_msg]
        self._step_type = "summarize"

        return self._make_observation(done=False, reward=None)

    def step(self, action: SummarizationAction) -> SummarizationObservation:
        """Process one agent action and return the next observation."""
        self._step_count += 1
        response = action.response.strip()

        # Append model response to conversation history
        self._messages.append({"role": "assistant", "content": response})

        task = self._get_task(self._task_name)

        # ── Summarize step ─────────────────────────────────────────────
        if self._step_type == "summarize":
            self._summary = response

            if self._task_name == "hard" and self._hard_chunk2 is not None:
                # Hard task: move to update_summary step with second chunk
                assert isinstance(task, HardTask)
                next_msg = {
                    "role": "user",
                    "content": task.get_update_summary_prompt(self._hard_chunk2),
                }
                self._messages.append(next_msg)
                self._step_type = "update_summary"
                self._hard_chunk2 = None  # consumed
                return self._make_observation(done=False, reward=None)

            # Easy / medium: move directly to answer step
            self._step_type = "answer"
            self._messages.append(
                {"role": "user", "content": task.get_answer_prompt(self._question)}
            )
            return self._make_observation(done=False, reward=None)

        # ── Update-summary step (hard task only) ───────────────────────
        if self._step_type == "update_summary":
            self._summary = response  # updated combined summary
            self._step_type = "answer"
            assert isinstance(task, HardTask)
            self._messages.append(
                {"role": "user", "content": task.get_answer_prompt(self._question)}
            )
            return self._make_observation(done=False, reward=None)

        # ── Answer step ────────────────────────────────────────────────
        if self._step_type == "answer":
            reward = compute_reward(
                predicted=response,
                ground_truth_list=self._ground_truth_list,
                summary=self._summary,
                task_name=self._task_name,
                question=self._question,
            )
            self._step_type = "done"
            return self._make_observation(done=True, reward=reward)

        # Fallback: episode already done
        return self._make_observation(done=True, reward=0.0)

    @property
    def state(self) -> SummarizationState:
        return SummarizationState(
            episode_id=self._episode_id,
            step_count=self._step_count,
            task_name=self._task_name,
            step_type=self._step_type,
            context_length=self._context_length,
            question=self._question,
            category=self._category,
            source_type=self._source_type,
        )

    # ------------------------------------------------------------------
    # Helpers
    # ------------------------------------------------------------------

    def _make_observation(
        self, done: bool, reward: Optional[float]
    ) -> SummarizationObservation:
        return SummarizationObservation(
            done=done,
            reward=reward,
            messages=list(self._messages),  # copy
            step_type=self._step_type,
            task_name=self._task_name,
            context_length=self._context_length,
            truncation_ratio=self._truncation_ratio,
            category=self._category,
            source_type=self._source_type,
        )

    def metadata(self) -> Dict[str, Any]:
        return {
            "name": "Long-Context Summarization",
            "description": (
                "An RL environment that trains models to compress long documents into "
                "compact summaries, evaluated by their ability to answer questions from "
                "those summaries. Inspired by Cursor's self-summarization approach."
            ),
            "version": "1.0.0",
            "tasks": ["easy", "medium", "hard"],
            "action_space": "Text (summary or answer)",
            "reward_range": [0.0, 1.0],
            "content_metadata": ["category", "source_type"],
        }