Spaces:
Runtime error
Runtime error
| from typing import Dict, Any, Optional, List | |
| from uuid import uuid4 | |
| from pydantic import BaseModel, Field | |
| from dataclasses import dataclass | |
| class State: | |
| episode_id: str | |
| step_count: int | |
| class Environment: | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def state(self) -> State: | |
| raise NotImplementedError() | |
| def reset(self): | |
| raise NotImplementedError() | |
| def step(self, action): | |
| raise NotImplementedError() | |
| class Observation(BaseModel): | |
| done: bool = False | |
| reward: float = 0.0 | |
| observation: Dict[str, Any] = Field(default_factory=dict) | |
| metadata: Dict[str, Any] = Field(default_factory=dict) | |
| class BaseWorkflowEnvironment(Environment): | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self, seed: Optional[int] = None): | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self.seed = seed | |
| self.history: List[Dict[str, Any]] = [] | |
| self.max_steps: int = 20 | |
| self.task_state: Dict[str, Any] = {} | |
| def reset(self) -> Observation: | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self.history = [] | |
| self.task_state = {} | |
| return Observation( | |
| done=False, | |
| reward=0.0, | |
| observation={"status": "ready", "episode_id": self._state.episode_id}, | |
| metadata={"reset_count": len(self.history)} | |
| ) | |
| def step(self, action: Dict[str, Any]) -> Observation: | |
| self._state.step_count += 1 | |
| # Validate action | |
| if not isinstance(action, dict): | |
| return Observation( | |
| done=True, | |
| reward=-0.5, | |
| observation={"error": "Invalid action format"}, | |
| metadata={"step": self._state.step_count} | |
| ) | |
| # Record history | |
| self.history.append({ | |
| "step": self._state.step_count, | |
| "action": action, | |
| "timestamp": self._state.episode_id | |
| }) | |
| # Check max steps | |
| if self._state.step_count >= self.max_steps: | |
| return Observation( | |
| done=True, | |
| reward=0.0, | |
| observation={"status": "max_steps_reached"}, | |
| metadata={"step": self._state.step_count} | |
| ) | |
| return self._execute_action(action) | |
| def _execute_action(self, action: Dict[str, Any]) -> Observation: | |
| raise NotImplementedError("Subclasses must implement _execute_action") | |
| def state(self) -> State: | |
| return self._state |