SafeSpace / client.py
Ishangtxl's picture
Upload folder using huggingface_hub
c7a9ff7 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""SafeSpace Content Moderation Environment Client."""
from typing import Any, Dict, List, Optional
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
try:
from .models import (
ContentItem,
GatheredContext,
ModerationAction,
ModerationObservation,
ModerationState,
RewardBreakdown,
TaskGradeBreakdown,
TriggerInfo,
)
except ImportError: # pragma: no cover
from models import (
ContentItem,
GatheredContext,
ModerationAction,
ModerationObservation,
ModerationState,
RewardBreakdown,
TaskGradeBreakdown,
TriggerInfo,
)
class SafeSpaceEnv(
EnvClient[ModerationAction, ModerationObservation, ModerationState]
):
"""
Client for the SafeSpace Content Moderation Environment.
This client maintains a persistent WebSocket connection to the environment server,
enabling efficient multi-step interactions with lower latency.
Each client instance has its own dedicated environment session on the server.
Example:
>>> # Connect to a running server
>>> with SafeSpaceEnv(base_url="http://localhost:8000").sync() as client:
... result = client.reset()
... print(result.observation.content_item.text)
...
... # Investigate
... result = client.step(ModerationAction(action_type="request_thread_context"))
... print(result.observation.gathered_context.thread_context)
...
... # Make decision
... result = client.step(ModerationAction(
... action_type="decide",
... decision="approve",
... primary_violation="none",
... severity="none",
... confidence=0.9,
... key_factors=["gaming_or_competition_context"]
... ))
... print(f"Reward: {result.reward}")
Example with Docker:
>>> # Automatically start container and connect
>>> client = SafeSpaceEnv.from_docker_image("safespace-env:latest")
>>> try:
... result = client.reset()
... result = client.step(ModerationAction(
... action_type="decide",
... decision="remove",
... primary_violation="5.1",
... severity="high",
... confidence=0.95,
... key_factors=["spam_commercial"]
... ))
... finally:
... client.close()
"""
def _step_payload(self, action: ModerationAction) -> Dict[str, Any]:
"""
Convert ModerationAction to JSON payload for step message.
Args:
action: ModerationAction instance
Returns:
Dictionary representation suitable for JSON encoding
"""
payload: Dict[str, Any] = {
"action_type": action.action_type,
}
# Include decision fields only for decide action
if action.action_type == "decide":
payload["decision"] = action.decision
payload["primary_violation"] = action.primary_violation
payload["severity"] = action.severity
payload["confidence"] = action.confidence
payload["key_factors"] = action.key_factors
return payload
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ModerationObservation]:
"""
Parse server response into StepResult[ModerationObservation].
Args:
payload: JSON response data from server
Returns:
StepResult with ModerationObservation
"""
obs_data = payload.get("observation", {})
# Parse nested objects
content_item = None
if obs_data.get("content_item"):
content_item = ContentItem(**obs_data["content_item"])
trigger_info = None
if obs_data.get("trigger_info"):
trigger_info = TriggerInfo(**obs_data["trigger_info"])
gathered_context = GatheredContext()
if obs_data.get("gathered_context"):
gathered_context = GatheredContext(**obs_data["gathered_context"])
reward_breakdown = None
if obs_data.get("reward_breakdown") is not None:
reward_breakdown = RewardBreakdown.model_validate(
obs_data["reward_breakdown"]
)
grade_breakdown = None
if obs_data.get("grade_breakdown") is not None:
grade_breakdown = TaskGradeBreakdown.model_validate(
obs_data["grade_breakdown"]
)
reward_value = obs_data.get("reward", payload.get("reward"))
observation = ModerationObservation(
content_item=content_item,
trigger_info=trigger_info,
gathered_context=gathered_context,
platform_policy=obs_data.get("platform_policy", ""),
available_factors=obs_data.get("available_factors", []),
actions_taken=obs_data.get("actions_taken", 0),
max_actions=obs_data.get("max_actions", 8),
action_history=obs_data.get("action_history", []),
feedback=obs_data.get("feedback", ""),
error_code=obs_data.get("error_code"),
done=payload.get("done", False),
reward=reward_value,
reward_breakdown=reward_breakdown,
task_grade=obs_data.get("task_grade"),
grade_breakdown=grade_breakdown,
metadata=obs_data.get("metadata", {}),
)
return StepResult(
observation=observation,
reward=reward_value,
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict[str, Any]) -> ModerationState:
"""
Parse server response into ModerationState object.
Args:
payload: JSON response from state request
Returns:
ModerationState object
"""
return ModerationState(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
scenario_id=payload.get("scenario_id"),
task_id=payload.get("task_id"),
difficulty=payload.get("difficulty"),
trigger_type=payload.get("trigger_type"),
actions_taken=payload.get("actions_taken", 0),
max_actions=payload.get("max_actions", 8),
context_requested=payload.get("context_requested", []),
decision_made=payload.get("decision_made", False),
episode_reward=payload.get("episode_reward", 0.0),
raw_episode_reward=payload.get("raw_episode_reward", 0.0),
done=payload.get("done", False),
last_error_code=payload.get("last_error_code"),
)
# Alias for backward compatibility
ContentModerationEnv = SafeSpaceEnv