Spaces:
Sleeping
Sleeping
updated
Browse files- envs/social_stream_moderation/environment.py +18 -5
- envs/social_stream_moderation/graders.py +208 -14
- envs/social_stream_moderation/tasks.py +7 -3
- openenv.yaml +3 -0
- server/app.py +21 -28
envs/social_stream_moderation/environment.py
CHANGED
|
@@ -4,7 +4,7 @@ import random
|
|
| 4 |
from typing import List, Dict, Any, Tuple, Optional
|
| 5 |
from .models import HarmLabel, ModerationAction, State, PolicyMode, Post, UserGroup
|
| 6 |
from .tasks import TASKS, TaskConfig
|
| 7 |
-
from .graders import compute_per_post_reward, grade_episode
|
| 8 |
|
| 9 |
class SocialStreamModerationEnv:
|
| 10 |
def __init__(self, data_dir: Optional[str] = None):
|
|
@@ -17,7 +17,8 @@ class SocialStreamModerationEnv:
|
|
| 17 |
self.done = False
|
| 18 |
self.episode_history: List[Dict[str, Any]] = []
|
| 19 |
self.policy_mode = PolicyMode.NORMAL
|
| 20 |
-
|
|
|
|
| 21 |
@classmethod
|
| 22 |
async def from_docker_image(cls, image_name: Optional[str] = None):
|
| 23 |
"""Standard OpenEnv V4 interface for initializing the environment."""
|
|
@@ -47,7 +48,11 @@ class SocialStreamModerationEnv:
|
|
| 47 |
self.done = False
|
| 48 |
self.episode_history = []
|
| 49 |
self.policy_mode = self.current_task.policy_mode
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
return self._get_state()
|
| 52 |
|
| 53 |
def _get_state(self) -> State:
|
|
@@ -88,11 +93,14 @@ class SocialStreamModerationEnv:
|
|
| 88 |
# Compute reward
|
| 89 |
reward = compute_per_post_reward(current_post.harm_label, action, self.policy_mode)
|
| 90 |
|
| 91 |
-
# Log to history for final grading
|
|
|
|
| 92 |
self.episode_history.append({
|
| 93 |
"post_id": current_post.post_id,
|
| 94 |
"harm_label": current_post.harm_label,
|
| 95 |
"user_group": current_post.user_group,
|
|
|
|
|
|
|
| 96 |
"action": action,
|
| 97 |
"reward": reward
|
| 98 |
})
|
|
@@ -114,7 +122,12 @@ class SocialStreamModerationEnv:
|
|
| 114 |
}
|
| 115 |
|
| 116 |
if self.done:
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
info["score"] = final_score
|
|
|
|
| 119 |
|
| 120 |
return next_state, reward, self.done, info
|
|
|
|
| 4 |
from typing import List, Dict, Any, Tuple, Optional
|
| 5 |
from .models import HarmLabel, ModerationAction, State, PolicyMode, Post, UserGroup
|
| 6 |
from .tasks import TASKS, TaskConfig
|
| 7 |
+
from .graders import compute_per_post_reward, grade_episode, get_grader
|
| 8 |
|
| 9 |
class SocialStreamModerationEnv:
|
| 10 |
def __init__(self, data_dir: Optional[str] = None):
|
|
|
|
| 17 |
self.done = False
|
| 18 |
self.episode_history: List[Dict[str, Any]] = []
|
| 19 |
self.policy_mode = PolicyMode.NORMAL
|
| 20 |
+
self._grader = None
|
| 21 |
+
|
| 22 |
@classmethod
|
| 23 |
async def from_docker_image(cls, image_name: Optional[str] = None):
|
| 24 |
"""Standard OpenEnv V4 interface for initializing the environment."""
|
|
|
|
| 48 |
self.done = False
|
| 49 |
self.episode_history = []
|
| 50 |
self.policy_mode = self.current_task.policy_mode
|
| 51 |
+
|
| 52 |
+
# Initialise the grader for this task
|
| 53 |
+
self._grader = get_grader(self.current_task.grader_id)
|
| 54 |
+
self._grader.reset()
|
| 55 |
+
|
| 56 |
return self._get_state()
|
| 57 |
|
| 58 |
def _get_state(self) -> State:
|
|
|
|
| 93 |
# Compute reward
|
| 94 |
reward = compute_per_post_reward(current_post.harm_label, action, self.policy_mode)
|
| 95 |
|
| 96 |
+
# Log to history for final grading (include context_type and
|
| 97 |
+
# policy_mode so context-aware graders can use them)
|
| 98 |
self.episode_history.append({
|
| 99 |
"post_id": current_post.post_id,
|
| 100 |
"harm_label": current_post.harm_label,
|
| 101 |
"user_group": current_post.user_group,
|
| 102 |
+
"context_type": current_post.context_type,
|
| 103 |
+
"policy_mode": self.policy_mode,
|
| 104 |
"action": action,
|
| 105 |
"reward": reward
|
| 106 |
})
|
|
|
|
| 122 |
}
|
| 123 |
|
| 124 |
if self.done:
|
| 125 |
+
# Use the task-specific grader when available
|
| 126 |
+
if self._grader is not None:
|
| 127 |
+
final_score = self._grader.grade(self.episode_history)
|
| 128 |
+
else:
|
| 129 |
+
final_score = grade_episode(self.episode_history, self.current_task.use_fairness)
|
| 130 |
info["score"] = final_score
|
| 131 |
+
info["grader_id"] = self.current_task.grader_id
|
| 132 |
|
| 133 |
return next_state, reward, self.done, info
|
envs/social_stream_moderation/graders.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
-
from typing import List, Dict, Any
|
| 2 |
-
from .
|
|
|
|
| 3 |
|
| 4 |
REWARD_TABLE = {
|
| 5 |
HarmLabel.SAFE: {
|
|
@@ -39,28 +40,33 @@ REWARD_TABLE = {
|
|
| 39 |
}
|
| 40 |
}
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
def compute_per_post_reward(harm_label: HarmLabel, action: ModerationAction, policy_mode: PolicyMode = PolicyMode.NORMAL) -> float:
|
| 43 |
base_reward = REWARD_TABLE.get(harm_label, {}).get(action, 0.01)
|
| 44 |
-
|
| 45 |
# Adjust for policy mode if needed
|
| 46 |
if policy_mode == PolicyMode.LENIENT:
|
| 47 |
-
#
|
| 48 |
if harm_label == HarmLabel.SAFE and action in [ModerationAction.SOFT_HIDE, ModerationAction.BAN_USER]:
|
| 49 |
base_reward -= 0.1
|
| 50 |
elif policy_mode == PolicyMode.STRICT:
|
| 51 |
# Strict: penalize under-moderation more
|
| 52 |
if harm_label != HarmLabel.SAFE and action == ModerationAction.ALLOW:
|
| 53 |
base_reward -= 0.2
|
| 54 |
-
|
| 55 |
return max(0.01, min(0.99, base_reward))
|
| 56 |
|
|
|
|
| 57 |
def calculate_fairness_penalty(episode_history: List[Dict[str, Any]], threshold: float = 0.3) -> float:
|
| 58 |
"""Calculates fairness penalty if group ban rate imbalance is high."""
|
| 59 |
-
group_stats = {
|
| 60 |
UserGroup.A: {"bans": 0, "total": 0},
|
| 61 |
UserGroup.B: {"bans": 0, "total": 0}
|
| 62 |
}
|
| 63 |
-
|
| 64 |
for entry in episode_history:
|
| 65 |
group = entry.get("user_group")
|
| 66 |
action = entry.get("action")
|
|
@@ -68,32 +74,220 @@ def calculate_fairness_penalty(episode_history: List[Dict[str, Any]], threshold:
|
|
| 68 |
group_stats[group]["total"] += 1
|
| 69 |
if action == ModerationAction.BAN_USER:
|
| 70 |
group_stats[group]["bans"] += 1
|
| 71 |
-
|
| 72 |
-
ban_rates = {}
|
| 73 |
for group, stats in group_stats.items():
|
| 74 |
if stats["total"] > 0:
|
| 75 |
ban_rates[group] = stats["bans"] / stats["total"]
|
| 76 |
else:
|
| 77 |
ban_rates[group] = 0.0
|
| 78 |
-
|
| 79 |
diff = abs(ban_rates.get(UserGroup.A, 0.0) - ban_rates.get(UserGroup.B, 0.0))
|
| 80 |
if diff > threshold:
|
| 81 |
# Return penalty up to 0.2
|
| 82 |
return min(0.2, (diff - threshold) * 0.5)
|
| 83 |
-
|
| 84 |
return 0.0
|
| 85 |
|
|
|
|
| 86 |
def grade_episode(episode_history: List[Dict[str, Any]], use_fairness: bool = False) -> float:
|
| 87 |
if not episode_history:
|
| 88 |
return 0.01
|
| 89 |
-
|
| 90 |
rewards = [entry.get("reward", 0.01) for entry in episode_history]
|
| 91 |
mean_reward = sum(rewards) / len(rewards)
|
| 92 |
-
|
| 93 |
if use_fairness:
|
| 94 |
penalty = calculate_fairness_penalty(episode_history)
|
| 95 |
mean_reward = max(0.01, mean_reward - penalty)
|
| 96 |
-
|
| 97 |
# Strictly between 0 and 1
|
| 98 |
final_score = max(0.001, min(0.999, mean_reward))
|
| 99 |
return round(final_score, 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Any, Optional
|
| 2 |
+
from openenv.core.rubrics import Rubric
|
| 3 |
+
from .models import HarmLabel, ModerationAction, UserGroup, PolicyMode, ContextType
|
| 4 |
|
| 5 |
REWARD_TABLE = {
|
| 6 |
HarmLabel.SAFE: {
|
|
|
|
| 40 |
}
|
| 41 |
}
|
| 42 |
|
| 43 |
+
# ---------------------------------------------------------------------------
|
| 44 |
+
# Utility functions (preserved for backward compatibility)
|
| 45 |
+
# ---------------------------------------------------------------------------
|
| 46 |
+
|
| 47 |
def compute_per_post_reward(harm_label: HarmLabel, action: ModerationAction, policy_mode: PolicyMode = PolicyMode.NORMAL) -> float:
|
| 48 |
base_reward = REWARD_TABLE.get(harm_label, {}).get(action, 0.01)
|
| 49 |
+
|
| 50 |
# Adjust for policy mode if needed
|
| 51 |
if policy_mode == PolicyMode.LENIENT:
|
| 52 |
+
# Lenient: penalize over-censorship (Hiding safe content) more
|
| 53 |
if harm_label == HarmLabel.SAFE and action in [ModerationAction.SOFT_HIDE, ModerationAction.BAN_USER]:
|
| 54 |
base_reward -= 0.1
|
| 55 |
elif policy_mode == PolicyMode.STRICT:
|
| 56 |
# Strict: penalize under-moderation more
|
| 57 |
if harm_label != HarmLabel.SAFE and action == ModerationAction.ALLOW:
|
| 58 |
base_reward -= 0.2
|
| 59 |
+
|
| 60 |
return max(0.01, min(0.99, base_reward))
|
| 61 |
|
| 62 |
+
|
| 63 |
def calculate_fairness_penalty(episode_history: List[Dict[str, Any]], threshold: float = 0.3) -> float:
|
| 64 |
"""Calculates fairness penalty if group ban rate imbalance is high."""
|
| 65 |
+
group_stats: Dict[UserGroup, Dict[str, int]] = {
|
| 66 |
UserGroup.A: {"bans": 0, "total": 0},
|
| 67 |
UserGroup.B: {"bans": 0, "total": 0}
|
| 68 |
}
|
| 69 |
+
|
| 70 |
for entry in episode_history:
|
| 71 |
group = entry.get("user_group")
|
| 72 |
action = entry.get("action")
|
|
|
|
| 74 |
group_stats[group]["total"] += 1
|
| 75 |
if action == ModerationAction.BAN_USER:
|
| 76 |
group_stats[group]["bans"] += 1
|
| 77 |
+
|
| 78 |
+
ban_rates: Dict[UserGroup, float] = {}
|
| 79 |
for group, stats in group_stats.items():
|
| 80 |
if stats["total"] > 0:
|
| 81 |
ban_rates[group] = stats["bans"] / stats["total"]
|
| 82 |
else:
|
| 83 |
ban_rates[group] = 0.0
|
| 84 |
+
|
| 85 |
diff = abs(ban_rates.get(UserGroup.A, 0.0) - ban_rates.get(UserGroup.B, 0.0))
|
| 86 |
if diff > threshold:
|
| 87 |
# Return penalty up to 0.2
|
| 88 |
return min(0.2, (diff - threshold) * 0.5)
|
| 89 |
+
|
| 90 |
return 0.0
|
| 91 |
|
| 92 |
+
|
| 93 |
def grade_episode(episode_history: List[Dict[str, Any]], use_fairness: bool = False) -> float:
|
| 94 |
if not episode_history:
|
| 95 |
return 0.01
|
| 96 |
+
|
| 97 |
rewards = [entry.get("reward", 0.01) for entry in episode_history]
|
| 98 |
mean_reward = sum(rewards) / len(rewards)
|
| 99 |
+
|
| 100 |
if use_fairness:
|
| 101 |
penalty = calculate_fairness_penalty(episode_history)
|
| 102 |
mean_reward = max(0.01, mean_reward - penalty)
|
| 103 |
+
|
| 104 |
# Strictly between 0 and 1
|
| 105 |
final_score = max(0.001, min(0.999, mean_reward))
|
| 106 |
return round(final_score, 4)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ---------------------------------------------------------------------------
|
| 110 |
+
# Grader classes – each wraps the scoring logic for one task and extends
|
| 111 |
+
# the openenv ``Rubric`` base class so the graders are first-class objects
|
| 112 |
+
# in the OpenEnv evaluation framework.
|
| 113 |
+
# ---------------------------------------------------------------------------
|
| 114 |
+
|
| 115 |
+
class BasicSafetyGrader(Rubric):
|
| 116 |
+
"""Grader for Task 1: Basic Safety.
|
| 117 |
+
|
| 118 |
+
Evaluates whether the agent correctly classifies clearly safe vs.
|
| 119 |
+
clearly harmful posts using the reward table. No context-aware or
|
| 120 |
+
fairness adjustments are applied – the score is purely accuracy-based.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
grader_id: str = "basic_safety_grader"
|
| 124 |
+
description: str = "Grader for basic safety checks"
|
| 125 |
+
|
| 126 |
+
def __init__(self) -> None:
|
| 127 |
+
super().__init__()
|
| 128 |
+
self._episode_history: List[Dict[str, Any]] = []
|
| 129 |
+
|
| 130 |
+
def forward(self, action: Any, observation: Any) -> float:
|
| 131 |
+
"""Score a single moderation step.
|
| 132 |
+
|
| 133 |
+
``action`` is a dict with at least ``harm_label``, ``action``, and
|
| 134 |
+
``policy_mode`` keys (matching the episode-history entries produced
|
| 135 |
+
by ``SocialStreamModerationEnv``).
|
| 136 |
+
``observation`` is unused but required by the Rubric interface.
|
| 137 |
+
"""
|
| 138 |
+
harm_label = action.get("harm_label", HarmLabel.SAFE)
|
| 139 |
+
mod_action = action.get("action", ModerationAction.ALLOW)
|
| 140 |
+
policy_mode = action.get("policy_mode", PolicyMode.NORMAL)
|
| 141 |
+
|
| 142 |
+
reward = compute_per_post_reward(harm_label, mod_action, policy_mode)
|
| 143 |
+
step_entry = {**action, "reward": reward}
|
| 144 |
+
self._episode_history.append(step_entry)
|
| 145 |
+
return reward
|
| 146 |
+
|
| 147 |
+
def grade(self, episode_history: Optional[List[Dict[str, Any]]] = None) -> float:
|
| 148 |
+
"""Return the final episode score.
|
| 149 |
+
|
| 150 |
+
If ``episode_history`` is provided it is used directly; otherwise the
|
| 151 |
+
internally accumulated history is scored.
|
| 152 |
+
"""
|
| 153 |
+
history = episode_history if episode_history is not None else self._episode_history
|
| 154 |
+
return grade_episode(history, use_fairness=False)
|
| 155 |
+
|
| 156 |
+
def reset(self) -> None:
|
| 157 |
+
self._episode_history = []
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class ContextNuanceGrader(Rubric):
|
| 161 |
+
"""Grader for Task 2: Context & Nuance.
|
| 162 |
+
|
| 163 |
+
Extends the basic reward-table scoring with context-aware adjustments:
|
| 164 |
+
* **Quote posts** that are labeled SAFE (e.g. quoting hate speech to
|
| 165 |
+
condemn it) receive a bonus when correctly allowed.
|
| 166 |
+
* **Sarcastic / heated replies** labeled SAFE receive a bonus for
|
| 167 |
+
correct identification.
|
| 168 |
+
* Over-moderation of contextually safe content is penalised more
|
| 169 |
+
heavily than in the basic grader.
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
grader_id: str = "context_nuance_grader"
|
| 173 |
+
description: str = "Grader for contextual and sarcastic content"
|
| 174 |
+
|
| 175 |
+
# Bonus / penalty multipliers for context-aware scoring
|
| 176 |
+
CONTEXT_BONUS: float = 0.05
|
| 177 |
+
OVER_MODERATION_PENALTY: float = 0.15
|
| 178 |
+
|
| 179 |
+
def __init__(self) -> None:
|
| 180 |
+
super().__init__()
|
| 181 |
+
self._episode_history: List[Dict[str, Any]] = []
|
| 182 |
+
|
| 183 |
+
def _context_adjusted_reward(self, entry: Dict[str, Any]) -> float:
|
| 184 |
+
"""Compute a context-aware reward for a single step."""
|
| 185 |
+
harm_label = entry.get("harm_label", HarmLabel.SAFE)
|
| 186 |
+
mod_action = entry.get("action", ModerationAction.ALLOW)
|
| 187 |
+
policy_mode = entry.get("policy_mode", PolicyMode.NORMAL)
|
| 188 |
+
context_type = entry.get("context_type", ContextType.ROOT_POST)
|
| 189 |
+
|
| 190 |
+
base = compute_per_post_reward(harm_label, mod_action, policy_mode)
|
| 191 |
+
|
| 192 |
+
# Bonus: correctly allowing safe quote_posts or sarcastic replies
|
| 193 |
+
if harm_label == HarmLabel.SAFE and context_type in (ContextType.QUOTE_POST, ContextType.REPLY_HEATED):
|
| 194 |
+
if mod_action == ModerationAction.ALLOW:
|
| 195 |
+
base += self.CONTEXT_BONUS
|
| 196 |
+
|
| 197 |
+
# Penalty: over-moderating contextually safe content
|
| 198 |
+
if harm_label == HarmLabel.SAFE and context_type in (ContextType.QUOTE_POST, ContextType.REPLY_HEATED):
|
| 199 |
+
if mod_action in (ModerationAction.SOFT_HIDE, ModerationAction.BAN_USER):
|
| 200 |
+
base -= self.OVER_MODERATION_PENALTY
|
| 201 |
+
|
| 202 |
+
return max(0.01, min(0.99, base))
|
| 203 |
+
|
| 204 |
+
def forward(self, action: Any, observation: Any) -> float:
|
| 205 |
+
reward = self._context_adjusted_reward(action)
|
| 206 |
+
step_entry = {**action, "reward": reward}
|
| 207 |
+
self._episode_history.append(step_entry)
|
| 208 |
+
return reward
|
| 209 |
+
|
| 210 |
+
def grade(self, episode_history: Optional[List[Dict[str, Any]]] = None) -> float:
|
| 211 |
+
"""Return the final episode score with context-aware adjustments."""
|
| 212 |
+
history = episode_history if episode_history is not None else self._episode_history
|
| 213 |
+
if not history:
|
| 214 |
+
return 0.01
|
| 215 |
+
|
| 216 |
+
# Re-compute rewards with context adjustments for externally supplied history
|
| 217 |
+
if episode_history is not None:
|
| 218 |
+
rewards = [self._context_adjusted_reward(e) for e in history]
|
| 219 |
+
else:
|
| 220 |
+
rewards = [e.get("reward", 0.01) for e in history]
|
| 221 |
+
|
| 222 |
+
mean_reward = sum(rewards) / len(rewards)
|
| 223 |
+
return round(max(0.001, min(0.999, mean_reward)), 4)
|
| 224 |
+
|
| 225 |
+
def reset(self) -> None:
|
| 226 |
+
self._episode_history = []
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class FairnessBiasGrader(Rubric):
|
| 230 |
+
"""Grader for Task 3: Fairness & Bias.
|
| 231 |
+
|
| 232 |
+
Uses the same per-post reward table as the basic grader but adds a
|
| 233 |
+
**fairness penalty** at the episode level based on ban-rate disparity
|
| 234 |
+
between user groups. Also applies stricter policy-mode adjustments
|
| 235 |
+
to penalise under-moderation.
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
grader_id: str = "fairness_bias_grader"
|
| 239 |
+
description: str = "Grader for fairness and bias parity"
|
| 240 |
+
|
| 241 |
+
FAIRNESS_THRESHOLD: float = 0.3
|
| 242 |
+
|
| 243 |
+
def __init__(self) -> None:
|
| 244 |
+
super().__init__()
|
| 245 |
+
self._episode_history: List[Dict[str, Any]] = []
|
| 246 |
+
|
| 247 |
+
def forward(self, action: Any, observation: Any) -> float:
|
| 248 |
+
harm_label = action.get("harm_label", HarmLabel.SAFE)
|
| 249 |
+
mod_action = action.get("action", ModerationAction.ALLOW)
|
| 250 |
+
policy_mode = action.get("policy_mode", PolicyMode.NORMAL)
|
| 251 |
+
|
| 252 |
+
reward = compute_per_post_reward(harm_label, mod_action, policy_mode)
|
| 253 |
+
step_entry = {**action, "reward": reward}
|
| 254 |
+
self._episode_history.append(step_entry)
|
| 255 |
+
return reward
|
| 256 |
+
|
| 257 |
+
def grade(self, episode_history: Optional[List[Dict[str, Any]]] = None) -> float:
|
| 258 |
+
"""Return the final episode score including the fairness penalty."""
|
| 259 |
+
history = episode_history if episode_history is not None else self._episode_history
|
| 260 |
+
return grade_episode(history, use_fairness=True)
|
| 261 |
+
|
| 262 |
+
def reset(self) -> None:
|
| 263 |
+
self._episode_history = []
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# ---------------------------------------------------------------------------
|
| 267 |
+
# Grader registry – maps grader IDs declared in openenv.yaml to instances.
|
| 268 |
+
# ---------------------------------------------------------------------------
|
| 269 |
+
|
| 270 |
+
GRADER_REGISTRY: Dict[str, Rubric] = {
|
| 271 |
+
BasicSafetyGrader.grader_id: BasicSafetyGrader(),
|
| 272 |
+
ContextNuanceGrader.grader_id: ContextNuanceGrader(),
|
| 273 |
+
FairnessBiasGrader.grader_id: FairnessBiasGrader(),
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def get_grader(grader_id: str) -> Rubric:
|
| 278 |
+
"""Look up a grader by its ID. Raises ``KeyError`` if not found."""
|
| 279 |
+
if grader_id not in GRADER_REGISTRY:
|
| 280 |
+
raise KeyError(f"Grader '{grader_id}' not found. Available: {list(GRADER_REGISTRY.keys())}")
|
| 281 |
+
return GRADER_REGISTRY[grader_id]
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def list_graders() -> List[Dict[str, str]]:
|
| 285 |
+
"""Return metadata for all registered graders."""
|
| 286 |
+
result: List[Dict[str, str]] = []
|
| 287 |
+
for grader_id, grader in GRADER_REGISTRY.items():
|
| 288 |
+
result.append({
|
| 289 |
+
"id": grader_id,
|
| 290 |
+
"description": grader.description,
|
| 291 |
+
"class": type(grader).__name__,
|
| 292 |
+
})
|
| 293 |
+
return result
|
envs/social_stream_moderation/tasks.py
CHANGED
|
@@ -9,6 +9,7 @@ class TaskConfig(BaseModel):
|
|
| 9 |
episode_length: int
|
| 10 |
policy_mode: PolicyMode
|
| 11 |
use_fairness: bool = False
|
|
|
|
| 12 |
|
| 13 |
TASKS = {
|
| 14 |
"Task 1: Basic Safety": TaskConfig(
|
|
@@ -17,7 +18,8 @@ TASKS = {
|
|
| 17 |
data_file="data_easy.json",
|
| 18 |
episode_length=10,
|
| 19 |
policy_mode=PolicyMode.NORMAL,
|
| 20 |
-
use_fairness=False
|
|
|
|
| 21 |
),
|
| 22 |
"Task 2: Context & Nuance": TaskConfig(
|
| 23 |
name="Task 2: Context & Nuance",
|
|
@@ -25,7 +27,8 @@ TASKS = {
|
|
| 25 |
data_file="data_medium.json",
|
| 26 |
episode_length=15,
|
| 27 |
policy_mode=PolicyMode.NORMAL,
|
| 28 |
-
use_fairness=False
|
|
|
|
| 29 |
),
|
| 30 |
"Task 3: Fairness & Bias": TaskConfig(
|
| 31 |
name="Task 3: Fairness & Bias",
|
|
@@ -33,6 +36,7 @@ TASKS = {
|
|
| 33 |
data_file="data_hard.json",
|
| 34 |
episode_length=20,
|
| 35 |
policy_mode=PolicyMode.NORMAL,
|
| 36 |
-
use_fairness=True
|
|
|
|
| 37 |
)
|
| 38 |
}
|
|
|
|
| 9 |
episode_length: int
|
| 10 |
policy_mode: PolicyMode
|
| 11 |
use_fairness: bool = False
|
| 12 |
+
grader_id: str = "basic_safety_grader"
|
| 13 |
|
| 14 |
TASKS = {
|
| 15 |
"Task 1: Basic Safety": TaskConfig(
|
|
|
|
| 18 |
data_file="data_easy.json",
|
| 19 |
episode_length=10,
|
| 20 |
policy_mode=PolicyMode.NORMAL,
|
| 21 |
+
use_fairness=False,
|
| 22 |
+
grader_id="basic_safety_grader"
|
| 23 |
),
|
| 24 |
"Task 2: Context & Nuance": TaskConfig(
|
| 25 |
name="Task 2: Context & Nuance",
|
|
|
|
| 27 |
data_file="data_medium.json",
|
| 28 |
episode_length=15,
|
| 29 |
policy_mode=PolicyMode.NORMAL,
|
| 30 |
+
use_fairness=False,
|
| 31 |
+
grader_id="context_nuance_grader"
|
| 32 |
),
|
| 33 |
"Task 3: Fairness & Bias": TaskConfig(
|
| 34 |
name="Task 3: Fairness & Bias",
|
|
|
|
| 36 |
data_file="data_hard.json",
|
| 37 |
episode_length=20,
|
| 38 |
policy_mode=PolicyMode.NORMAL,
|
| 39 |
+
use_fairness=True,
|
| 40 |
+
grader_id="fairness_bias_grader"
|
| 41 |
)
|
| 42 |
}
|
openenv.yaml
CHANGED
|
@@ -3,10 +3,13 @@ version: 1.0.0
|
|
| 3 |
graders:
|
| 4 |
- id: basic_safety_grader
|
| 5 |
description: "Grader for basic safety checks"
|
|
|
|
| 6 |
- id: context_nuance_grader
|
| 7 |
description: "Grader for contextual and sarcastic content"
|
|
|
|
| 8 |
- id: fairness_bias_grader
|
| 9 |
description: "Grader for fairness and bias parity"
|
|
|
|
| 10 |
|
| 11 |
tasks:
|
| 12 |
- id: "Task 1: Basic Safety"
|
|
|
|
| 3 |
graders:
|
| 4 |
- id: basic_safety_grader
|
| 5 |
description: "Grader for basic safety checks"
|
| 6 |
+
entry_point: envs.social_stream_moderation.graders:BasicSafetyGrader
|
| 7 |
- id: context_nuance_grader
|
| 8 |
description: "Grader for contextual and sarcastic content"
|
| 9 |
+
entry_point: envs.social_stream_moderation.graders:ContextNuanceGrader
|
| 10 |
- id: fairness_bias_grader
|
| 11 |
description: "Grader for fairness and bias parity"
|
| 12 |
+
entry_point: envs.social_stream_moderation.graders:FairnessBiasGrader
|
| 13 |
|
| 14 |
tasks:
|
| 15 |
- id: "Task 1: Basic Safety"
|
server/app.py
CHANGED
|
@@ -10,6 +10,8 @@ from typing import Optional, Dict, Any, List
|
|
| 10 |
from enum import Enum
|
| 11 |
from envs.social_stream_moderation.environment import SocialStreamModerationEnv
|
| 12 |
from envs.social_stream_moderation.models import State, ModerationAction
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# Enums for Swagger Dropdowns
|
| 15 |
class TaskName(str, Enum):
|
|
@@ -772,33 +774,18 @@ async def list_tasks():
|
|
| 772 |
"""Returns the list of tasks available in the environment for discovery."""
|
| 773 |
return [
|
| 774 |
{
|
| 775 |
-
"id":
|
| 776 |
-
"difficulty":
|
| 777 |
-
"description": "
|
| 778 |
-
"grader_id":
|
| 779 |
-
},
|
| 780 |
-
{
|
| 781 |
-
"id": "Task 2: Context & Nuance",
|
| 782 |
-
"difficulty": "medium",
|
| 783 |
-
"description": "Handle sarcastic content and quotes of harmful material with condemnation.",
|
| 784 |
-
"grader_id": "context_nuance_grader"
|
| 785 |
-
},
|
| 786 |
-
{
|
| 787 |
-
"id": "Task 3: Fairness & Bias",
|
| 788 |
-
"difficulty": "hard",
|
| 789 |
-
"description": "Ensure fairness across user groups and adhere to stricter policy regimes.",
|
| 790 |
-
"grader_id": "fairness_bias_grader"
|
| 791 |
}
|
|
|
|
| 792 |
]
|
| 793 |
|
| 794 |
@app.get("/graders", tags=["🛡️ Automated Benchmarking"])
|
| 795 |
-
async def
|
| 796 |
"""Returns the list of graders available in the environment for discovery."""
|
| 797 |
-
return
|
| 798 |
-
{"id": "basic_safety_grader", "description": "Grader for basic safety checks"},
|
| 799 |
-
{"id": "context_nuance_grader", "description": "Grader for contextual and sarcastic content"},
|
| 800 |
-
{"id": "fairness_bias_grader", "description": "Grader for fairness and bias parity"}
|
| 801 |
-
]
|
| 802 |
|
| 803 |
@app.post("/evaluate", tags=["🧪 Interactive Lab"], summary="Test Model Logic (XAI Insight)")
|
| 804 |
async def evaluate_text(
|
|
@@ -864,17 +851,20 @@ async def step_env(req: StepRequest):
|
|
| 864 |
next_state, reward, done, info = await env.step(req.action)
|
| 865 |
|
| 866 |
final_score = 0.0
|
|
|
|
| 867 |
if done:
|
| 868 |
-
|
| 869 |
-
#
|
| 870 |
-
final_score =
|
|
|
|
| 871 |
|
| 872 |
return {
|
| 873 |
"next_state": next_state,
|
| 874 |
"reward": reward,
|
| 875 |
"done": done,
|
| 876 |
"info": info,
|
| 877 |
-
"final_score": final_score
|
|
|
|
| 878 |
}
|
| 879 |
|
| 880 |
except RuntimeError as e:
|
|
@@ -900,9 +890,11 @@ async def predict_and_step(req: Optional[LLMConfigRequest] = Body(None)):
|
|
| 900 |
next_state, reward, done, info = await env.step(action)
|
| 901 |
|
| 902 |
final_score = 0.0
|
|
|
|
| 903 |
if done:
|
| 904 |
-
|
| 905 |
-
final_score =
|
|
|
|
| 906 |
|
| 907 |
return {
|
| 908 |
"prediction": action.value,
|
|
@@ -910,6 +902,7 @@ async def predict_and_step(req: Optional[LLMConfigRequest] = Body(None)):
|
|
| 910 |
"reward": reward,
|
| 911 |
"done": done,
|
| 912 |
"final_score": final_score,
|
|
|
|
| 913 |
"next_state": next_state,
|
| 914 |
"info": info
|
| 915 |
}
|
|
|
|
| 10 |
from enum import Enum
|
| 11 |
from envs.social_stream_moderation.environment import SocialStreamModerationEnv
|
| 12 |
from envs.social_stream_moderation.models import State, ModerationAction
|
| 13 |
+
from envs.social_stream_moderation.graders import list_graders as _list_graders, get_grader
|
| 14 |
+
from envs.social_stream_moderation.tasks import TASKS
|
| 15 |
|
| 16 |
# Enums for Swagger Dropdowns
|
| 17 |
class TaskName(str, Enum):
|
|
|
|
| 774 |
"""Returns the list of tasks available in the environment for discovery."""
|
| 775 |
return [
|
| 776 |
{
|
| 777 |
+
"id": task_cfg.name,
|
| 778 |
+
"difficulty": task_cfg.difficulty,
|
| 779 |
+
"description": f"Episode length: {task_cfg.episode_length} posts. Policy mode: {task_cfg.policy_mode.value}.",
|
| 780 |
+
"grader_id": task_cfg.grader_id,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 781 |
}
|
| 782 |
+
for task_cfg in TASKS.values()
|
| 783 |
]
|
| 784 |
|
| 785 |
@app.get("/graders", tags=["🛡️ Automated Benchmarking"])
|
| 786 |
+
async def list_graders_endpoint():
|
| 787 |
"""Returns the list of graders available in the environment for discovery."""
|
| 788 |
+
return _list_graders()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 789 |
|
| 790 |
@app.post("/evaluate", tags=["🧪 Interactive Lab"], summary="Test Model Logic (XAI Insight)")
|
| 791 |
async def evaluate_text(
|
|
|
|
| 851 |
next_state, reward, done, info = await env.step(req.action)
|
| 852 |
|
| 853 |
final_score = 0.0
|
| 854 |
+
grader_id = None
|
| 855 |
if done:
|
| 856 |
+
# The environment now uses the task-specific grader internally;
|
| 857 |
+
# the final score and grader_id are returned in ``info``.
|
| 858 |
+
final_score = info.get("score", 0.0)
|
| 859 |
+
grader_id = info.get("grader_id")
|
| 860 |
|
| 861 |
return {
|
| 862 |
"next_state": next_state,
|
| 863 |
"reward": reward,
|
| 864 |
"done": done,
|
| 865 |
"info": info,
|
| 866 |
+
"final_score": final_score,
|
| 867 |
+
"grader_id": grader_id,
|
| 868 |
}
|
| 869 |
|
| 870 |
except RuntimeError as e:
|
|
|
|
| 890 |
next_state, reward, done, info = await env.step(action)
|
| 891 |
|
| 892 |
final_score = 0.0
|
| 893 |
+
grader_id = None
|
| 894 |
if done:
|
| 895 |
+
# The environment now uses the task-specific grader internally
|
| 896 |
+
final_score = info.get("score", 0.0)
|
| 897 |
+
grader_id = info.get("grader_id")
|
| 898 |
|
| 899 |
return {
|
| 900 |
"prediction": action.value,
|
|
|
|
| 902 |
"reward": reward,
|
| 903 |
"done": done,
|
| 904 |
"final_score": final_score,
|
| 905 |
+
"grader_id": grader_id,
|
| 906 |
"next_state": next_state,
|
| 907 |
"info": info
|
| 908 |
}
|