NDGCodes's picture
Upload folder using huggingface_hub
69c0b6d verified
"""Typed client for the Social Influence Arena."""
from __future__ import annotations
from typing import Any, Dict, Optional
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
from .models import ArenaAction, ArenaObservation, ArenaState, BeliefState, DialogTurn
class SocialInfluenceEnv(EnvClient[ArenaAction, ArenaObservation, ArenaState]):
"""HTTP client. Use ``reset(task_id=...)`` to select a task per episode."""
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
task_id: Optional[str] = None,
domain: Optional[str] = None,
) -> StepResult[ArenaObservation]:
kwargs: Dict[str, Any] = {}
if task_id is not None:
kwargs["task_id"] = task_id
if domain is not None:
kwargs["domain"] = domain
return super().reset(seed=seed, episode_id=episode_id, **kwargs)
# -----------------------------------------------------------------
# Payload / response shaping
# -----------------------------------------------------------------
def _step_payload(self, action: ArenaAction) -> Dict[str, Any]:
return {
"belief": action.belief.model_dump(),
"public_response": action.public_response,
"metadata": action.metadata,
}
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ArenaObservation]:
obs_data = payload.get("observation", {})
observation = ArenaObservation(
turn=obs_data.get("turn", 0),
task_id=obs_data.get("task_id"),
scenario=obs_data.get("scenario", "BASELINE"),
attacker_persona=obs_data.get("attacker_persona", "NEUTRAL"),
attacker_message=obs_data.get("attacker_message", ""),
question=obs_data.get("question", ""),
dialog_history=[
DialogTurn(**t) for t in obs_data.get("dialog_history", [])
],
reward_breakdown=obs_data.get("reward_breakdown", {}),
ground_truth=obs_data.get("ground_truth"),
reward=payload.get("reward"),
done=payload.get("done", False),
metadata=obs_data.get("metadata", {}),
)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict[str, Any]) -> ArenaState:
return ArenaState(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
task_id=payload.get("task_id"),
question_id=payload.get("question_id"),
ground_truth=payload.get("ground_truth"),
cumulative_reward=payload.get("cumulative_reward", 0.0),
)
__all__ = ["SocialInfluenceEnv"]