scrapeRL / backend /app /core /episode.py
NeerajCodz's picture
feat: add core RL environment models (observation, action, reward, env)
ab65628
"""Episode state machine and management."""
from datetime import datetime, timezone
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
class EpisodeStatus(str, Enum):
"""Status of an episode."""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
TRUNCATED = "truncated"
CANCELLED = "cancelled"
class EpisodeStep(BaseModel):
"""Record of a single step in the episode."""
step_number: int
timestamp: str
action_type: str
action_params: dict[str, Any]
action_reasoning: str | None = None
reward: float
reward_breakdown: dict[str, float]
observation_summary: dict[str, Any]
error: str | None = None
duration_ms: float = 0.0
class Episode(BaseModel):
"""
Represents a complete episode in the RL environment.
An episode is a sequence of steps from reset to termination,
tracking all actions, rewards, and observations.
"""
# Identification
episode_id: str
task_id: str
# Timing
created_at: str = Field(
default_factory=lambda: datetime.now(timezone.utc).isoformat()
)
started_at: str | None = None
ended_at: str | None = None
# State
status: EpisodeStatus = EpisodeStatus.PENDING
current_step: int = 0
max_steps: int = 50
# Seed for reproducibility
seed: int | None = None
# Configuration
config: dict[str, Any] = Field(default_factory=dict)
# Step history
steps: list[EpisodeStep] = Field(default_factory=list)
# Aggregates
total_reward: float = 0.0
tokens_used: int = 0
api_calls: int = 0
estimated_cost_usd: float = 0.0
# Results
extracted_data: dict[str, Any] = Field(default_factory=dict)
final_accuracy: float | None = None
success: bool | None = None
failure_reason: str | None = None
# Navigation history
urls_visited: list[str] = Field(default_factory=list)
def start(self) -> None:
"""Mark the episode as started."""
self.status = EpisodeStatus.RUNNING
self.started_at = datetime.now(timezone.utc).isoformat()
def add_step(
self,
action_type: str,
action_params: dict[str, Any],
reward: float,
reward_breakdown: dict[str, float],
observation_summary: dict[str, Any],
action_reasoning: str | None = None,
error: str | None = None,
duration_ms: float = 0.0,
) -> EpisodeStep:
"""Add a step to the episode."""
self.current_step += 1
step = EpisodeStep(
step_number=self.current_step,
timestamp=datetime.now(timezone.utc).isoformat(),
action_type=action_type,
action_params=action_params,
action_reasoning=action_reasoning,
reward=reward,
reward_breakdown=reward_breakdown,
observation_summary=observation_summary,
error=error,
duration_ms=duration_ms,
)
self.steps.append(step)
self.total_reward += reward
return step
def complete(
self,
success: bool,
extracted_data: dict[str, Any] | None = None,
final_accuracy: float | None = None,
) -> None:
"""Mark the episode as completed."""
self.status = EpisodeStatus.COMPLETED
self.ended_at = datetime.now(timezone.utc).isoformat()
self.success = success
if extracted_data:
self.extracted_data = extracted_data
self.final_accuracy = final_accuracy
def fail(self, reason: str) -> None:
"""Mark the episode as failed."""
self.status = EpisodeStatus.FAILED
self.ended_at = datetime.now(timezone.utc).isoformat()
self.success = False
self.failure_reason = reason
def truncate(self, reason: str = "max_steps_reached") -> None:
"""Mark the episode as truncated (stopped early)."""
self.status = EpisodeStatus.TRUNCATED
self.ended_at = datetime.now(timezone.utc).isoformat()
self.failure_reason = reason
def cancel(self) -> None:
"""Mark the episode as cancelled."""
self.status = EpisodeStatus.CANCELLED
self.ended_at = datetime.now(timezone.utc).isoformat()
@property
def is_terminal(self) -> bool:
"""Check if the episode has terminated."""
return self.status in [
EpisodeStatus.COMPLETED,
EpisodeStatus.FAILED,
EpisodeStatus.TRUNCATED,
EpisodeStatus.CANCELLED,
]
@property
def duration_seconds(self) -> float | None:
"""Get episode duration in seconds."""
if not self.started_at:
return None
end = self.ended_at or datetime.now(timezone.utc).isoformat()
start_dt = datetime.fromisoformat(self.started_at.replace("Z", "+00:00"))
end_dt = datetime.fromisoformat(end.replace("Z", "+00:00"))
return (end_dt - start_dt).total_seconds()
@property
def average_reward(self) -> float:
"""Get average reward per step."""
if not self.steps:
return 0.0
return self.total_reward / len(self.steps)
def get_summary(self) -> dict[str, Any]:
"""Get a summary of the episode."""
return {
"episode_id": self.episode_id,
"task_id": self.task_id,
"status": self.status.value,
"steps": self.current_step,
"total_reward": self.total_reward,
"average_reward": self.average_reward,
"duration_seconds": self.duration_seconds,
"tokens_used": self.tokens_used,
"estimated_cost_usd": self.estimated_cost_usd,
"success": self.success,
"fields_extracted": len(self.extracted_data),
}
def get_step_history(
self,
start: int = 0,
end: int | None = None,
) -> list[EpisodeStep]:
"""Get a slice of the step history."""
return self.steps[start:end]
def get_action_sequence(self) -> list[str]:
"""Get the sequence of action types taken."""
return [step.action_type for step in self.steps]
def get_reward_history(self) -> list[float]:
"""Get the sequence of rewards received."""
return [step.reward for step in self.steps]
class EpisodeManager:
"""Manager for episode lifecycle."""
def __init__(self) -> None:
"""Initialize the episode manager."""
self._episodes: dict[str, Episode] = {}
def create_episode(
self,
episode_id: str,
task_id: str,
max_steps: int = 50,
seed: int | None = None,
config: dict[str, Any] | None = None,
) -> Episode:
"""Create a new episode."""
episode = Episode(
episode_id=episode_id,
task_id=task_id,
max_steps=max_steps,
seed=seed,
config=config or {},
)
self._episodes[episode_id] = episode
return episode
def get_episode(self, episode_id: str) -> Episode | None:
"""Get an episode by ID."""
return self._episodes.get(episode_id)
def remove_episode(self, episode_id: str) -> bool:
"""Remove an episode."""
if episode_id in self._episodes:
del self._episodes[episode_id]
return True
return False
def list_episodes(
self,
status: EpisodeStatus | None = None,
task_id: str | None = None,
) -> list[Episode]:
"""List episodes with optional filtering."""
episodes = list(self._episodes.values())
if status:
episodes = [e for e in episodes if e.status == status]
if task_id:
episodes = [e for e in episodes if e.task_id == task_id]
return episodes