Spaces:
Sleeping
Sleeping
File size: 7,255 Bytes
2a39e79 abf8abc 2a39e79 abf8abc 2a39e79 abf8abc 2a39e79 abf8abc 2a39e79 abf8abc 2a39e79 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | """
models.py
=========
Typed Pydantic data-models for the ContentModerationEnv OpenEnv spec.
These models are used for:
β’ runtime validation of agent actions
β’ serialisation of environment state
β’ documentation (JSON Schema derivable via .model_json_schema())
Requires: pydantic >= 2.0
"""
from __future__ import annotations
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel, Field, field_validator, model_validator
# ββ Enumerations ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class Label(str, Enum):
"""Content classification label."""
safe = "safe"
toxic = "toxic"
spam = "spam"
misleading = "misleading"
class ModerationAction(str, Enum):
"""Moderation action to apply to the content."""
allow = "allow"
warn = "warn"
remove = "remove"
shadowban = "shadowban"
escalate = "escalate"
class PlatformPolicy(str, Enum):
"""Enforcement level of the platform."""
strict = "strict"
moderate = "moderate"
lenient = "lenient"
class Platform(str, Enum):
"""Social media platform context."""
reddit = "reddit"
twitter = "twitter"
youtube = "youtube"
linkedin = "linkedin"
class Tier(str, Enum):
"""Benchmark difficulty tier."""
easy = "easy"
medium = "medium"
hard = "hard"
# ββ Observation (state) βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class Observation(BaseModel):
"""
The environment's observation returned by reset() / state().
Attributes
----------
text : str
The user-generated text content to be reviewed.
audio_transcript : str | None
Transcript of any accompanying audio/video (None for text-only).
visual_tags : list[str]
Machine-detected visual content tags (empty list if no visual media).
previous_flags : int
Number of prior platform violations by this account (β₯ 0).
platform_policy : PlatformPolicy
Policy enforcement level the moderation decision must respect.
platform : str | None
Social media platform context: reddit / twitter / youtube / linkedin.
None for legacy scenarios that predate the platform field.
"""
text: str = Field(..., description="User-generated content text")
audio_transcript: Optional[str] = Field(None, description="Audio/video transcript (nullable)")
visual_tags: List[str] = Field(default_factory=list, description="Detected visual content tags")
previous_flags: int = Field(..., ge=0, description="Prior policy violations count")
platform_policy: PlatformPolicy = Field(..., description="Platform enforcement level")
platform: Optional[str] = Field(
None,
description="Platform: reddit/twitter/youtube/linkedin"
)
model_config = {"frozen": True} # immutable β agents must not mutate state
# ββ Action ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class AgentAction(BaseModel):
"""
The action an agent submits via env.step().
Required
--------
label : Label
action : ModerationAction
Optional (scored only in hard tier)
------------------------------------
severity : int in [1, 5]
rationale : str
"""
label: Label = Field(..., description="Content classification")
action: ModerationAction = Field(..., description="Moderation action to apply")
severity: Optional[int] = Field(None, ge=1, le=5, description="Severity 1-5 (hard tier only)")
rationale: Optional[str] = Field(None, description="Brief reasoning (not scored)")
@field_validator("severity", mode="before")
@classmethod
def coerce_severity(cls, v):
"""Accept string integers gracefully."""
if v is not None:
return int(v)
return v
def to_env_dict(self) -> dict:
"""Convert to the plain dict format expected by ContentModerationEnv.step()."""
d: dict = {"label": self.label.value, "action": self.action.value}
if self.severity is not None:
d["severity"] = self.severity
if self.rationale is not None:
d["rationale"] = self.rationale
return d
# ββ Score breakdown βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class ScoreBreakdown(BaseModel):
"""Per-component reward breakdown returned in step() info dict."""
label_correct: Optional[float] = Field(None, ge=0.0, le=1.0)
action_correct: Optional[float] = Field(None, ge=0.0, le=1.0)
severity_within_1: Optional[float] = Field(None, ge=0.0, le=1.0)
@property
def total(self) -> float:
return sum(
v for v in [self.label_correct, self.action_correct, self.severity_within_1]
if v is not None
)
# ββ Step result βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class GroundTruth(BaseModel):
"""Ground truth record stored in each scenario."""
label: Label
action: ModerationAction
severity: Optional[int] = Field(None, ge=1, le=5)
rationale: Optional[str] = None
class StepResult(BaseModel):
"""
Full result returned by env.step().
state β next observation (next post in queue, or final post state)
reward β [-0.3, 1.0] partial-credit score (penalties may go negative)
done β False until all queue posts processed; True after final step
info β breakdown, ground truth, submitted action, warnings
"""
state: Observation
reward: float = Field(..., ge=-0.3, le=1.0)
done: bool
info: StepInfo
class StepInfo(BaseModel):
"""Metadata returned inside StepResult.info."""
scenario_id: str
tier: Tier
ground_truth: GroundTruth
score_rubric: dict
score_breakdown: ScoreBreakdown
submitted_action: AgentAction
warnings: List[str] = Field(default_factory=list)
# ββ Scenario (internal) βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class Scenario(BaseModel):
"""
One benchmark scenario as stored in moderation_benchmark.json.
Used internally by the environment loader.
"""
id: str
tier: Tier
state: Observation
ground_truth: GroundTruth
score_rubric: dict
|