Spaces:
Sleeping
Sleeping
| """Typed client for the Social Influence Arena.""" | |
| from __future__ import annotations | |
| from typing import Any, Dict, Optional | |
| from openenv.core.client_types import StepResult | |
| from openenv.core.env_client import EnvClient | |
| from .models import ArenaAction, ArenaObservation, ArenaState, BeliefState, DialogTurn | |
| class SocialInfluenceEnv(EnvClient[ArenaAction, ArenaObservation, ArenaState]): | |
| """HTTP client. Use ``reset(task_id=...)`` to select a task per episode.""" | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| task_id: Optional[str] = None, | |
| domain: Optional[str] = None, | |
| ) -> StepResult[ArenaObservation]: | |
| kwargs: Dict[str, Any] = {} | |
| if task_id is not None: | |
| kwargs["task_id"] = task_id | |
| if domain is not None: | |
| kwargs["domain"] = domain | |
| return super().reset(seed=seed, episode_id=episode_id, **kwargs) | |
| # ----------------------------------------------------------------- | |
| # Payload / response shaping | |
| # ----------------------------------------------------------------- | |
| def _step_payload(self, action: ArenaAction) -> Dict[str, Any]: | |
| return { | |
| "belief": action.belief.model_dump(), | |
| "public_response": action.public_response, | |
| "metadata": action.metadata, | |
| } | |
| def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ArenaObservation]: | |
| obs_data = payload.get("observation", {}) | |
| observation = ArenaObservation( | |
| turn=obs_data.get("turn", 0), | |
| task_id=obs_data.get("task_id"), | |
| scenario=obs_data.get("scenario", "BASELINE"), | |
| attacker_persona=obs_data.get("attacker_persona", "NEUTRAL"), | |
| attacker_message=obs_data.get("attacker_message", ""), | |
| question=obs_data.get("question", ""), | |
| dialog_history=[ | |
| DialogTurn(**t) for t in obs_data.get("dialog_history", []) | |
| ], | |
| reward_breakdown=obs_data.get("reward_breakdown", {}), | |
| ground_truth=obs_data.get("ground_truth"), | |
| reward=payload.get("reward"), | |
| done=payload.get("done", False), | |
| metadata=obs_data.get("metadata", {}), | |
| ) | |
| return StepResult( | |
| observation=observation, | |
| reward=payload.get("reward"), | |
| done=payload.get("done", False), | |
| ) | |
| def _parse_state(self, payload: Dict[str, Any]) -> ArenaState: | |
| return ArenaState( | |
| episode_id=payload.get("episode_id"), | |
| step_count=payload.get("step_count", 0), | |
| task_id=payload.get("task_id"), | |
| question_id=payload.get("question_id"), | |
| ground_truth=payload.get("ground_truth"), | |
| cumulative_reward=payload.get("cumulative_reward", 0.0), | |
| ) | |
| __all__ = ["SocialInfluenceEnv"] | |