Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from abc import ABC, abstractmethod | |
| from copy import deepcopy | |
| from dataclasses import dataclass, field | |
| from typing import Any | |
| class StepOutcome: | |
| reward: float = 0.0 | |
| progress_delta: int = 0 | |
| done: bool = False | |
| success: bool = False | |
| truncated: bool = False | |
| event: str = "progress" | |
| info: dict[str, Any] = field(default_factory=dict) | |
| state_updates: dict[str, Any] = field(default_factory=dict) | |
| class BaseDomain(ABC): | |
| """ | |
| Base reusable contract for OmniBench Aegis OpenEnv domains. | |
| Every domain should either: | |
| - implement `build_initial_state()` + `apply_action()` and use the generic | |
| `reset()` / `step()` from this base class, or | |
| - override `reset()` / `step()` if it needs a custom execution contract. | |
| Standard state keys maintained here: | |
| - domain | |
| - seed | |
| - mission | |
| - max_steps | |
| - step_count | |
| - progress | |
| - target_progress | |
| - score | |
| - done | |
| - success | |
| - last_action | |
| - history | |
| - metadata | |
| """ | |
| domain_name = "base" | |
| env_name = "omnibench_aegis_env" | |
| default_max_steps = 20 | |
| default_target_progress = 100 | |
| def list_task_categories(self) -> list[str]: | |
| return [] | |
| def build_initial_state(self, **kwargs: Any) -> dict[str, Any]: | |
| """Return the raw initial state for the domain.""" | |
| raise NotImplementedError | |
| def apply_action(self, state: dict[str, Any], action: dict[str, Any]) -> StepOutcome: | |
| """Apply one environment action to the current state and return a StepOutcome.""" | |
| raise NotImplementedError | |
| def normalize_action(self, action: dict[str, Any] | None) -> dict[str, Any]: | |
| return dict(action or {}) | |
| def default_action(self) -> dict[str, Any]: | |
| return {} | |
| def get_observation(self, state: dict[str, Any]) -> dict[str, Any]: | |
| return { | |
| "text": ( | |
| f"Domain: {state.get('domain', self.domain_name)}. " | |
| f"Mission: {state.get('mission', 'N/A')} " | |
| f"Step {state.get('step_count', 0)}/{state.get('max_steps', self.default_max_steps)}. " | |
| f"Progress {state.get('progress', 0)}/{state.get('target_progress', self.default_target_progress)}." | |
| ), | |
| "domain": state.get("domain", self.domain_name), | |
| "mission": state.get("mission"), | |
| "step_count": state.get("step_count", 0), | |
| "max_steps": state.get("max_steps", self.default_max_steps), | |
| "progress": state.get("progress", 0), | |
| "target_progress": state.get("target_progress", self.default_target_progress), | |
| "last_event": state["history"][-1]["event"] if state.get("history") else "reset", | |
| } | |
| def reset(self, **kwargs: Any) -> dict[str, Any]: | |
| raw_state = self.build_initial_state(**kwargs) | |
| state = self._ensure_common_state(raw_state) | |
| return { | |
| "observation": self.get_observation(state), | |
| "state": state, | |
| "info": self._build_reset_info(state), | |
| } | |
| def step(self, state: dict[str, Any], action: dict[str, Any] | None) -> dict[str, Any]: | |
| state = deepcopy(state) | |
| if state.get("done", False): | |
| return { | |
| "observation": self.get_observation(state), | |
| "reward": 0.0, | |
| "done": True, | |
| "truncated": False, | |
| "info": { | |
| "event": "episode_already_finished", | |
| "success": state.get("success", False), | |
| "domain": state.get("domain", self.domain_name), | |
| "env_name": self.env_name, | |
| "env_id": (state.get("metadata") or {}).get("env_id"), | |
| }, | |
| "state": state, | |
| } | |
| normalized_action = self.normalize_action(action) | |
| outcome = self.apply_action(state, normalized_action) | |
| state["step_count"] = int(state.get("step_count", 0)) + 1 | |
| state["progress"] = min( | |
| int(state.get("target_progress", self.default_target_progress)), | |
| int(state.get("progress", 0)) + int(outcome.progress_delta), | |
| ) | |
| state["score"] = round(float(state.get("score", 0.0)) + float(outcome.reward), 6) | |
| state["last_action"] = normalized_action | |
| for key, value in outcome.state_updates.items(): | |
| state[key] = value | |
| auto_success = state["progress"] >= int(state.get("target_progress", self.default_target_progress)) | |
| state["success"] = bool(outcome.success or auto_success) | |
| reached_max_steps = state["step_count"] >= int(state.get("max_steps", self.default_max_steps)) | |
| state["done"] = bool(outcome.done or state["success"] or reached_max_steps) | |
| truncated = bool(outcome.truncated or (reached_max_steps and not state["success"])) | |
| history_entry = { | |
| "step": state["step_count"], | |
| "event": outcome.event, | |
| "reward": float(outcome.reward), | |
| "progress": int(state["progress"]), | |
| } | |
| if outcome.info: | |
| history_entry["info"] = deepcopy(outcome.info) | |
| state.setdefault("history", []).append(history_entry) | |
| info = { | |
| "event": outcome.event, | |
| "success": state["success"], | |
| "domain": state.get("domain", self.domain_name), | |
| "env_name": self.env_name, | |
| "env_id": (state.get("metadata") or {}).get("env_id"), | |
| **deepcopy(outcome.info), | |
| } | |
| return { | |
| "observation": self.get_observation(state), | |
| "reward": float(outcome.reward), | |
| "done": bool(state["done"]), | |
| "truncated": truncated, | |
| "info": info, | |
| "state": state, | |
| } | |
| def clone_state(self, state: dict[str, Any]) -> dict[str, Any]: | |
| return deepcopy(state) | |
| def _ensure_common_state(self, state: dict[str, Any]) -> dict[str, Any]: | |
| normalized = deepcopy(state) | |
| normalized.setdefault("domain", self.domain_name) | |
| normalized.setdefault("seed", 0) | |
| normalized.setdefault("mission", f"Complete the {self.domain_name} objective.") | |
| normalized.setdefault("max_steps", self.default_max_steps) | |
| normalized.setdefault("step_count", 0) | |
| normalized.setdefault("progress", 0) | |
| normalized.setdefault("target_progress", self.default_target_progress) | |
| normalized.setdefault("score", 0.0) | |
| normalized.setdefault("done", False) | |
| normalized.setdefault("success", False) | |
| normalized.setdefault("last_action", self.default_action()) | |
| normalized.setdefault("history", []) | |
| normalized.setdefault("metadata", {}) | |
| return normalized | |
| def _build_reset_info(self, state: dict[str, Any]) -> dict[str, Any]: | |
| metadata = state.get("metadata") or {} | |
| return { | |
| "domain": state.get("domain", self.domain_name), | |
| "env_name": self.env_name, | |
| "env_id": metadata.get("env_id"), | |
| "mission": state.get("mission"), | |
| "task_category": state.get("task_category"), | |
| } | |