Spaces:
Sleeping
Sleeping
File size: 9,509 Bytes
d1221ff f8a321a d1221ff 999c3ec d1221ff f8a321a d1221ff f8a321a d1221ff 999c3ec d1221ff f8a321a d1221ff 999c3ec d1221ff 999c3ec d1221ff 999c3ec d1221ff 999c3ec d1221ff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 | """Main SummarizationEnvironment β implements the OpenEnv Environment interface.
Episode flow per task:
easy (2 steps): truncated_context β summarize β question β answer
medium (2 steps): longer truncated_context β summarize β question β answer
hard (3 steps): chunk1 β summarize β chunk2 β update_summary β question β answer
Reward: token-level F1 score, with a small conciseness bonus for compact summaries.
"""
import random
import sys
import os
import logging
from typing import Optional, List, Dict, Any
# Allow imports from project root when running from server/
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from openenv.core.env_server import Environment
from models import SummarizationAction, SummarizationObservation, SummarizationState
from tasks import get_task
from tasks.hard import HardTask
from graders import compute_reward
logger = logging.getLogger(__name__)
class SummarizationEnvironment(Environment):
"""RL environment for evaluating long-context summarization.
The agent must condense a truncated document into a summary, then use
that summary to answer a question about the original content. The reward
signal trains the model to write summaries that preserve answer-critical
information.
"""
SUPPORTS_CONCURRENT_SESSIONS = False
def __init__(self):
logger.info("Initialising SummarizationEnvironment...")
self._tasks: Dict[str, Any] = {}
self._reset_episode_state()
logger.info("Environment ready.")
# ------------------------------------------------------------------
# Internal episode state
# ------------------------------------------------------------------
def _reset_episode_state(self):
self._episode_id: Optional[str] = None
self._step_count: int = 0
self._task_name: str = "easy"
self._step_type: str = "summarize"
self._messages: List[Dict[str, str]] = []
self._ground_truth_list: List[str] = []
self._summary: Optional[str] = None
self._question: Optional[str] = None
self._context_length: int = 0
self._truncation_ratio: float = 0.7
self._category: Optional[str] = None
self._source_type: Optional[str] = None
# Hard task only: second chunk shown after first summary
self._hard_chunk2: Optional[str] = None
def _get_task(self, task_name: str):
"""Lazily initialize tasks so app startup stays fast on Spaces."""
task = self._tasks.get(task_name)
if task is None:
logger.info("Loading task '%s'...", task_name)
task = get_task(task_name)
self._tasks[task_name] = task
return task
# ------------------------------------------------------------------
# OpenEnv API
# ------------------------------------------------------------------
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
task_name: Optional[str] = None,
**kwargs,
) -> SummarizationObservation:
"""Start a new episode.
Task selection priority:
1. ``task_name`` kwarg (passed as extra field in ResetRequest)
2. ``seed`` β seed % 3 maps to easy/medium/hard
3. random choice
"""
self._reset_episode_state()
# Determine task
if task_name is None:
if seed is not None:
names = ["easy", "medium", "hard"]
task_name = names[seed % len(names)]
else:
task_name = random.choice(["easy", "medium", "hard"])
self._task_name = task_name
self._episode_id = episode_id or f"ep_{random.randint(10000, 99999)}"
rng_seed = seed
task = self._get_task(task_name)
sample = task.get_sample(seed=rng_seed)
# Store episode data
self._question = sample["question"]
self._ground_truth_list = sample["answer_list"]
self._context_length = len(sample["context"])
self._truncation_ratio = sample["truncation_ratio"]
self._category = sample.get("category")
self._source_type = sample.get("source_type")
# Hard task: store second chunk for step 2
if task_name == "hard" and "chunk2" in sample:
self._hard_chunk2 = sample["chunk2"]
first_chunk = sample["chunk1"]
else:
self._hard_chunk2 = None
first_chunk = sample["truncated_context"]
# Build initial conversation
system_msg = {"role": "system", "content": task.get_system_prompt()}
user_msg = {
"role": "user",
"content": task.get_summarize_prompt(first_chunk, self._truncation_ratio),
}
self._messages = [system_msg, user_msg]
self._step_type = "summarize"
return self._make_observation(done=False, reward=None)
def step(self, action: SummarizationAction) -> SummarizationObservation:
"""Process one agent action and return the next observation."""
self._step_count += 1
response = action.response.strip()
# Append model response to conversation history
self._messages.append({"role": "assistant", "content": response})
task = self._get_task(self._task_name)
# ββ Summarize step βββββββββββββββββββββββββββββββββββββββββββββ
if self._step_type == "summarize":
self._summary = response
if self._task_name == "hard" and self._hard_chunk2 is not None:
# Hard task: move to update_summary step with second chunk
assert isinstance(task, HardTask)
next_msg = {
"role": "user",
"content": task.get_update_summary_prompt(self._hard_chunk2),
}
self._messages.append(next_msg)
self._step_type = "update_summary"
self._hard_chunk2 = None # consumed
return self._make_observation(done=False, reward=None)
# Easy / medium: move directly to answer step
self._step_type = "answer"
self._messages.append(
{"role": "user", "content": task.get_answer_prompt(self._question)}
)
return self._make_observation(done=False, reward=None)
# ββ Update-summary step (hard task only) βββββββββββββββββββββββ
if self._step_type == "update_summary":
self._summary = response # updated combined summary
self._step_type = "answer"
assert isinstance(task, HardTask)
self._messages.append(
{"role": "user", "content": task.get_answer_prompt(self._question)}
)
return self._make_observation(done=False, reward=None)
# ββ Answer step ββββββββββββββββββββββββββββββββββββββββββββββββ
if self._step_type == "answer":
reward = compute_reward(
predicted=response,
ground_truth_list=self._ground_truth_list,
summary=self._summary,
task_name=self._task_name,
question=self._question,
)
self._step_type = "done"
return self._make_observation(done=True, reward=reward)
# Fallback: episode already done
return self._make_observation(done=True, reward=0.0)
@property
def state(self) -> SummarizationState:
return SummarizationState(
episode_id=self._episode_id,
step_count=self._step_count,
task_name=self._task_name,
step_type=self._step_type,
context_length=self._context_length,
question=self._question,
category=self._category,
source_type=self._source_type,
)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _make_observation(
self, done: bool, reward: Optional[float]
) -> SummarizationObservation:
return SummarizationObservation(
done=done,
reward=reward,
messages=list(self._messages), # copy
step_type=self._step_type,
task_name=self._task_name,
context_length=self._context_length,
truncation_ratio=self._truncation_ratio,
category=self._category,
source_type=self._source_type,
)
def metadata(self) -> Dict[str, Any]:
return {
"name": "Long-Context Summarization",
"description": (
"An RL environment that trains models to compress long documents into "
"compact summaries, evaluated by their ability to answer questions from "
"those summaries. Inspired by Cursor's self-summarization approach."
),
"version": "1.0.0",
"tasks": ["easy", "medium", "hard"],
"action_space": "Text (summary or answer)",
"reward_range": [0.0, 1.0],
"content_metadata": ["category", "source_type"],
}
|