Spaces:
Sleeping
Sleeping
File size: 7,888 Bytes
ab65628 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 | """Episode state machine and management."""
from datetime import datetime, timezone
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
class EpisodeStatus(str, Enum):
"""Status of an episode."""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
TRUNCATED = "truncated"
CANCELLED = "cancelled"
class EpisodeStep(BaseModel):
"""Record of a single step in the episode."""
step_number: int
timestamp: str
action_type: str
action_params: dict[str, Any]
action_reasoning: str | None = None
reward: float
reward_breakdown: dict[str, float]
observation_summary: dict[str, Any]
error: str | None = None
duration_ms: float = 0.0
class Episode(BaseModel):
"""
Represents a complete episode in the RL environment.
An episode is a sequence of steps from reset to termination,
tracking all actions, rewards, and observations.
"""
# Identification
episode_id: str
task_id: str
# Timing
created_at: str = Field(
default_factory=lambda: datetime.now(timezone.utc).isoformat()
)
started_at: str | None = None
ended_at: str | None = None
# State
status: EpisodeStatus = EpisodeStatus.PENDING
current_step: int = 0
max_steps: int = 50
# Seed for reproducibility
seed: int | None = None
# Configuration
config: dict[str, Any] = Field(default_factory=dict)
# Step history
steps: list[EpisodeStep] = Field(default_factory=list)
# Aggregates
total_reward: float = 0.0
tokens_used: int = 0
api_calls: int = 0
estimated_cost_usd: float = 0.0
# Results
extracted_data: dict[str, Any] = Field(default_factory=dict)
final_accuracy: float | None = None
success: bool | None = None
failure_reason: str | None = None
# Navigation history
urls_visited: list[str] = Field(default_factory=list)
def start(self) -> None:
"""Mark the episode as started."""
self.status = EpisodeStatus.RUNNING
self.started_at = datetime.now(timezone.utc).isoformat()
def add_step(
self,
action_type: str,
action_params: dict[str, Any],
reward: float,
reward_breakdown: dict[str, float],
observation_summary: dict[str, Any],
action_reasoning: str | None = None,
error: str | None = None,
duration_ms: float = 0.0,
) -> EpisodeStep:
"""Add a step to the episode."""
self.current_step += 1
step = EpisodeStep(
step_number=self.current_step,
timestamp=datetime.now(timezone.utc).isoformat(),
action_type=action_type,
action_params=action_params,
action_reasoning=action_reasoning,
reward=reward,
reward_breakdown=reward_breakdown,
observation_summary=observation_summary,
error=error,
duration_ms=duration_ms,
)
self.steps.append(step)
self.total_reward += reward
return step
def complete(
self,
success: bool,
extracted_data: dict[str, Any] | None = None,
final_accuracy: float | None = None,
) -> None:
"""Mark the episode as completed."""
self.status = EpisodeStatus.COMPLETED
self.ended_at = datetime.now(timezone.utc).isoformat()
self.success = success
if extracted_data:
self.extracted_data = extracted_data
self.final_accuracy = final_accuracy
def fail(self, reason: str) -> None:
"""Mark the episode as failed."""
self.status = EpisodeStatus.FAILED
self.ended_at = datetime.now(timezone.utc).isoformat()
self.success = False
self.failure_reason = reason
def truncate(self, reason: str = "max_steps_reached") -> None:
"""Mark the episode as truncated (stopped early)."""
self.status = EpisodeStatus.TRUNCATED
self.ended_at = datetime.now(timezone.utc).isoformat()
self.failure_reason = reason
def cancel(self) -> None:
"""Mark the episode as cancelled."""
self.status = EpisodeStatus.CANCELLED
self.ended_at = datetime.now(timezone.utc).isoformat()
@property
def is_terminal(self) -> bool:
"""Check if the episode has terminated."""
return self.status in [
EpisodeStatus.COMPLETED,
EpisodeStatus.FAILED,
EpisodeStatus.TRUNCATED,
EpisodeStatus.CANCELLED,
]
@property
def duration_seconds(self) -> float | None:
"""Get episode duration in seconds."""
if not self.started_at:
return None
end = self.ended_at or datetime.now(timezone.utc).isoformat()
start_dt = datetime.fromisoformat(self.started_at.replace("Z", "+00:00"))
end_dt = datetime.fromisoformat(end.replace("Z", "+00:00"))
return (end_dt - start_dt).total_seconds()
@property
def average_reward(self) -> float:
"""Get average reward per step."""
if not self.steps:
return 0.0
return self.total_reward / len(self.steps)
def get_summary(self) -> dict[str, Any]:
"""Get a summary of the episode."""
return {
"episode_id": self.episode_id,
"task_id": self.task_id,
"status": self.status.value,
"steps": self.current_step,
"total_reward": self.total_reward,
"average_reward": self.average_reward,
"duration_seconds": self.duration_seconds,
"tokens_used": self.tokens_used,
"estimated_cost_usd": self.estimated_cost_usd,
"success": self.success,
"fields_extracted": len(self.extracted_data),
}
def get_step_history(
self,
start: int = 0,
end: int | None = None,
) -> list[EpisodeStep]:
"""Get a slice of the step history."""
return self.steps[start:end]
def get_action_sequence(self) -> list[str]:
"""Get the sequence of action types taken."""
return [step.action_type for step in self.steps]
def get_reward_history(self) -> list[float]:
"""Get the sequence of rewards received."""
return [step.reward for step in self.steps]
class EpisodeManager:
"""Manager for episode lifecycle."""
def __init__(self) -> None:
"""Initialize the episode manager."""
self._episodes: dict[str, Episode] = {}
def create_episode(
self,
episode_id: str,
task_id: str,
max_steps: int = 50,
seed: int | None = None,
config: dict[str, Any] | None = None,
) -> Episode:
"""Create a new episode."""
episode = Episode(
episode_id=episode_id,
task_id=task_id,
max_steps=max_steps,
seed=seed,
config=config or {},
)
self._episodes[episode_id] = episode
return episode
def get_episode(self, episode_id: str) -> Episode | None:
"""Get an episode by ID."""
return self._episodes.get(episode_id)
def remove_episode(self, episode_id: str) -> bool:
"""Remove an episode."""
if episode_id in self._episodes:
del self._episodes[episode_id]
return True
return False
def list_episodes(
self,
status: EpisodeStatus | None = None,
task_id: str | None = None,
) -> list[Episode]:
"""List episodes with optional filtering."""
episodes = list(self._episodes.values())
if status:
episodes = [e for e in episodes if e.status == status]
if task_id:
episodes = [e for e in episodes if e.task_id == task_id]
return episodes
|