File size: 7,888 Bytes
ab65628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
"""Episode state machine and management."""

from datetime import datetime, timezone
from enum import Enum
from typing import Any

from pydantic import BaseModel, Field


class EpisodeStatus(str, Enum):
    """Status of an episode."""

    PENDING = "pending"
    RUNNING = "running"
    COMPLETED = "completed"
    FAILED = "failed"
    TRUNCATED = "truncated"
    CANCELLED = "cancelled"


class EpisodeStep(BaseModel):
    """Record of a single step in the episode."""

    step_number: int
    timestamp: str
    action_type: str
    action_params: dict[str, Any]
    action_reasoning: str | None = None
    reward: float
    reward_breakdown: dict[str, float]
    observation_summary: dict[str, Any]
    error: str | None = None
    duration_ms: float = 0.0


class Episode(BaseModel):
    """
    Represents a complete episode in the RL environment.
    
    An episode is a sequence of steps from reset to termination,
    tracking all actions, rewards, and observations.
    """

    # Identification
    episode_id: str
    task_id: str

    # Timing
    created_at: str = Field(
        default_factory=lambda: datetime.now(timezone.utc).isoformat()
    )
    started_at: str | None = None
    ended_at: str | None = None

    # State
    status: EpisodeStatus = EpisodeStatus.PENDING
    current_step: int = 0
    max_steps: int = 50

    # Seed for reproducibility
    seed: int | None = None

    # Configuration
    config: dict[str, Any] = Field(default_factory=dict)

    # Step history
    steps: list[EpisodeStep] = Field(default_factory=list)

    # Aggregates
    total_reward: float = 0.0
    tokens_used: int = 0
    api_calls: int = 0
    estimated_cost_usd: float = 0.0

    # Results
    extracted_data: dict[str, Any] = Field(default_factory=dict)
    final_accuracy: float | None = None
    success: bool | None = None
    failure_reason: str | None = None

    # Navigation history
    urls_visited: list[str] = Field(default_factory=list)

    def start(self) -> None:
        """Mark the episode as started."""
        self.status = EpisodeStatus.RUNNING
        self.started_at = datetime.now(timezone.utc).isoformat()

    def add_step(
        self,
        action_type: str,
        action_params: dict[str, Any],
        reward: float,
        reward_breakdown: dict[str, float],
        observation_summary: dict[str, Any],
        action_reasoning: str | None = None,
        error: str | None = None,
        duration_ms: float = 0.0,
    ) -> EpisodeStep:
        """Add a step to the episode."""
        self.current_step += 1

        step = EpisodeStep(
            step_number=self.current_step,
            timestamp=datetime.now(timezone.utc).isoformat(),
            action_type=action_type,
            action_params=action_params,
            action_reasoning=action_reasoning,
            reward=reward,
            reward_breakdown=reward_breakdown,
            observation_summary=observation_summary,
            error=error,
            duration_ms=duration_ms,
        )

        self.steps.append(step)
        self.total_reward += reward

        return step

    def complete(
        self,
        success: bool,
        extracted_data: dict[str, Any] | None = None,
        final_accuracy: float | None = None,
    ) -> None:
        """Mark the episode as completed."""
        self.status = EpisodeStatus.COMPLETED
        self.ended_at = datetime.now(timezone.utc).isoformat()
        self.success = success
        if extracted_data:
            self.extracted_data = extracted_data
        self.final_accuracy = final_accuracy

    def fail(self, reason: str) -> None:
        """Mark the episode as failed."""
        self.status = EpisodeStatus.FAILED
        self.ended_at = datetime.now(timezone.utc).isoformat()
        self.success = False
        self.failure_reason = reason

    def truncate(self, reason: str = "max_steps_reached") -> None:
        """Mark the episode as truncated (stopped early)."""
        self.status = EpisodeStatus.TRUNCATED
        self.ended_at = datetime.now(timezone.utc).isoformat()
        self.failure_reason = reason

    def cancel(self) -> None:
        """Mark the episode as cancelled."""
        self.status = EpisodeStatus.CANCELLED
        self.ended_at = datetime.now(timezone.utc).isoformat()

    @property
    def is_terminal(self) -> bool:
        """Check if the episode has terminated."""
        return self.status in [
            EpisodeStatus.COMPLETED,
            EpisodeStatus.FAILED,
            EpisodeStatus.TRUNCATED,
            EpisodeStatus.CANCELLED,
        ]

    @property
    def duration_seconds(self) -> float | None:
        """Get episode duration in seconds."""
        if not self.started_at:
            return None
        end = self.ended_at or datetime.now(timezone.utc).isoformat()
        start_dt = datetime.fromisoformat(self.started_at.replace("Z", "+00:00"))
        end_dt = datetime.fromisoformat(end.replace("Z", "+00:00"))
        return (end_dt - start_dt).total_seconds()

    @property
    def average_reward(self) -> float:
        """Get average reward per step."""
        if not self.steps:
            return 0.0
        return self.total_reward / len(self.steps)

    def get_summary(self) -> dict[str, Any]:
        """Get a summary of the episode."""
        return {
            "episode_id": self.episode_id,
            "task_id": self.task_id,
            "status": self.status.value,
            "steps": self.current_step,
            "total_reward": self.total_reward,
            "average_reward": self.average_reward,
            "duration_seconds": self.duration_seconds,
            "tokens_used": self.tokens_used,
            "estimated_cost_usd": self.estimated_cost_usd,
            "success": self.success,
            "fields_extracted": len(self.extracted_data),
        }

    def get_step_history(
        self,
        start: int = 0,
        end: int | None = None,
    ) -> list[EpisodeStep]:
        """Get a slice of the step history."""
        return self.steps[start:end]

    def get_action_sequence(self) -> list[str]:
        """Get the sequence of action types taken."""
        return [step.action_type for step in self.steps]

    def get_reward_history(self) -> list[float]:
        """Get the sequence of rewards received."""
        return [step.reward for step in self.steps]


class EpisodeManager:
    """Manager for episode lifecycle."""

    def __init__(self) -> None:
        """Initialize the episode manager."""
        self._episodes: dict[str, Episode] = {}

    def create_episode(
        self,
        episode_id: str,
        task_id: str,
        max_steps: int = 50,
        seed: int | None = None,
        config: dict[str, Any] | None = None,
    ) -> Episode:
        """Create a new episode."""
        episode = Episode(
            episode_id=episode_id,
            task_id=task_id,
            max_steps=max_steps,
            seed=seed,
            config=config or {},
        )
        self._episodes[episode_id] = episode
        return episode

    def get_episode(self, episode_id: str) -> Episode | None:
        """Get an episode by ID."""
        return self._episodes.get(episode_id)

    def remove_episode(self, episode_id: str) -> bool:
        """Remove an episode."""
        if episode_id in self._episodes:
            del self._episodes[episode_id]
            return True
        return False

    def list_episodes(
        self,
        status: EpisodeStatus | None = None,
        task_id: str | None = None,
    ) -> list[Episode]:
        """List episodes with optional filtering."""
        episodes = list(self._episodes.values())
        if status:
            episodes = [e for e in episodes if e.status == status]
        if task_id:
            episodes = [e for e in episodes if e.task_id == task_id]
        return episodes